@jax-js/jax 0.1.4 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +10 -7
- package/dist/{backend-Bu9GY6sK.cjs → backend-D7s-Retx.cjs} +122 -8
- package/dist/{backend-tngXtWe4.js → backend-Dx6Ob2D1.js} +111 -9
- package/dist/index.cjs +1059 -208
- package/dist/index.d.cts +429 -21
- package/dist/index.d.ts +429 -21
- package/dist/index.js +1059 -209
- package/dist/webgl-CLLvzJlO.js +522 -0
- package/dist/webgl-CyfzNW8T.cjs +522 -0
- package/dist/{webgpu-ChVgx3b6.js → webgpu-C-VfevQW.js} +296 -3
- package/dist/{webgpu-Oj3Kd-kd.cjs → webgpu-rraa6dfz.cjs} +296 -3
- package/package.json +1 -1
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-Dx6Ob2D1.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -414,10 +414,301 @@ function createArgsort(device, type) {
|
|
|
414
414
|
const batches = prod(shape.slice(0, -1));
|
|
415
415
|
return bitonicSortShader(device, dtype, n, batches, true);
|
|
416
416
|
}
|
|
417
|
+
/**
|
|
418
|
+
* Generate a triangular solve shader.
|
|
419
|
+
*
|
|
420
|
+
* Solves A @ X.T = B.T for X, where A is upper-triangular.
|
|
421
|
+
* Uses a parallelized back-substitution:
|
|
422
|
+
* 1. Copy b to x
|
|
423
|
+
* 2. For j = n-1 down to 0:
|
|
424
|
+
* - Divide x[j] by a[j,j] (single thread)
|
|
425
|
+
* - All threads subtract x[j] * a[i,j] from x[i] for i < j in parallel
|
|
426
|
+
*/
|
|
427
|
+
function createTriangularSolve(device, type, params) {
|
|
428
|
+
const dtype = type.inputDtypes[0];
|
|
429
|
+
const aShape = type.inputShapes[0];
|
|
430
|
+
const bShape = type.inputShapes[1];
|
|
431
|
+
const n = aShape[aShape.length - 1];
|
|
432
|
+
const numRhs = bShape[bShape.length - 2];
|
|
433
|
+
const numMatrices = prod(aShape.slice(0, -2));
|
|
434
|
+
const needsF16 = dtype === DType.Float16;
|
|
435
|
+
const ty = dtypeToWgsl(dtype, true);
|
|
436
|
+
const workgroupSize = findPow2(n, device.limits.maxComputeWorkgroupSizeX);
|
|
437
|
+
const code = `
|
|
438
|
+
${needsF16 ? "enable f16;" : ""}
|
|
439
|
+
${headerWgsl}
|
|
440
|
+
|
|
441
|
+
@group(0) @binding(0) var<storage, read> a: array<${ty}>;
|
|
442
|
+
@group(0) @binding(1) var<storage, read> b: array<${ty}>;
|
|
443
|
+
@group(0) @binding(2) var<storage, read_write> x: array<${ty}>;
|
|
444
|
+
|
|
445
|
+
// Shared memory for the current pivot value x[j]
|
|
446
|
+
var<workgroup> x_j: ${ty};
|
|
447
|
+
|
|
448
|
+
@compute @workgroup_size(${workgroupSize})
|
|
449
|
+
fn main(
|
|
450
|
+
@builtin(workgroup_id) wg_id: vec3<u32>,
|
|
451
|
+
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
452
|
+
) {
|
|
453
|
+
let wg_idx = wg_id.x + wg_id.y * ${gridOffsetY}u;
|
|
454
|
+
let mat_idx = wg_idx / ${numRhs}u;
|
|
455
|
+
let rhs_idx = wg_idx % ${numRhs}u;
|
|
456
|
+
|
|
457
|
+
if (mat_idx >= ${numMatrices}u) {
|
|
458
|
+
return;
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
let a_base = mat_idx * ${n * n}u;
|
|
462
|
+
let bx_base = (mat_idx * ${numRhs}u + rhs_idx) * ${n}u;
|
|
463
|
+
let tid = local_id.x;
|
|
464
|
+
|
|
465
|
+
// Step 1: Copy b to x (threads collaborate)
|
|
466
|
+
for (var idx = tid; idx < ${n}u; idx += ${workgroupSize}u) {
|
|
467
|
+
x[bx_base + idx] = b[bx_base + idx];
|
|
468
|
+
}
|
|
469
|
+
storageBarrier();
|
|
470
|
+
|
|
471
|
+
// Step 2: Back-substitution from j = n-1 down to 0
|
|
472
|
+
for (var jj = 0u; jj < ${n}u; jj++) {
|
|
473
|
+
let j = ${n - 1}u - jj;
|
|
474
|
+
|
|
475
|
+
// Thread 0 computes x[j] = x[j] / a[j,j]
|
|
476
|
+
if (tid == 0u) {
|
|
477
|
+
${params.unitDiagonal ? `x_j = x[bx_base + j];` : `x_j = x[bx_base + j] / a[a_base + j * ${n}u + j];`}
|
|
478
|
+
x[bx_base + j] = x_j;
|
|
479
|
+
}
|
|
480
|
+
workgroupBarrier(); // Sync shared memory x_j
|
|
481
|
+
|
|
482
|
+
// All threads subtract x[j] * a[i,j] from x[i] for i < j
|
|
483
|
+
for (var i = tid; i < j; i += ${workgroupSize}u) {
|
|
484
|
+
x[bx_base + i] -= x_j * a[a_base + i * ${n}u + j];
|
|
485
|
+
}
|
|
486
|
+
workgroupBarrier();
|
|
487
|
+
storageBarrier();
|
|
488
|
+
}
|
|
489
|
+
}
|
|
490
|
+
`.trim();
|
|
491
|
+
const totalWorkgroups = numMatrices * numRhs;
|
|
492
|
+
const grid = calculateGrid(totalWorkgroups);
|
|
493
|
+
return [{
|
|
494
|
+
code,
|
|
495
|
+
numInputs: 2,
|
|
496
|
+
numOutputs: 1,
|
|
497
|
+
hasUniform: false,
|
|
498
|
+
passes: [{ grid }]
|
|
499
|
+
}];
|
|
500
|
+
}
|
|
501
|
+
/**
|
|
502
|
+
* Generate a Cholesky decomposition shader.
|
|
503
|
+
*
|
|
504
|
+
* Computes the lower triangular matrix L such that A = L * L^T for each
|
|
505
|
+
* positive semi-definite matrix in the batch. Uses the Cholesky-Crout
|
|
506
|
+
* algorithm which processes column-by-column.
|
|
507
|
+
*
|
|
508
|
+
* For each column j:
|
|
509
|
+
* 1. All threads compute their row's sum in parallel and store to output
|
|
510
|
+
* 2. Thread 0 computes L[j][j] = sqrt(output[j][j]) and stores to shared memory
|
|
511
|
+
* 3. All threads divide their output[i][j] by L[j][j] in parallel
|
|
512
|
+
*/
|
|
513
|
+
function createCholesky(device, type) {
|
|
514
|
+
const dtype = type.inputDtypes[0];
|
|
515
|
+
const shape = type.inputShapes[0];
|
|
516
|
+
const n = shape[shape.length - 1];
|
|
517
|
+
const batches = prod(shape.slice(0, -2));
|
|
518
|
+
const needsF16 = dtype === DType.Float16;
|
|
519
|
+
const ty = dtypeToWgsl(dtype, true);
|
|
520
|
+
const workgroupSize = findPow2(n, device.limits.maxComputeWorkgroupSizeX);
|
|
521
|
+
const code = `
|
|
522
|
+
${needsF16 ? "enable f16;" : ""}
|
|
523
|
+
${headerWgsl}
|
|
524
|
+
|
|
525
|
+
@group(0) @binding(0) var<storage, read> input: array<${ty}>;
|
|
526
|
+
@group(0) @binding(1) var<storage, read_write> output: array<${ty}>;
|
|
527
|
+
|
|
528
|
+
// Shared memory for the diagonal element
|
|
529
|
+
var<workgroup> L_jj: ${ty};
|
|
530
|
+
|
|
531
|
+
@compute @workgroup_size(${workgroupSize})
|
|
532
|
+
fn main(
|
|
533
|
+
@builtin(workgroup_id) wg_id: vec3<u32>,
|
|
534
|
+
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
535
|
+
) {
|
|
536
|
+
let batch = wg_id.x + wg_id.y * ${gridOffsetY}u;
|
|
537
|
+
if (batch >= ${batches}u) {
|
|
538
|
+
return;
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
let base = batch * ${n * n}u;
|
|
542
|
+
let tid = local_id.x;
|
|
543
|
+
|
|
544
|
+
// Zero out output and copy lower triangle from input (threads collaborate)
|
|
545
|
+
for (var idx = tid; idx < ${n * n}u; idx += ${workgroupSize}u) {
|
|
546
|
+
let row = idx / ${n}u;
|
|
547
|
+
let col = idx % ${n}u;
|
|
548
|
+
output[base + idx] = select(0, input[base + idx], col <= row);
|
|
549
|
+
}
|
|
550
|
+
storageBarrier();
|
|
551
|
+
|
|
552
|
+
// Cholesky-Crout algorithm: process column by column
|
|
553
|
+
for (var j = 0u; j < ${n}u; j++) {
|
|
554
|
+
// Step 1: All threads compute sum for their rows i >= j in parallel
|
|
555
|
+
// sum = A[i][j] - sum(L[i][k] * L[j][k] for k < j)
|
|
556
|
+
for (var i = j + tid; i < ${n}u; i += ${workgroupSize}u) {
|
|
557
|
+
var sum = output[base + i * ${n}u + j];
|
|
558
|
+
for (var k = 0u; k < j; k++) {
|
|
559
|
+
sum -= output[base + i * ${n}u + k] * output[base + j * ${n}u + k];
|
|
560
|
+
}
|
|
561
|
+
output[base + i * ${n}u + j] = sum;
|
|
562
|
+
}
|
|
563
|
+
storageBarrier();
|
|
564
|
+
|
|
565
|
+
// Step 2: Thread 0 computes L[j][j] = sqrt(output[j][j])
|
|
566
|
+
if (tid == 0u) {
|
|
567
|
+
L_jj = sqrt(output[base + j * ${n}u + j]);
|
|
568
|
+
output[base + j * ${n}u + j] = L_jj;
|
|
569
|
+
}
|
|
570
|
+
workgroupBarrier();
|
|
571
|
+
|
|
572
|
+
// Step 3: All threads divide output[i][j] by L[j][j] for i > j
|
|
573
|
+
for (var i = j + 1u + tid; i < ${n}u; i += ${workgroupSize}u) {
|
|
574
|
+
output[base + i * ${n}u + j] /= L_jj;
|
|
575
|
+
}
|
|
576
|
+
storageBarrier();
|
|
577
|
+
}
|
|
578
|
+
}
|
|
579
|
+
`.trim();
|
|
580
|
+
const grid = calculateGrid(batches);
|
|
581
|
+
return [{
|
|
582
|
+
code,
|
|
583
|
+
numInputs: 1,
|
|
584
|
+
numOutputs: 1,
|
|
585
|
+
hasUniform: false,
|
|
586
|
+
passes: [{ grid }]
|
|
587
|
+
}];
|
|
588
|
+
}
|
|
589
|
+
/**
|
|
590
|
+
* Generate an LU decomposition shader with partial pivoting.
|
|
591
|
+
*
|
|
592
|
+
* Computes PA = LU where P is a permutation matrix, L is lower triangular
|
|
593
|
+
* with unit diagonal, and U is upper triangular.
|
|
594
|
+
*
|
|
595
|
+
* For each column j:
|
|
596
|
+
* 1. Find pivot row (max absolute value in column j, rows >= j)
|
|
597
|
+
* 2. Swap rows j and pivot row
|
|
598
|
+
* 3. Compute L[i][j] = A[i][j] / A[j][j] for i > j
|
|
599
|
+
* 4. Update submatrix: A[i][k] -= L[i][j] * A[j][k] for i > j, k > j
|
|
600
|
+
*/
|
|
601
|
+
function createLU(device, type) {
|
|
602
|
+
const dtype = type.inputDtypes[0];
|
|
603
|
+
const shape = type.inputShapes[0];
|
|
604
|
+
const m = shape[shape.length - 2];
|
|
605
|
+
const n = shape[shape.length - 1];
|
|
606
|
+
const r = Math.min(m, n);
|
|
607
|
+
const batches = prod(shape.slice(0, -2));
|
|
608
|
+
const needsF16 = dtype === DType.Float16;
|
|
609
|
+
const ty = dtypeToWgsl(dtype, true);
|
|
610
|
+
const workgroupSize = findPow2(Math.max(m, n), device.limits.maxComputeWorkgroupSizeX);
|
|
611
|
+
const code = `
|
|
612
|
+
${needsF16 ? "enable f16;" : ""}
|
|
613
|
+
${headerWgsl}
|
|
614
|
+
|
|
615
|
+
@group(0) @binding(0) var<storage, read> input: array<${ty}>;
|
|
616
|
+
@group(0) @binding(1) var<storage, read_write> lu: array<${ty}>;
|
|
617
|
+
@group(0) @binding(2) var<storage, read_write> pivots: array<i32>;
|
|
618
|
+
@group(0) @binding(3) var<storage, read_write> perm: array<i32>;
|
|
619
|
+
|
|
620
|
+
var<workgroup> pivot_row: u32;
|
|
621
|
+
var<workgroup> pivot_val: ${ty};
|
|
622
|
+
|
|
623
|
+
@compute @workgroup_size(${workgroupSize})
|
|
624
|
+
fn main(
|
|
625
|
+
@builtin(workgroup_id) wg_id: vec3<u32>,
|
|
626
|
+
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
627
|
+
) {
|
|
628
|
+
let batch = wg_id.x + wg_id.y * ${gridOffsetY}u;
|
|
629
|
+
if (batch >= ${batches}u) {
|
|
630
|
+
return;
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
let lu_base = batch * ${m * n}u;
|
|
634
|
+
let piv_base = batch * ${r}u;
|
|
635
|
+
let perm_base = batch * ${m}u;
|
|
636
|
+
let tid = local_id.x;
|
|
637
|
+
|
|
638
|
+
// Copy input to lu
|
|
639
|
+
for (var idx = tid; idx < ${m * n}u; idx += ${workgroupSize}u) {
|
|
640
|
+
lu[lu_base + idx] = input[lu_base + idx];
|
|
641
|
+
}
|
|
642
|
+
// Initialize permutation
|
|
643
|
+
for (var idx = tid; idx < ${m}u; idx += ${workgroupSize}u) {
|
|
644
|
+
perm[perm_base + idx] = i32(idx);
|
|
645
|
+
}
|
|
646
|
+
storageBarrier();
|
|
647
|
+
|
|
648
|
+
// LU decomposition with partial pivoting
|
|
649
|
+
for (var j = 0u; j < ${r}u; j++) {
|
|
650
|
+
// Step 1: Thread 0 finds pivot (max abs value in column j, rows >= j)
|
|
651
|
+
if (tid == 0u) {
|
|
652
|
+
var max_val = abs(lu[lu_base + j * ${n}u + j]);
|
|
653
|
+
var max_row = j;
|
|
654
|
+
for (var i = j + 1u; i < ${m}u; i++) {
|
|
655
|
+
let val = abs(lu[lu_base + i * ${n}u + j]);
|
|
656
|
+
if (val > max_val) {
|
|
657
|
+
max_val = val;
|
|
658
|
+
max_row = i;
|
|
659
|
+
}
|
|
660
|
+
}
|
|
661
|
+
pivot_row = max_row;
|
|
662
|
+
pivot_val = lu[lu_base + max_row * ${n}u + j];
|
|
663
|
+
pivots[piv_base + j] = i32(max_row);
|
|
664
|
+
}
|
|
665
|
+
workgroupBarrier();
|
|
666
|
+
|
|
667
|
+
// Step 2: Swap rows j and pivot_row (threads collaborate)
|
|
668
|
+
let pr = pivot_row;
|
|
669
|
+
if (pr != j) {
|
|
670
|
+
for (var col = tid; col < ${n}u; col += ${workgroupSize}u) {
|
|
671
|
+
let tmp = lu[lu_base + j * ${n}u + col];
|
|
672
|
+
lu[lu_base + j * ${n}u + col] = lu[lu_base + pr * ${n}u + col];
|
|
673
|
+
lu[lu_base + pr * ${n}u + col] = tmp;
|
|
674
|
+
}
|
|
675
|
+
if (tid == 0u) {
|
|
676
|
+
let tmp_p = perm[perm_base + j];
|
|
677
|
+
perm[perm_base + j] = perm[perm_base + pr];
|
|
678
|
+
perm[perm_base + pr] = tmp_p;
|
|
679
|
+
}
|
|
680
|
+
}
|
|
681
|
+
storageBarrier();
|
|
682
|
+
|
|
683
|
+
// Step 3: Compute L[i][j] and update submatrix
|
|
684
|
+
// Each thread handles one row i > j
|
|
685
|
+
for (var i = j + 1u + tid; i < ${m}u; i += ${workgroupSize}u) {
|
|
686
|
+
let factor = lu[lu_base + i * ${n}u + j] / pivot_val;
|
|
687
|
+
lu[lu_base + i * ${n}u + j] = factor; // L[i][j]
|
|
688
|
+
for (var k = j + 1u; k < ${n}u; k++) {
|
|
689
|
+
lu[lu_base + i * ${n}u + k] -= factor * lu[lu_base + j * ${n}u + k];
|
|
690
|
+
}
|
|
691
|
+
}
|
|
692
|
+
storageBarrier();
|
|
693
|
+
}
|
|
694
|
+
}
|
|
695
|
+
`.trim();
|
|
696
|
+
const grid = calculateGrid(batches);
|
|
697
|
+
return [{
|
|
698
|
+
code,
|
|
699
|
+
numInputs: 1,
|
|
700
|
+
numOutputs: 3,
|
|
701
|
+
hasUniform: false,
|
|
702
|
+
passes: [{ grid }]
|
|
703
|
+
}];
|
|
704
|
+
}
|
|
417
705
|
function createRoutineShader(device, routine) {
|
|
418
706
|
switch (routine.name) {
|
|
419
707
|
case Routines.Sort: return createSort(device, routine.type);
|
|
420
708
|
case Routines.Argsort: return createArgsort(device, routine.type);
|
|
709
|
+
case Routines.TriangularSolve: return createTriangularSolve(device, routine.type, routine.params);
|
|
710
|
+
case Routines.Cholesky: return createCholesky(device, routine.type);
|
|
711
|
+
case Routines.LU: return createLU(device, routine.type);
|
|
421
712
|
default: throw new UnsupportedRoutineError(routine.name, "webgpu");
|
|
422
713
|
}
|
|
423
714
|
}
|
|
@@ -675,8 +966,10 @@ function pipelineSource(device, kernel) {
|
|
|
675
966
|
else source = `(${a} * ${b})`;
|
|
676
967
|
else if (op === AluOp.Idiv) source = isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
|
|
677
968
|
else if (op === AluOp.Mod) source = `(${a} % ${b})`;
|
|
678
|
-
else if (op === AluOp.Min) source = `
|
|
679
|
-
else
|
|
969
|
+
else if (op === AluOp.Min) if (dtype === DType.Bool) source = `(${a} && ${b})`;
|
|
970
|
+
else source = `min(${strip1(a)}, ${strip1(b)})`;
|
|
971
|
+
else if (op === AluOp.Max) if (dtype === DType.Bool) source = `(${a} || ${b})`;
|
|
972
|
+
else source = `max(${strip1(a)}, ${strip1(b)})`;
|
|
680
973
|
else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
681
974
|
else if (op === AluOp.Cmpne) if (isFloatDtype(src[0].dtype)) {
|
|
682
975
|
const x = isGensym(a) ? a : gensym();
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-D7s-Retx.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -414,10 +414,301 @@ function createArgsort(device, type) {
|
|
|
414
414
|
const batches = require_backend.prod(shape.slice(0, -1));
|
|
415
415
|
return bitonicSortShader(device, dtype, n, batches, true);
|
|
416
416
|
}
|
|
417
|
+
/**
|
|
418
|
+
* Generate a triangular solve shader.
|
|
419
|
+
*
|
|
420
|
+
* Solves A @ X.T = B.T for X, where A is upper-triangular.
|
|
421
|
+
* Uses a parallelized back-substitution:
|
|
422
|
+
* 1. Copy b to x
|
|
423
|
+
* 2. For j = n-1 down to 0:
|
|
424
|
+
* - Divide x[j] by a[j,j] (single thread)
|
|
425
|
+
* - All threads subtract x[j] * a[i,j] from x[i] for i < j in parallel
|
|
426
|
+
*/
|
|
427
|
+
function createTriangularSolve(device, type, params) {
|
|
428
|
+
const dtype = type.inputDtypes[0];
|
|
429
|
+
const aShape = type.inputShapes[0];
|
|
430
|
+
const bShape = type.inputShapes[1];
|
|
431
|
+
const n = aShape[aShape.length - 1];
|
|
432
|
+
const numRhs = bShape[bShape.length - 2];
|
|
433
|
+
const numMatrices = require_backend.prod(aShape.slice(0, -2));
|
|
434
|
+
const needsF16 = dtype === require_backend.DType.Float16;
|
|
435
|
+
const ty = dtypeToWgsl(dtype, true);
|
|
436
|
+
const workgroupSize = require_backend.findPow2(n, device.limits.maxComputeWorkgroupSizeX);
|
|
437
|
+
const code = `
|
|
438
|
+
${needsF16 ? "enable f16;" : ""}
|
|
439
|
+
${headerWgsl}
|
|
440
|
+
|
|
441
|
+
@group(0) @binding(0) var<storage, read> a: array<${ty}>;
|
|
442
|
+
@group(0) @binding(1) var<storage, read> b: array<${ty}>;
|
|
443
|
+
@group(0) @binding(2) var<storage, read_write> x: array<${ty}>;
|
|
444
|
+
|
|
445
|
+
// Shared memory for the current pivot value x[j]
|
|
446
|
+
var<workgroup> x_j: ${ty};
|
|
447
|
+
|
|
448
|
+
@compute @workgroup_size(${workgroupSize})
|
|
449
|
+
fn main(
|
|
450
|
+
@builtin(workgroup_id) wg_id: vec3<u32>,
|
|
451
|
+
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
452
|
+
) {
|
|
453
|
+
let wg_idx = wg_id.x + wg_id.y * ${gridOffsetY}u;
|
|
454
|
+
let mat_idx = wg_idx / ${numRhs}u;
|
|
455
|
+
let rhs_idx = wg_idx % ${numRhs}u;
|
|
456
|
+
|
|
457
|
+
if (mat_idx >= ${numMatrices}u) {
|
|
458
|
+
return;
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
let a_base = mat_idx * ${n * n}u;
|
|
462
|
+
let bx_base = (mat_idx * ${numRhs}u + rhs_idx) * ${n}u;
|
|
463
|
+
let tid = local_id.x;
|
|
464
|
+
|
|
465
|
+
// Step 1: Copy b to x (threads collaborate)
|
|
466
|
+
for (var idx = tid; idx < ${n}u; idx += ${workgroupSize}u) {
|
|
467
|
+
x[bx_base + idx] = b[bx_base + idx];
|
|
468
|
+
}
|
|
469
|
+
storageBarrier();
|
|
470
|
+
|
|
471
|
+
// Step 2: Back-substitution from j = n-1 down to 0
|
|
472
|
+
for (var jj = 0u; jj < ${n}u; jj++) {
|
|
473
|
+
let j = ${n - 1}u - jj;
|
|
474
|
+
|
|
475
|
+
// Thread 0 computes x[j] = x[j] / a[j,j]
|
|
476
|
+
if (tid == 0u) {
|
|
477
|
+
${params.unitDiagonal ? `x_j = x[bx_base + j];` : `x_j = x[bx_base + j] / a[a_base + j * ${n}u + j];`}
|
|
478
|
+
x[bx_base + j] = x_j;
|
|
479
|
+
}
|
|
480
|
+
workgroupBarrier(); // Sync shared memory x_j
|
|
481
|
+
|
|
482
|
+
// All threads subtract x[j] * a[i,j] from x[i] for i < j
|
|
483
|
+
for (var i = tid; i < j; i += ${workgroupSize}u) {
|
|
484
|
+
x[bx_base + i] -= x_j * a[a_base + i * ${n}u + j];
|
|
485
|
+
}
|
|
486
|
+
workgroupBarrier();
|
|
487
|
+
storageBarrier();
|
|
488
|
+
}
|
|
489
|
+
}
|
|
490
|
+
`.trim();
|
|
491
|
+
const totalWorkgroups = numMatrices * numRhs;
|
|
492
|
+
const grid = calculateGrid(totalWorkgroups);
|
|
493
|
+
return [{
|
|
494
|
+
code,
|
|
495
|
+
numInputs: 2,
|
|
496
|
+
numOutputs: 1,
|
|
497
|
+
hasUniform: false,
|
|
498
|
+
passes: [{ grid }]
|
|
499
|
+
}];
|
|
500
|
+
}
|
|
501
|
+
/**
|
|
502
|
+
* Generate a Cholesky decomposition shader.
|
|
503
|
+
*
|
|
504
|
+
* Computes the lower triangular matrix L such that A = L * L^T for each
|
|
505
|
+
* positive semi-definite matrix in the batch. Uses the Cholesky-Crout
|
|
506
|
+
* algorithm which processes column-by-column.
|
|
507
|
+
*
|
|
508
|
+
* For each column j:
|
|
509
|
+
* 1. All threads compute their row's sum in parallel and store to output
|
|
510
|
+
* 2. Thread 0 computes L[j][j] = sqrt(output[j][j]) and stores to shared memory
|
|
511
|
+
* 3. All threads divide their output[i][j] by L[j][j] in parallel
|
|
512
|
+
*/
|
|
513
|
+
function createCholesky(device, type) {
|
|
514
|
+
const dtype = type.inputDtypes[0];
|
|
515
|
+
const shape = type.inputShapes[0];
|
|
516
|
+
const n = shape[shape.length - 1];
|
|
517
|
+
const batches = require_backend.prod(shape.slice(0, -2));
|
|
518
|
+
const needsF16 = dtype === require_backend.DType.Float16;
|
|
519
|
+
const ty = dtypeToWgsl(dtype, true);
|
|
520
|
+
const workgroupSize = require_backend.findPow2(n, device.limits.maxComputeWorkgroupSizeX);
|
|
521
|
+
const code = `
|
|
522
|
+
${needsF16 ? "enable f16;" : ""}
|
|
523
|
+
${headerWgsl}
|
|
524
|
+
|
|
525
|
+
@group(0) @binding(0) var<storage, read> input: array<${ty}>;
|
|
526
|
+
@group(0) @binding(1) var<storage, read_write> output: array<${ty}>;
|
|
527
|
+
|
|
528
|
+
// Shared memory for the diagonal element
|
|
529
|
+
var<workgroup> L_jj: ${ty};
|
|
530
|
+
|
|
531
|
+
@compute @workgroup_size(${workgroupSize})
|
|
532
|
+
fn main(
|
|
533
|
+
@builtin(workgroup_id) wg_id: vec3<u32>,
|
|
534
|
+
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
535
|
+
) {
|
|
536
|
+
let batch = wg_id.x + wg_id.y * ${gridOffsetY}u;
|
|
537
|
+
if (batch >= ${batches}u) {
|
|
538
|
+
return;
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
let base = batch * ${n * n}u;
|
|
542
|
+
let tid = local_id.x;
|
|
543
|
+
|
|
544
|
+
// Zero out output and copy lower triangle from input (threads collaborate)
|
|
545
|
+
for (var idx = tid; idx < ${n * n}u; idx += ${workgroupSize}u) {
|
|
546
|
+
let row = idx / ${n}u;
|
|
547
|
+
let col = idx % ${n}u;
|
|
548
|
+
output[base + idx] = select(0, input[base + idx], col <= row);
|
|
549
|
+
}
|
|
550
|
+
storageBarrier();
|
|
551
|
+
|
|
552
|
+
// Cholesky-Crout algorithm: process column by column
|
|
553
|
+
for (var j = 0u; j < ${n}u; j++) {
|
|
554
|
+
// Step 1: All threads compute sum for their rows i >= j in parallel
|
|
555
|
+
// sum = A[i][j] - sum(L[i][k] * L[j][k] for k < j)
|
|
556
|
+
for (var i = j + tid; i < ${n}u; i += ${workgroupSize}u) {
|
|
557
|
+
var sum = output[base + i * ${n}u + j];
|
|
558
|
+
for (var k = 0u; k < j; k++) {
|
|
559
|
+
sum -= output[base + i * ${n}u + k] * output[base + j * ${n}u + k];
|
|
560
|
+
}
|
|
561
|
+
output[base + i * ${n}u + j] = sum;
|
|
562
|
+
}
|
|
563
|
+
storageBarrier();
|
|
564
|
+
|
|
565
|
+
// Step 2: Thread 0 computes L[j][j] = sqrt(output[j][j])
|
|
566
|
+
if (tid == 0u) {
|
|
567
|
+
L_jj = sqrt(output[base + j * ${n}u + j]);
|
|
568
|
+
output[base + j * ${n}u + j] = L_jj;
|
|
569
|
+
}
|
|
570
|
+
workgroupBarrier();
|
|
571
|
+
|
|
572
|
+
// Step 3: All threads divide output[i][j] by L[j][j] for i > j
|
|
573
|
+
for (var i = j + 1u + tid; i < ${n}u; i += ${workgroupSize}u) {
|
|
574
|
+
output[base + i * ${n}u + j] /= L_jj;
|
|
575
|
+
}
|
|
576
|
+
storageBarrier();
|
|
577
|
+
}
|
|
578
|
+
}
|
|
579
|
+
`.trim();
|
|
580
|
+
const grid = calculateGrid(batches);
|
|
581
|
+
return [{
|
|
582
|
+
code,
|
|
583
|
+
numInputs: 1,
|
|
584
|
+
numOutputs: 1,
|
|
585
|
+
hasUniform: false,
|
|
586
|
+
passes: [{ grid }]
|
|
587
|
+
}];
|
|
588
|
+
}
|
|
589
|
+
/**
|
|
590
|
+
* Generate an LU decomposition shader with partial pivoting.
|
|
591
|
+
*
|
|
592
|
+
* Computes PA = LU where P is a permutation matrix, L is lower triangular
|
|
593
|
+
* with unit diagonal, and U is upper triangular.
|
|
594
|
+
*
|
|
595
|
+
* For each column j:
|
|
596
|
+
* 1. Find pivot row (max absolute value in column j, rows >= j)
|
|
597
|
+
* 2. Swap rows j and pivot row
|
|
598
|
+
* 3. Compute L[i][j] = A[i][j] / A[j][j] for i > j
|
|
599
|
+
* 4. Update submatrix: A[i][k] -= L[i][j] * A[j][k] for i > j, k > j
|
|
600
|
+
*/
|
|
601
|
+
function createLU(device, type) {
|
|
602
|
+
const dtype = type.inputDtypes[0];
|
|
603
|
+
const shape = type.inputShapes[0];
|
|
604
|
+
const m = shape[shape.length - 2];
|
|
605
|
+
const n = shape[shape.length - 1];
|
|
606
|
+
const r = Math.min(m, n);
|
|
607
|
+
const batches = require_backend.prod(shape.slice(0, -2));
|
|
608
|
+
const needsF16 = dtype === require_backend.DType.Float16;
|
|
609
|
+
const ty = dtypeToWgsl(dtype, true);
|
|
610
|
+
const workgroupSize = require_backend.findPow2(Math.max(m, n), device.limits.maxComputeWorkgroupSizeX);
|
|
611
|
+
const code = `
|
|
612
|
+
${needsF16 ? "enable f16;" : ""}
|
|
613
|
+
${headerWgsl}
|
|
614
|
+
|
|
615
|
+
@group(0) @binding(0) var<storage, read> input: array<${ty}>;
|
|
616
|
+
@group(0) @binding(1) var<storage, read_write> lu: array<${ty}>;
|
|
617
|
+
@group(0) @binding(2) var<storage, read_write> pivots: array<i32>;
|
|
618
|
+
@group(0) @binding(3) var<storage, read_write> perm: array<i32>;
|
|
619
|
+
|
|
620
|
+
var<workgroup> pivot_row: u32;
|
|
621
|
+
var<workgroup> pivot_val: ${ty};
|
|
622
|
+
|
|
623
|
+
@compute @workgroup_size(${workgroupSize})
|
|
624
|
+
fn main(
|
|
625
|
+
@builtin(workgroup_id) wg_id: vec3<u32>,
|
|
626
|
+
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
627
|
+
) {
|
|
628
|
+
let batch = wg_id.x + wg_id.y * ${gridOffsetY}u;
|
|
629
|
+
if (batch >= ${batches}u) {
|
|
630
|
+
return;
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
let lu_base = batch * ${m * n}u;
|
|
634
|
+
let piv_base = batch * ${r}u;
|
|
635
|
+
let perm_base = batch * ${m}u;
|
|
636
|
+
let tid = local_id.x;
|
|
637
|
+
|
|
638
|
+
// Copy input to lu
|
|
639
|
+
for (var idx = tid; idx < ${m * n}u; idx += ${workgroupSize}u) {
|
|
640
|
+
lu[lu_base + idx] = input[lu_base + idx];
|
|
641
|
+
}
|
|
642
|
+
// Initialize permutation
|
|
643
|
+
for (var idx = tid; idx < ${m}u; idx += ${workgroupSize}u) {
|
|
644
|
+
perm[perm_base + idx] = i32(idx);
|
|
645
|
+
}
|
|
646
|
+
storageBarrier();
|
|
647
|
+
|
|
648
|
+
// LU decomposition with partial pivoting
|
|
649
|
+
for (var j = 0u; j < ${r}u; j++) {
|
|
650
|
+
// Step 1: Thread 0 finds pivot (max abs value in column j, rows >= j)
|
|
651
|
+
if (tid == 0u) {
|
|
652
|
+
var max_val = abs(lu[lu_base + j * ${n}u + j]);
|
|
653
|
+
var max_row = j;
|
|
654
|
+
for (var i = j + 1u; i < ${m}u; i++) {
|
|
655
|
+
let val = abs(lu[lu_base + i * ${n}u + j]);
|
|
656
|
+
if (val > max_val) {
|
|
657
|
+
max_val = val;
|
|
658
|
+
max_row = i;
|
|
659
|
+
}
|
|
660
|
+
}
|
|
661
|
+
pivot_row = max_row;
|
|
662
|
+
pivot_val = lu[lu_base + max_row * ${n}u + j];
|
|
663
|
+
pivots[piv_base + j] = i32(max_row);
|
|
664
|
+
}
|
|
665
|
+
workgroupBarrier();
|
|
666
|
+
|
|
667
|
+
// Step 2: Swap rows j and pivot_row (threads collaborate)
|
|
668
|
+
let pr = pivot_row;
|
|
669
|
+
if (pr != j) {
|
|
670
|
+
for (var col = tid; col < ${n}u; col += ${workgroupSize}u) {
|
|
671
|
+
let tmp = lu[lu_base + j * ${n}u + col];
|
|
672
|
+
lu[lu_base + j * ${n}u + col] = lu[lu_base + pr * ${n}u + col];
|
|
673
|
+
lu[lu_base + pr * ${n}u + col] = tmp;
|
|
674
|
+
}
|
|
675
|
+
if (tid == 0u) {
|
|
676
|
+
let tmp_p = perm[perm_base + j];
|
|
677
|
+
perm[perm_base + j] = perm[perm_base + pr];
|
|
678
|
+
perm[perm_base + pr] = tmp_p;
|
|
679
|
+
}
|
|
680
|
+
}
|
|
681
|
+
storageBarrier();
|
|
682
|
+
|
|
683
|
+
// Step 3: Compute L[i][j] and update submatrix
|
|
684
|
+
// Each thread handles one row i > j
|
|
685
|
+
for (var i = j + 1u + tid; i < ${m}u; i += ${workgroupSize}u) {
|
|
686
|
+
let factor = lu[lu_base + i * ${n}u + j] / pivot_val;
|
|
687
|
+
lu[lu_base + i * ${n}u + j] = factor; // L[i][j]
|
|
688
|
+
for (var k = j + 1u; k < ${n}u; k++) {
|
|
689
|
+
lu[lu_base + i * ${n}u + k] -= factor * lu[lu_base + j * ${n}u + k];
|
|
690
|
+
}
|
|
691
|
+
}
|
|
692
|
+
storageBarrier();
|
|
693
|
+
}
|
|
694
|
+
}
|
|
695
|
+
`.trim();
|
|
696
|
+
const grid = calculateGrid(batches);
|
|
697
|
+
return [{
|
|
698
|
+
code,
|
|
699
|
+
numInputs: 1,
|
|
700
|
+
numOutputs: 3,
|
|
701
|
+
hasUniform: false,
|
|
702
|
+
passes: [{ grid }]
|
|
703
|
+
}];
|
|
704
|
+
}
|
|
417
705
|
function createRoutineShader(device, routine) {
|
|
418
706
|
switch (routine.name) {
|
|
419
707
|
case require_backend.Routines.Sort: return createSort(device, routine.type);
|
|
420
708
|
case require_backend.Routines.Argsort: return createArgsort(device, routine.type);
|
|
709
|
+
case require_backend.Routines.TriangularSolve: return createTriangularSolve(device, routine.type, routine.params);
|
|
710
|
+
case require_backend.Routines.Cholesky: return createCholesky(device, routine.type);
|
|
711
|
+
case require_backend.Routines.LU: return createLU(device, routine.type);
|
|
421
712
|
default: throw new require_backend.UnsupportedRoutineError(routine.name, "webgpu");
|
|
422
713
|
}
|
|
423
714
|
}
|
|
@@ -675,8 +966,10 @@ function pipelineSource(device, kernel) {
|
|
|
675
966
|
else source = `(${a} * ${b})`;
|
|
676
967
|
else if (op === require_backend.AluOp.Idiv) source = require_backend.isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
|
|
677
968
|
else if (op === require_backend.AluOp.Mod) source = `(${a} % ${b})`;
|
|
678
|
-
else if (op === require_backend.AluOp.Min) source = `
|
|
679
|
-
else
|
|
969
|
+
else if (op === require_backend.AluOp.Min) if (dtype === require_backend.DType.Bool) source = `(${a} && ${b})`;
|
|
970
|
+
else source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
971
|
+
else if (op === require_backend.AluOp.Max) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
|
|
972
|
+
else source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
680
973
|
else if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
681
974
|
else if (op === require_backend.AluOp.Cmpne) if (require_backend.isFloatDtype(src[0].dtype)) {
|
|
682
975
|
const x = isGensym(a) ? a : gensym();
|