@jax-js/jax 0.0.3 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +96 -22
- package/dist/{backend-BqDtPGaR.js → backend-CdcTZEOF.js} +325 -153
- package/dist/{backend-D2C4MJRP.cjs → backend-yEU0L_ig.cjs} +350 -154
- package/dist/index.cjs +977 -354
- package/dist/index.d.cts +479 -88
- package/dist/index.d.ts +479 -88
- package/dist/index.js +964 -345
- package/dist/{webgpu-CNg9JGva.js → webgpu-CM-xNYzW.js} +9 -3
- package/dist/{webgpu-fqhx41TC.cjs → webgpu-CNOpiO5T.cjs} +9 -3
- package/package.json +15 -4
|
@@ -35,7 +35,22 @@ var PPrint = class PPrint {
|
|
|
35
35
|
//#endregion
|
|
36
36
|
//#region src/utils.ts
|
|
37
37
|
/** @file Generic programming utilities with no dependencies on library code. */
|
|
38
|
-
|
|
38
|
+
let DEBUG = 0;
|
|
39
|
+
/**
|
|
40
|
+
* Set the debug level for verbose logging.
|
|
41
|
+
*
|
|
42
|
+
* 1. JIT compile logs
|
|
43
|
+
* 2. Shader code
|
|
44
|
+
* 3. Expressions and metadata
|
|
45
|
+
* 4. JIT programs, tuning details
|
|
46
|
+
* 5. Most verbose operation traces
|
|
47
|
+
*
|
|
48
|
+
* This is an experimental API and may change in behavior. Do not rely on this
|
|
49
|
+
* in production.
|
|
50
|
+
*/
|
|
51
|
+
function setDebug(level) {
|
|
52
|
+
DEBUG = level;
|
|
53
|
+
}
|
|
39
54
|
function unzip2(pairs) {
|
|
40
55
|
const lst1 = [];
|
|
41
56
|
const lst2 = [];
|
|
@@ -109,9 +124,23 @@ function isNumberPair(x) {
|
|
|
109
124
|
}
|
|
110
125
|
/** Check an axis against number of dimensions, and resolve negative axes. */
|
|
111
126
|
function checkAxis(axis, ndim) {
|
|
112
|
-
if (axis < -ndim || axis >= ndim) throw new Error(`
|
|
127
|
+
if (axis < -ndim || axis >= ndim) throw new Error(`Axis ${axis} out of bounds for array of dimension ${ndim}`);
|
|
113
128
|
return axis < 0 ? axis + ndim : axis;
|
|
114
129
|
}
|
|
130
|
+
/** Normalize common axis argument for functions, defaulting to all axes. */
|
|
131
|
+
function normalizeAxis(axis, ndim) {
|
|
132
|
+
if (axis === null) return range(ndim);
|
|
133
|
+
else if (typeof axis === "number") return [checkAxis(axis, ndim)];
|
|
134
|
+
else {
|
|
135
|
+
const seen = /* @__PURE__ */ new Set();
|
|
136
|
+
for (const a of axis) {
|
|
137
|
+
const ca = checkAxis(a, ndim);
|
|
138
|
+
if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
|
|
139
|
+
seen.add(ca);
|
|
140
|
+
}
|
|
141
|
+
return [...seen].sort();
|
|
142
|
+
}
|
|
143
|
+
}
|
|
115
144
|
function range(start, stop, step = 1) {
|
|
116
145
|
if (stop === void 0) {
|
|
117
146
|
stop = start;
|
|
@@ -177,6 +206,35 @@ function findPow2(hint, max) {
|
|
|
177
206
|
while (ret < hint && 2 * ret <= max) ret *= 2;
|
|
178
207
|
return ret;
|
|
179
208
|
}
|
|
209
|
+
/**
|
|
210
|
+
* Implements a NumPy-style generalized broadcast rule on two array shapes.
|
|
211
|
+
*
|
|
212
|
+
* "When operating on two arrays, NumPy compares their shapes element-wise. It
|
|
213
|
+
* starts with the trailing (i.e. rightmost) dimension and works its way left.
|
|
214
|
+
* Two dimensions are compatible when:
|
|
215
|
+
* 1. they are equal, or
|
|
216
|
+
* 2. one of them is 1."
|
|
217
|
+
*
|
|
218
|
+
* Throws a TypeError if the broadcast is not possible.
|
|
219
|
+
*
|
|
220
|
+
* <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
|
|
221
|
+
*/
|
|
222
|
+
function generalBroadcast(a, b) {
|
|
223
|
+
const out = [];
|
|
224
|
+
let i = a.length - 1;
|
|
225
|
+
let j = b.length - 1;
|
|
226
|
+
for (; i >= 0 && j >= 0; i--, j--) {
|
|
227
|
+
const x = a[i];
|
|
228
|
+
const y = b[j];
|
|
229
|
+
if (x === y) out.push(x);
|
|
230
|
+
else if (x === 1) out.push(y);
|
|
231
|
+
else if (y === 1) out.push(x);
|
|
232
|
+
else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
|
|
233
|
+
}
|
|
234
|
+
for (; i >= 0; i--) out.push(a[i]);
|
|
235
|
+
for (; j >= 0; j--) out.push(b[j]);
|
|
236
|
+
return out.reverse();
|
|
237
|
+
}
|
|
180
238
|
function recursiveFlatten(ar) {
|
|
181
239
|
if (!Array.isArray(ar)) return [ar];
|
|
182
240
|
return ar.flat(Infinity);
|
|
@@ -186,6 +244,7 @@ function strip1(str) {
|
|
|
186
244
|
if (str[0] === "(" && str[str.length - 1] === ")") return str.slice(1, -1);
|
|
187
245
|
return str;
|
|
188
246
|
}
|
|
247
|
+
const _stagingbuf = /* @__PURE__ */ new DataView(/* @__PURE__ */ new ArrayBuffer(8));
|
|
189
248
|
/**
|
|
190
249
|
* Polynomial hashes modulo p are good at avoiding collisions in expectation.
|
|
191
250
|
* Probability-wise, it's good enough to be used for something like
|
|
@@ -200,22 +259,26 @@ var FpHash = class FpHash {
|
|
|
200
259
|
const modulus = 3189051996290219n;
|
|
201
260
|
this.value = (this.value * base + x) % modulus;
|
|
202
261
|
}
|
|
203
|
-
update(
|
|
204
|
-
|
|
205
|
-
|
|
262
|
+
update(x) {
|
|
263
|
+
if (typeof x === "string") {
|
|
264
|
+
this.#update(BigInt(x.length));
|
|
265
|
+
for (let i = 0; i < x.length; i++) this.#update(BigInt(199 + x.charCodeAt(i)));
|
|
266
|
+
} else if (typeof x === "number") if (Number.isInteger(x)) this.#update(68265653n ^ BigInt(x));
|
|
206
267
|
else {
|
|
207
|
-
|
|
208
|
-
this.#update(
|
|
268
|
+
_stagingbuf.setFloat64(0, x, true);
|
|
269
|
+
this.#update(_stagingbuf.getBigUint64(0, true));
|
|
209
270
|
}
|
|
210
271
|
else if (typeof x === "boolean") this.#update(x ? 69069841n : 63640693n);
|
|
211
272
|
else if (typeof x === "bigint") this.#update(x ^ 71657401n);
|
|
212
273
|
else if (x === null) this.#update(37832657n);
|
|
213
274
|
else if (x === void 0) this.#update(18145117n);
|
|
214
|
-
else
|
|
275
|
+
else x.hash(this);
|
|
215
276
|
return this;
|
|
216
277
|
}
|
|
217
278
|
static hash(...values) {
|
|
218
|
-
|
|
279
|
+
const h = new FpHash();
|
|
280
|
+
for (const x of values) h.update(x);
|
|
281
|
+
return h.value;
|
|
219
282
|
}
|
|
220
283
|
};
|
|
221
284
|
/** Run a function while caching it inline inside a `Map`. */
|
|
@@ -250,6 +313,41 @@ const byteWidth = (dtype) => {
|
|
|
250
313
|
}
|
|
251
314
|
};
|
|
252
315
|
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
|
|
316
|
+
/**
|
|
317
|
+
* Promote two dtypes to their join according to the type lattice.
|
|
318
|
+
*
|
|
319
|
+
* When performing operations between arrays of different types, we need to
|
|
320
|
+
* promote both operands to a common type that can represent values from both
|
|
321
|
+
* input types. This follows JAX's type promotion rules.
|
|
322
|
+
*
|
|
323
|
+
* **Type lattice:**
|
|
324
|
+
* ```text
|
|
325
|
+
* bool -> uint32 -> int32 -> float16 -> float32
|
|
326
|
+
* weakType --^
|
|
327
|
+
* ```
|
|
328
|
+
*
|
|
329
|
+
* `weakType` represents weakly typed arrays. These are created for JS numbers,
|
|
330
|
+
* which default to float32 but "weak" so they cast to the dtype of any array
|
|
331
|
+
* they are first combined with, except `bool`.
|
|
332
|
+
*
|
|
333
|
+
* **Examples:**
|
|
334
|
+
* - `promoteTypes(bool, int32) → int32`
|
|
335
|
+
* - `promoteTypes(uint32, int32) → int32`
|
|
336
|
+
* - `promoteTypes(int32, float16) → float16`
|
|
337
|
+
* - `promoteTypes(float16, float32) → float32`
|
|
338
|
+
* - `promoteTypes(uint32, float32) → float32`
|
|
339
|
+
*/
|
|
340
|
+
function promoteTypes(dtype1, dtype2) {
|
|
341
|
+
if (dtype1 === dtype2) return dtype1;
|
|
342
|
+
const rank = {
|
|
343
|
+
[DType.Bool]: 0,
|
|
344
|
+
[DType.Uint32]: 1,
|
|
345
|
+
[DType.Int32]: 2,
|
|
346
|
+
[DType.Float16]: 3,
|
|
347
|
+
[DType.Float32]: 4
|
|
348
|
+
};
|
|
349
|
+
return rank[dtype1] > rank[dtype2] ? dtype1 : dtype2;
|
|
350
|
+
}
|
|
253
351
|
function dtypedArray(dtype, data) {
|
|
254
352
|
const { buffer, byteLength, byteOffset } = data;
|
|
255
353
|
const length = byteLength / byteWidth(dtype);
|
|
@@ -319,6 +417,12 @@ var AluExp = class AluExp {
|
|
|
319
417
|
static cos(a) {
|
|
320
418
|
return new AluExp(AluOp.Cos, a.dtype, [a]);
|
|
321
419
|
}
|
|
420
|
+
static asin(a) {
|
|
421
|
+
return new AluExp(AluOp.Asin, a.dtype, [a]);
|
|
422
|
+
}
|
|
423
|
+
static atan(a) {
|
|
424
|
+
return new AluExp(AluOp.Atan, a.dtype, [a]);
|
|
425
|
+
}
|
|
322
426
|
static exp(a) {
|
|
323
427
|
return new AluExp(AluOp.Exp, a.dtype, [a]);
|
|
324
428
|
}
|
|
@@ -402,8 +506,11 @@ var AluExp = class AluExp {
|
|
|
402
506
|
getHash() {
|
|
403
507
|
if (this.#hash !== void 0) return this.#hash;
|
|
404
508
|
const hasher = new FpHash();
|
|
405
|
-
hasher.update(this.op
|
|
406
|
-
hasher.update(this.
|
|
509
|
+
hasher.update(this.op);
|
|
510
|
+
hasher.update(this.dtype);
|
|
511
|
+
hasher.update(JSON.stringify(this.arg));
|
|
512
|
+
hasher.update(this.src.length);
|
|
513
|
+
for (const s of this.src) hasher.update(s);
|
|
407
514
|
this.#hash = hasher.value;
|
|
408
515
|
return this.#hash;
|
|
409
516
|
}
|
|
@@ -475,10 +582,16 @@ var AluExp = class AluExp {
|
|
|
475
582
|
ret = [Math.max(src[0].min, src[1].min), Math.max(src[0].max, src[1].max)];
|
|
476
583
|
break;
|
|
477
584
|
case AluOp.Sin:
|
|
478
|
-
ret = [
|
|
585
|
+
ret = [-1, 1];
|
|
479
586
|
break;
|
|
480
587
|
case AluOp.Cos:
|
|
481
|
-
ret = [
|
|
588
|
+
ret = [-1, 1];
|
|
589
|
+
break;
|
|
590
|
+
case AluOp.Asin:
|
|
591
|
+
ret = [-Math.PI / 2, Math.PI / 2];
|
|
592
|
+
break;
|
|
593
|
+
case AluOp.Atan:
|
|
594
|
+
ret = [-Math.PI / 2, Math.PI / 2];
|
|
482
595
|
break;
|
|
483
596
|
case AluOp.Exp:
|
|
484
597
|
ret = [Math.exp(src[0].min), Math.exp(src[0].max)];
|
|
@@ -594,11 +707,12 @@ var AluExp = class AluExp {
|
|
|
594
707
|
simplify(cache = /* @__PURE__ */ new Map()) {
|
|
595
708
|
if (this.#simplified !== void 0) return this.#simplified;
|
|
596
709
|
const hash = this.getHash();
|
|
597
|
-
|
|
710
|
+
const prevCachedValue = cache.get(hash);
|
|
711
|
+
if (prevCachedValue !== void 0) return this.#simplified = prevCachedValue;
|
|
598
712
|
const simplified = this.#simplifyInner(cache);
|
|
599
713
|
const simplifiedHash = simplified.getHash();
|
|
600
|
-
|
|
601
|
-
|
|
714
|
+
const prevSimplified = cache.get(simplifiedHash);
|
|
715
|
+
if (prevSimplified !== void 0) {
|
|
602
716
|
cache.set(hash, prevSimplified);
|
|
603
717
|
this.#simplified = prevSimplified;
|
|
604
718
|
return prevSimplified;
|
|
@@ -802,6 +916,8 @@ var AluExp = class AluExp {
|
|
|
802
916
|
switch (this.op) {
|
|
803
917
|
case AluOp.Sin: return Math.sin(x);
|
|
804
918
|
case AluOp.Cos: return Math.cos(x);
|
|
919
|
+
case AluOp.Asin: return Math.asin(x);
|
|
920
|
+
case AluOp.Atan: return Math.atan(x);
|
|
805
921
|
case AluOp.Exp: return Math.exp(x);
|
|
806
922
|
case AluOp.Log: return Math.log(x);
|
|
807
923
|
case AluOp.Sqrt: return Math.sqrt(x);
|
|
@@ -981,6 +1097,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
981
1097
|
AluOp$1["Max"] = "Max";
|
|
982
1098
|
AluOp$1["Sin"] = "Sin";
|
|
983
1099
|
AluOp$1["Cos"] = "Cos";
|
|
1100
|
+
AluOp$1["Asin"] = "Asin";
|
|
1101
|
+
AluOp$1["Atan"] = "Atan";
|
|
984
1102
|
AluOp$1["Exp"] = "Exp";
|
|
985
1103
|
AluOp$1["Log"] = "Log";
|
|
986
1104
|
AluOp$1["Sqrt"] = "Sqrt";
|
|
@@ -1011,6 +1129,8 @@ const AluGroup = {
|
|
|
1011
1129
|
Unary: new Set([
|
|
1012
1130
|
AluOp.Sin,
|
|
1013
1131
|
AluOp.Cos,
|
|
1132
|
+
AluOp.Asin,
|
|
1133
|
+
AluOp.Atan,
|
|
1014
1134
|
AluOp.Exp,
|
|
1015
1135
|
AluOp.Log,
|
|
1016
1136
|
AluOp.Sqrt,
|
|
@@ -1034,6 +1154,8 @@ const AluGroup = {
|
|
|
1034
1154
|
RequiredFloat: new Set([
|
|
1035
1155
|
AluOp.Sin,
|
|
1036
1156
|
AluOp.Cos,
|
|
1157
|
+
AluOp.Asin,
|
|
1158
|
+
AluOp.Atan,
|
|
1037
1159
|
AluOp.Exp,
|
|
1038
1160
|
AluOp.Log,
|
|
1039
1161
|
AluOp.Sqrt,
|
|
@@ -1065,7 +1187,7 @@ var Kernel = class {
|
|
|
1065
1187
|
this.exp = exp.simplify();
|
|
1066
1188
|
}
|
|
1067
1189
|
hash(state) {
|
|
1068
|
-
state.update(this.nargs
|
|
1190
|
+
state.update(this.nargs).update(this.size).update(this.exp).update(this.reduction);
|
|
1069
1191
|
}
|
|
1070
1192
|
pprint() {
|
|
1071
1193
|
let details = PPrint.pp(`exp = ${this.exp}`);
|
|
@@ -1111,7 +1233,7 @@ var Reduction = class {
|
|
|
1111
1233
|
this.epilogue = epilogue.simplify();
|
|
1112
1234
|
}
|
|
1113
1235
|
hash(state) {
|
|
1114
|
-
state.update(this.dtype
|
|
1236
|
+
state.update(this.dtype).update(this.op).update(this.size).update(this.epilogue);
|
|
1115
1237
|
}
|
|
1116
1238
|
toString() {
|
|
1117
1239
|
return `${this.op}{${this.size}} -> ${this.epilogue}`;
|
|
@@ -2283,78 +2405,92 @@ function wasm_log(cg) {
|
|
|
2283
2405
|
});
|
|
2284
2406
|
}
|
|
2285
2407
|
/**
|
|
2286
|
-
*
|
|
2408
|
+
* Common helper to approximate sin(x) and cos(x).
|
|
2287
2409
|
*
|
|
2288
2410
|
* Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
|
|
2289
|
-
* z = y - q*(π/2); use
|
|
2411
|
+
* z = y - q*(π/2); use one of two polynomials on z:
|
|
2290
2412
|
* sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
|
|
2413
|
+
* cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
|
|
2414
|
+
*/
|
|
2415
|
+
function _sincos(cg) {
|
|
2416
|
+
const y = cg.local.declare(cg.f32);
|
|
2417
|
+
const qf = cg.local.declare(cg.f32);
|
|
2418
|
+
const q = cg.local.declare(cg.i32);
|
|
2419
|
+
const z = cg.local.declare(cg.f32);
|
|
2420
|
+
const z2 = cg.local.declare(cg.f32);
|
|
2421
|
+
const sz = cg.local.declare(cg.f32);
|
|
2422
|
+
const cz = cg.local.declare(cg.f32);
|
|
2423
|
+
cg.local.get(0);
|
|
2424
|
+
cg.local.get(0);
|
|
2425
|
+
cg.f32.const(1 / (2 * Math.PI));
|
|
2426
|
+
cg.f32.mul();
|
|
2427
|
+
cg.f32.nearest();
|
|
2428
|
+
cg.local.tee(qf);
|
|
2429
|
+
cg.f32.const(2 * Math.PI);
|
|
2430
|
+
cg.f32.mul();
|
|
2431
|
+
cg.f32.sub();
|
|
2432
|
+
cg.local.set(y);
|
|
2433
|
+
cg.local.get(y);
|
|
2434
|
+
cg.f32.const(2 / Math.PI);
|
|
2435
|
+
cg.f32.mul();
|
|
2436
|
+
cg.f32.nearest();
|
|
2437
|
+
cg.local.tee(qf);
|
|
2438
|
+
cg.i32.trunc_f32_s();
|
|
2439
|
+
cg.local.set(q);
|
|
2440
|
+
cg.local.get(y);
|
|
2441
|
+
cg.local.get(qf);
|
|
2442
|
+
cg.f32.const(Math.PI / 2);
|
|
2443
|
+
cg.f32.mul();
|
|
2444
|
+
cg.f32.sub();
|
|
2445
|
+
cg.local.tee(z);
|
|
2446
|
+
cg.local.get(z);
|
|
2447
|
+
cg.f32.mul();
|
|
2448
|
+
cg.local.set(z2);
|
|
2449
|
+
cg.f32.const(-1 / 5040);
|
|
2450
|
+
cg.local.get(z2);
|
|
2451
|
+
cg.f32.mul();
|
|
2452
|
+
cg.f32.const(1 / 120);
|
|
2453
|
+
cg.f32.add();
|
|
2454
|
+
cg.local.get(z2);
|
|
2455
|
+
cg.f32.mul();
|
|
2456
|
+
cg.f32.const(-1 / 6);
|
|
2457
|
+
cg.f32.add();
|
|
2458
|
+
cg.local.get(z2);
|
|
2459
|
+
cg.f32.mul();
|
|
2460
|
+
cg.f32.const(1);
|
|
2461
|
+
cg.f32.add();
|
|
2462
|
+
cg.local.get(z);
|
|
2463
|
+
cg.f32.mul();
|
|
2464
|
+
cg.local.set(sz);
|
|
2465
|
+
cg.f32.const(-1 / 720);
|
|
2466
|
+
cg.local.get(z2);
|
|
2467
|
+
cg.f32.mul();
|
|
2468
|
+
cg.f32.const(1 / 24);
|
|
2469
|
+
cg.f32.add();
|
|
2470
|
+
cg.local.get(z2);
|
|
2471
|
+
cg.f32.mul();
|
|
2472
|
+
cg.f32.const(-1 / 2);
|
|
2473
|
+
cg.f32.add();
|
|
2474
|
+
cg.local.get(z2);
|
|
2475
|
+
cg.f32.mul();
|
|
2476
|
+
cg.f32.const(1);
|
|
2477
|
+
cg.f32.add();
|
|
2478
|
+
cg.local.set(cz);
|
|
2479
|
+
return {
|
|
2480
|
+
q,
|
|
2481
|
+
sz,
|
|
2482
|
+
cz
|
|
2483
|
+
};
|
|
2484
|
+
}
|
|
2485
|
+
/**
|
|
2486
|
+
* Approximate sin(x).
|
|
2487
|
+
*
|
|
2488
|
+
* Quadrant mapping: k=q mod 4: 0: +sz, 1: +cz, 2: -sz, 3: -cz
|
|
2291
2489
|
*/
|
|
2292
2490
|
function wasm_sin(cg) {
|
|
2293
2491
|
return cg.function([cg.f32], [cg.f32], () => {
|
|
2294
|
-
const
|
|
2295
|
-
const qf = cg.local.declare(cg.f32);
|
|
2296
|
-
const q = cg.local.declare(cg.i32);
|
|
2297
|
-
const z = cg.local.declare(cg.f32);
|
|
2298
|
-
const z2 = cg.local.declare(cg.f32);
|
|
2299
|
-
const sz = cg.local.declare(cg.f32);
|
|
2300
|
-
const cz = cg.local.declare(cg.f32);
|
|
2492
|
+
const { q, sz, cz } = _sincos(cg);
|
|
2301
2493
|
const mag = cg.local.declare(cg.f32);
|
|
2302
|
-
cg.local.get(0);
|
|
2303
|
-
cg.local.get(0);
|
|
2304
|
-
cg.f32.const(1 / (2 * Math.PI));
|
|
2305
|
-
cg.f32.mul();
|
|
2306
|
-
cg.f32.nearest();
|
|
2307
|
-
cg.local.tee(qf);
|
|
2308
|
-
cg.f32.const(2 * Math.PI);
|
|
2309
|
-
cg.f32.mul();
|
|
2310
|
-
cg.f32.sub();
|
|
2311
|
-
cg.local.set(y);
|
|
2312
|
-
cg.local.get(y);
|
|
2313
|
-
cg.f32.const(2 / Math.PI);
|
|
2314
|
-
cg.f32.mul();
|
|
2315
|
-
cg.f32.nearest();
|
|
2316
|
-
cg.local.tee(qf);
|
|
2317
|
-
cg.i32.trunc_f32_s();
|
|
2318
|
-
cg.local.set(q);
|
|
2319
|
-
cg.local.get(y);
|
|
2320
|
-
cg.local.get(qf);
|
|
2321
|
-
cg.f32.const(Math.PI / 2);
|
|
2322
|
-
cg.f32.mul();
|
|
2323
|
-
cg.f32.sub();
|
|
2324
|
-
cg.local.tee(z);
|
|
2325
|
-
cg.local.get(z);
|
|
2326
|
-
cg.f32.mul();
|
|
2327
|
-
cg.local.set(z2);
|
|
2328
|
-
cg.f32.const(-1 / 5040);
|
|
2329
|
-
cg.local.get(z2);
|
|
2330
|
-
cg.f32.mul();
|
|
2331
|
-
cg.f32.const(1 / 120);
|
|
2332
|
-
cg.f32.add();
|
|
2333
|
-
cg.local.get(z2);
|
|
2334
|
-
cg.f32.mul();
|
|
2335
|
-
cg.f32.const(-1 / 6);
|
|
2336
|
-
cg.f32.add();
|
|
2337
|
-
cg.local.get(z2);
|
|
2338
|
-
cg.f32.mul();
|
|
2339
|
-
cg.f32.const(1);
|
|
2340
|
-
cg.f32.add();
|
|
2341
|
-
cg.local.get(z);
|
|
2342
|
-
cg.f32.mul();
|
|
2343
|
-
cg.local.set(sz);
|
|
2344
|
-
cg.f32.const(-1 / 720);
|
|
2345
|
-
cg.local.get(z2);
|
|
2346
|
-
cg.f32.mul();
|
|
2347
|
-
cg.f32.const(1 / 24);
|
|
2348
|
-
cg.f32.add();
|
|
2349
|
-
cg.local.get(z2);
|
|
2350
|
-
cg.f32.mul();
|
|
2351
|
-
cg.f32.const(-1 / 2);
|
|
2352
|
-
cg.f32.add();
|
|
2353
|
-
cg.local.get(z2);
|
|
2354
|
-
cg.f32.mul();
|
|
2355
|
-
cg.f32.const(1);
|
|
2356
|
-
cg.f32.add();
|
|
2357
|
-
cg.local.set(cz);
|
|
2358
2494
|
cg.local.get(cz);
|
|
2359
2495
|
cg.local.get(sz);
|
|
2360
2496
|
cg.local.get(q);
|
|
@@ -2373,75 +2509,12 @@ function wasm_sin(cg) {
|
|
|
2373
2509
|
/**
|
|
2374
2510
|
* Approximate cos(x).
|
|
2375
2511
|
*
|
|
2376
|
-
*
|
|
2377
|
-
* k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
|
|
2512
|
+
* Quadrant mapping: k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
|
|
2378
2513
|
*/
|
|
2379
2514
|
function wasm_cos(cg) {
|
|
2380
2515
|
return cg.function([cg.f32], [cg.f32], () => {
|
|
2381
|
-
const
|
|
2382
|
-
const qf = cg.local.declare(cg.f32);
|
|
2383
|
-
const q = cg.local.declare(cg.i32);
|
|
2384
|
-
const z = cg.local.declare(cg.f32);
|
|
2385
|
-
const z2 = cg.local.declare(cg.f32);
|
|
2386
|
-
const sz = cg.local.declare(cg.f32);
|
|
2387
|
-
const cz = cg.local.declare(cg.f32);
|
|
2516
|
+
const { q, sz, cz } = _sincos(cg);
|
|
2388
2517
|
const mag = cg.local.declare(cg.f32);
|
|
2389
|
-
cg.local.get(0);
|
|
2390
|
-
cg.local.get(0);
|
|
2391
|
-
cg.f32.const(1 / (2 * Math.PI));
|
|
2392
|
-
cg.f32.mul();
|
|
2393
|
-
cg.f32.nearest();
|
|
2394
|
-
cg.local.tee(qf);
|
|
2395
|
-
cg.f32.const(2 * Math.PI);
|
|
2396
|
-
cg.f32.mul();
|
|
2397
|
-
cg.f32.sub();
|
|
2398
|
-
cg.local.set(y);
|
|
2399
|
-
cg.local.get(y);
|
|
2400
|
-
cg.f32.const(2 / Math.PI);
|
|
2401
|
-
cg.f32.mul();
|
|
2402
|
-
cg.f32.nearest();
|
|
2403
|
-
cg.local.tee(qf);
|
|
2404
|
-
cg.i32.trunc_f32_s();
|
|
2405
|
-
cg.local.set(q);
|
|
2406
|
-
cg.local.get(y);
|
|
2407
|
-
cg.local.get(qf);
|
|
2408
|
-
cg.f32.const(Math.PI / 2);
|
|
2409
|
-
cg.f32.mul();
|
|
2410
|
-
cg.f32.sub();
|
|
2411
|
-
cg.local.tee(z);
|
|
2412
|
-
cg.local.get(z);
|
|
2413
|
-
cg.f32.mul();
|
|
2414
|
-
cg.local.set(z2);
|
|
2415
|
-
cg.f32.const(-1 / 5040);
|
|
2416
|
-
cg.local.get(z2);
|
|
2417
|
-
cg.f32.mul();
|
|
2418
|
-
cg.f32.const(1 / 120);
|
|
2419
|
-
cg.f32.add();
|
|
2420
|
-
cg.local.get(z2);
|
|
2421
|
-
cg.f32.mul();
|
|
2422
|
-
cg.f32.const(-1 / 6);
|
|
2423
|
-
cg.f32.add();
|
|
2424
|
-
cg.local.get(z2);
|
|
2425
|
-
cg.f32.mul();
|
|
2426
|
-
cg.f32.const(1);
|
|
2427
|
-
cg.f32.add();
|
|
2428
|
-
cg.local.get(z);
|
|
2429
|
-
cg.f32.mul();
|
|
2430
|
-
cg.local.set(sz);
|
|
2431
|
-
cg.f32.const(-1 / 720);
|
|
2432
|
-
cg.local.get(z2);
|
|
2433
|
-
cg.f32.mul();
|
|
2434
|
-
cg.f32.const(1 / 24);
|
|
2435
|
-
cg.f32.add();
|
|
2436
|
-
cg.local.get(z2);
|
|
2437
|
-
cg.f32.mul();
|
|
2438
|
-
cg.f32.const(-1 / 2);
|
|
2439
|
-
cg.f32.add();
|
|
2440
|
-
cg.local.get(z2);
|
|
2441
|
-
cg.f32.mul();
|
|
2442
|
-
cg.f32.const(1);
|
|
2443
|
-
cg.f32.add();
|
|
2444
|
-
cg.local.set(cz);
|
|
2445
2518
|
cg.local.get(sz);
|
|
2446
2519
|
cg.local.get(cz);
|
|
2447
2520
|
cg.local.get(q);
|
|
@@ -2459,6 +2532,100 @@ function wasm_cos(cg) {
|
|
|
2459
2532
|
cg.select();
|
|
2460
2533
|
});
|
|
2461
2534
|
}
|
|
2535
|
+
/** Helper function for approximating arctan(x). */
|
|
2536
|
+
function _atan(cg) {
|
|
2537
|
+
const x = cg.local.declare(cg.f32);
|
|
2538
|
+
const abs_x = cg.local.declare(cg.f32);
|
|
2539
|
+
const z = cg.local.declare(cg.f32);
|
|
2540
|
+
const z2 = cg.local.declare(cg.f32);
|
|
2541
|
+
const p = cg.local.declare(cg.f32);
|
|
2542
|
+
cg.local.set(x);
|
|
2543
|
+
cg.local.get(x);
|
|
2544
|
+
cg.f32.abs();
|
|
2545
|
+
cg.local.set(abs_x);
|
|
2546
|
+
cg.f32.const(1);
|
|
2547
|
+
cg.local.get(abs_x);
|
|
2548
|
+
cg.f32.div();
|
|
2549
|
+
cg.local.get(abs_x);
|
|
2550
|
+
cg.local.get(abs_x);
|
|
2551
|
+
cg.f32.const(1);
|
|
2552
|
+
cg.f32.ge();
|
|
2553
|
+
cg.select();
|
|
2554
|
+
cg.local.set(z);
|
|
2555
|
+
cg.local.get(z);
|
|
2556
|
+
cg.local.get(z);
|
|
2557
|
+
cg.f32.mul();
|
|
2558
|
+
cg.local.set(z2);
|
|
2559
|
+
cg.f32.const(.0415796528637);
|
|
2560
|
+
cg.local.get(z2);
|
|
2561
|
+
cg.f32.mul();
|
|
2562
|
+
cg.f32.const(.661705427875);
|
|
2563
|
+
cg.f32.add();
|
|
2564
|
+
cg.local.get(z2);
|
|
2565
|
+
cg.f32.mul();
|
|
2566
|
+
cg.f32.const(.999998614341);
|
|
2567
|
+
cg.f32.add();
|
|
2568
|
+
cg.f32.const(.173698870181);
|
|
2569
|
+
cg.local.get(z2);
|
|
2570
|
+
cg.f32.mul();
|
|
2571
|
+
cg.f32.const(.994987933645);
|
|
2572
|
+
cg.f32.add();
|
|
2573
|
+
cg.local.get(z2);
|
|
2574
|
+
cg.f32.mul();
|
|
2575
|
+
cg.f32.const(1);
|
|
2576
|
+
cg.f32.add();
|
|
2577
|
+
cg.f32.div();
|
|
2578
|
+
cg.local.get(z);
|
|
2579
|
+
cg.f32.mul();
|
|
2580
|
+
cg.local.set(p);
|
|
2581
|
+
cg.f32.const(Math.PI / 2);
|
|
2582
|
+
cg.local.get(p);
|
|
2583
|
+
cg.f32.sub();
|
|
2584
|
+
cg.local.get(p);
|
|
2585
|
+
cg.local.get(abs_x);
|
|
2586
|
+
cg.f32.const(1);
|
|
2587
|
+
cg.f32.ge();
|
|
2588
|
+
cg.select();
|
|
2589
|
+
cg.local.get(x);
|
|
2590
|
+
cg.f32.copysign();
|
|
2591
|
+
}
|
|
2592
|
+
/**
|
|
2593
|
+
* Approximate atan(x).
|
|
2594
|
+
*
|
|
2595
|
+
* Method: if |x| < 1, use rational approximation: atan(x) ≈ x * P(x^2) / Q(x^2)
|
|
2596
|
+
* where P(u) = A0 + A1*u + A2*u^2 (degree 2)
|
|
2597
|
+
* Q(u) = 1 + B1*u + B2*u^2 (degree 2)
|
|
2598
|
+
* if |x| >= 1, use: atan(x) = sign(x)*π/2 - atan(1/x)
|
|
2599
|
+
* (fitted coefficients, max error ~5e-7 on [0,1])
|
|
2600
|
+
*/
|
|
2601
|
+
function wasm_atan(cg) {
|
|
2602
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2603
|
+
cg.local.get(0);
|
|
2604
|
+
_atan(cg);
|
|
2605
|
+
});
|
|
2606
|
+
}
|
|
2607
|
+
/**
|
|
2608
|
+
* Approximate asin(x).
|
|
2609
|
+
*
|
|
2610
|
+
* Method: asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
|
|
2611
|
+
*/
|
|
2612
|
+
function wasm_asin(cg) {
|
|
2613
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2614
|
+
cg.local.get(0);
|
|
2615
|
+
cg.f32.const(1);
|
|
2616
|
+
cg.local.get(0);
|
|
2617
|
+
cg.local.get(0);
|
|
2618
|
+
cg.f32.mul();
|
|
2619
|
+
cg.f32.sub();
|
|
2620
|
+
cg.f32.sqrt();
|
|
2621
|
+
cg.f32.const(1);
|
|
2622
|
+
cg.f32.add();
|
|
2623
|
+
cg.f32.div();
|
|
2624
|
+
_atan(cg);
|
|
2625
|
+
cg.f32.const(2);
|
|
2626
|
+
cg.f32.mul();
|
|
2627
|
+
});
|
|
2628
|
+
}
|
|
2462
2629
|
/**
|
|
2463
2630
|
* Threefry2x32 pseudorandom number generator.
|
|
2464
2631
|
*
|
|
@@ -3338,6 +3505,8 @@ function codegenWasm(kernel) {
|
|
|
3338
3505
|
const funcs = {};
|
|
3339
3506
|
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
3340
3507
|
if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
|
|
3508
|
+
if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
|
|
3509
|
+
if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
|
|
3341
3510
|
if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
|
|
3342
3511
|
if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
|
|
3343
3512
|
if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
|
|
@@ -3489,6 +3658,8 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3489
3658
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3490
3659
|
} else if (AluGroup.Unary.has(op)) if (op === AluOp.Sin) gen(src[0]), cg.call(funcs.sin);
|
|
3491
3660
|
else if (op === AluOp.Cos) gen(src[0]), cg.call(funcs.cos);
|
|
3661
|
+
else if (op === AluOp.Asin) gen(src[0]), cg.call(funcs.asin);
|
|
3662
|
+
else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
|
|
3492
3663
|
else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
|
|
3493
3664
|
else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
|
|
3494
3665
|
else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
|
|
@@ -3587,10 +3758,11 @@ let defaultBackend = "wasm";
|
|
|
3587
3758
|
const initializedBackends = /* @__PURE__ */ new Map();
|
|
3588
3759
|
initializedBackends.set("cpu", new CpuBackend());
|
|
3589
3760
|
initializedBackends.set("wasm", new WasmBackend());
|
|
3590
|
-
/**
|
|
3591
|
-
function
|
|
3592
|
-
if (initializedBackends.has(device)) defaultBackend = device;
|
|
3761
|
+
/** Configure the default device for arrays. */
|
|
3762
|
+
function defaultDevice(device) {
|
|
3763
|
+
if (device !== void 0) if (initializedBackends.has(device)) defaultBackend = device;
|
|
3593
3764
|
else throw new Error(`Backend not initialized: ${device}`);
|
|
3765
|
+
return defaultBackend;
|
|
3594
3766
|
}
|
|
3595
3767
|
/**
|
|
3596
3768
|
* Initialize `jax-js` library backends.
|
|
@@ -3617,7 +3789,7 @@ async function createBackend(device) {
|
|
|
3617
3789
|
if (!navigator.gpu) return null;
|
|
3618
3790
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3619
3791
|
if (!adapter) return null;
|
|
3620
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
3792
|
+
const { WebGPUBackend } = await import("./webgpu-CM-xNYzW.js");
|
|
3621
3793
|
const importantLimits = [
|
|
3622
3794
|
"maxBufferSize",
|
|
3623
3795
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -3670,4 +3842,4 @@ var UnsupportedOpError = class extends Error {
|
|
|
3670
3842
|
};
|
|
3671
3843
|
|
|
3672
3844
|
//#endregion
|
|
3673
|
-
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, devices, dtypedArray, dtypedJsArray, findPow2, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, partitionList, prod, range, recursiveFlatten, rep, runWithCache,
|
|
3845
|
+
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, union, unravelAlu, unzip2, zip, zipn };
|