@jax-js/jax 0.1.4 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-Bu9GY6sK.cjs');
1
+ const require_backend = require('./backend-DziQSaoQ.cjs');
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -414,10 +414,301 @@ function createArgsort(device, type) {
414
414
  const batches = require_backend.prod(shape.slice(0, -1));
415
415
  return bitonicSortShader(device, dtype, n, batches, true);
416
416
  }
417
+ /**
418
+ * Generate a triangular solve shader.
419
+ *
420
+ * Solves A @ X.T = B.T for X, where A is upper-triangular.
421
+ * Uses a parallelized back-substitution:
422
+ * 1. Copy b to x
423
+ * 2. For j = n-1 down to 0:
424
+ * - Divide x[j] by a[j,j] (single thread)
425
+ * - All threads subtract x[j] * a[i,j] from x[i] for i < j in parallel
426
+ */
427
+ function createTriangularSolve(device, type, params) {
428
+ const dtype = type.inputDtypes[0];
429
+ const aShape = type.inputShapes[0];
430
+ const bShape = type.inputShapes[1];
431
+ const n = aShape[aShape.length - 1];
432
+ const numRhs = bShape[bShape.length - 2];
433
+ const numMatrices = require_backend.prod(aShape.slice(0, -2));
434
+ const needsF16 = dtype === require_backend.DType.Float16;
435
+ const ty = dtypeToWgsl(dtype, true);
436
+ const workgroupSize = require_backend.findPow2(n, device.limits.maxComputeWorkgroupSizeX);
437
+ const code = `
438
+ ${needsF16 ? "enable f16;" : ""}
439
+ ${headerWgsl}
440
+
441
+ @group(0) @binding(0) var<storage, read> a: array<${ty}>;
442
+ @group(0) @binding(1) var<storage, read> b: array<${ty}>;
443
+ @group(0) @binding(2) var<storage, read_write> x: array<${ty}>;
444
+
445
+ // Shared memory for the current pivot value x[j]
446
+ var<workgroup> x_j: ${ty};
447
+
448
+ @compute @workgroup_size(${workgroupSize})
449
+ fn main(
450
+ @builtin(workgroup_id) wg_id: vec3<u32>,
451
+ @builtin(local_invocation_id) local_id: vec3<u32>,
452
+ ) {
453
+ let wg_idx = wg_id.x + wg_id.y * ${gridOffsetY}u;
454
+ let mat_idx = wg_idx / ${numRhs}u;
455
+ let rhs_idx = wg_idx % ${numRhs}u;
456
+
457
+ if (mat_idx >= ${numMatrices}u) {
458
+ return;
459
+ }
460
+
461
+ let a_base = mat_idx * ${n * n}u;
462
+ let bx_base = (mat_idx * ${numRhs}u + rhs_idx) * ${n}u;
463
+ let tid = local_id.x;
464
+
465
+ // Step 1: Copy b to x (threads collaborate)
466
+ for (var idx = tid; idx < ${n}u; idx += ${workgroupSize}u) {
467
+ x[bx_base + idx] = b[bx_base + idx];
468
+ }
469
+ storageBarrier();
470
+
471
+ // Step 2: Back-substitution from j = n-1 down to 0
472
+ for (var jj = 0u; jj < ${n}u; jj++) {
473
+ let j = ${n - 1}u - jj;
474
+
475
+ // Thread 0 computes x[j] = x[j] / a[j,j]
476
+ if (tid == 0u) {
477
+ ${params.unitDiagonal ? `x_j = x[bx_base + j];` : `x_j = x[bx_base + j] / a[a_base + j * ${n}u + j];`}
478
+ x[bx_base + j] = x_j;
479
+ }
480
+ workgroupBarrier(); // Sync shared memory x_j
481
+
482
+ // All threads subtract x[j] * a[i,j] from x[i] for i < j
483
+ for (var i = tid; i < j; i += ${workgroupSize}u) {
484
+ x[bx_base + i] -= x_j * a[a_base + i * ${n}u + j];
485
+ }
486
+ workgroupBarrier();
487
+ storageBarrier();
488
+ }
489
+ }
490
+ `.trim();
491
+ const totalWorkgroups = numMatrices * numRhs;
492
+ const grid = calculateGrid(totalWorkgroups);
493
+ return [{
494
+ code,
495
+ numInputs: 2,
496
+ numOutputs: 1,
497
+ hasUniform: false,
498
+ passes: [{ grid }]
499
+ }];
500
+ }
501
+ /**
502
+ * Generate a Cholesky decomposition shader.
503
+ *
504
+ * Computes the lower triangular matrix L such that A = L * L^T for each
505
+ * positive semi-definite matrix in the batch. Uses the Cholesky-Crout
506
+ * algorithm which processes column-by-column.
507
+ *
508
+ * For each column j:
509
+ * 1. All threads compute their row's sum in parallel and store to output
510
+ * 2. Thread 0 computes L[j][j] = sqrt(output[j][j]) and stores to shared memory
511
+ * 3. All threads divide their output[i][j] by L[j][j] in parallel
512
+ */
513
+ function createCholesky(device, type) {
514
+ const dtype = type.inputDtypes[0];
515
+ const shape = type.inputShapes[0];
516
+ const n = shape[shape.length - 1];
517
+ const batches = require_backend.prod(shape.slice(0, -2));
518
+ const needsF16 = dtype === require_backend.DType.Float16;
519
+ const ty = dtypeToWgsl(dtype, true);
520
+ const workgroupSize = require_backend.findPow2(n, device.limits.maxComputeWorkgroupSizeX);
521
+ const code = `
522
+ ${needsF16 ? "enable f16;" : ""}
523
+ ${headerWgsl}
524
+
525
+ @group(0) @binding(0) var<storage, read> input: array<${ty}>;
526
+ @group(0) @binding(1) var<storage, read_write> output: array<${ty}>;
527
+
528
+ // Shared memory for the diagonal element
529
+ var<workgroup> L_jj: ${ty};
530
+
531
+ @compute @workgroup_size(${workgroupSize})
532
+ fn main(
533
+ @builtin(workgroup_id) wg_id: vec3<u32>,
534
+ @builtin(local_invocation_id) local_id: vec3<u32>,
535
+ ) {
536
+ let batch = wg_id.x + wg_id.y * ${gridOffsetY}u;
537
+ if (batch >= ${batches}u) {
538
+ return;
539
+ }
540
+
541
+ let base = batch * ${n * n}u;
542
+ let tid = local_id.x;
543
+
544
+ // Zero out output and copy lower triangle from input (threads collaborate)
545
+ for (var idx = tid; idx < ${n * n}u; idx += ${workgroupSize}u) {
546
+ let row = idx / ${n}u;
547
+ let col = idx % ${n}u;
548
+ output[base + idx] = select(0, input[base + idx], col <= row);
549
+ }
550
+ storageBarrier();
551
+
552
+ // Cholesky-Crout algorithm: process column by column
553
+ for (var j = 0u; j < ${n}u; j++) {
554
+ // Step 1: All threads compute sum for their rows i >= j in parallel
555
+ // sum = A[i][j] - sum(L[i][k] * L[j][k] for k < j)
556
+ for (var i = j + tid; i < ${n}u; i += ${workgroupSize}u) {
557
+ var sum = output[base + i * ${n}u + j];
558
+ for (var k = 0u; k < j; k++) {
559
+ sum -= output[base + i * ${n}u + k] * output[base + j * ${n}u + k];
560
+ }
561
+ output[base + i * ${n}u + j] = sum;
562
+ }
563
+ storageBarrier();
564
+
565
+ // Step 2: Thread 0 computes L[j][j] = sqrt(output[j][j])
566
+ if (tid == 0u) {
567
+ L_jj = sqrt(output[base + j * ${n}u + j]);
568
+ output[base + j * ${n}u + j] = L_jj;
569
+ }
570
+ workgroupBarrier();
571
+
572
+ // Step 3: All threads divide output[i][j] by L[j][j] for i > j
573
+ for (var i = j + 1u + tid; i < ${n}u; i += ${workgroupSize}u) {
574
+ output[base + i * ${n}u + j] /= L_jj;
575
+ }
576
+ storageBarrier();
577
+ }
578
+ }
579
+ `.trim();
580
+ const grid = calculateGrid(batches);
581
+ return [{
582
+ code,
583
+ numInputs: 1,
584
+ numOutputs: 1,
585
+ hasUniform: false,
586
+ passes: [{ grid }]
587
+ }];
588
+ }
589
+ /**
590
+ * Generate an LU decomposition shader with partial pivoting.
591
+ *
592
+ * Computes PA = LU where P is a permutation matrix, L is lower triangular
593
+ * with unit diagonal, and U is upper triangular.
594
+ *
595
+ * For each column j:
596
+ * 1. Find pivot row (max absolute value in column j, rows >= j)
597
+ * 2. Swap rows j and pivot row
598
+ * 3. Compute L[i][j] = A[i][j] / A[j][j] for i > j
599
+ * 4. Update submatrix: A[i][k] -= L[i][j] * A[j][k] for i > j, k > j
600
+ */
601
+ function createLU(device, type) {
602
+ const dtype = type.inputDtypes[0];
603
+ const shape = type.inputShapes[0];
604
+ const m = shape[shape.length - 2];
605
+ const n = shape[shape.length - 1];
606
+ const r = Math.min(m, n);
607
+ const batches = require_backend.prod(shape.slice(0, -2));
608
+ const needsF16 = dtype === require_backend.DType.Float16;
609
+ const ty = dtypeToWgsl(dtype, true);
610
+ const workgroupSize = require_backend.findPow2(Math.max(m, n), device.limits.maxComputeWorkgroupSizeX);
611
+ const code = `
612
+ ${needsF16 ? "enable f16;" : ""}
613
+ ${headerWgsl}
614
+
615
+ @group(0) @binding(0) var<storage, read> input: array<${ty}>;
616
+ @group(0) @binding(1) var<storage, read_write> lu: array<${ty}>;
617
+ @group(0) @binding(2) var<storage, read_write> pivots: array<i32>;
618
+ @group(0) @binding(3) var<storage, read_write> perm: array<i32>;
619
+
620
+ var<workgroup> pivot_row: u32;
621
+ var<workgroup> pivot_val: ${ty};
622
+
623
+ @compute @workgroup_size(${workgroupSize})
624
+ fn main(
625
+ @builtin(workgroup_id) wg_id: vec3<u32>,
626
+ @builtin(local_invocation_id) local_id: vec3<u32>,
627
+ ) {
628
+ let batch = wg_id.x + wg_id.y * ${gridOffsetY}u;
629
+ if (batch >= ${batches}u) {
630
+ return;
631
+ }
632
+
633
+ let lu_base = batch * ${m * n}u;
634
+ let piv_base = batch * ${r}u;
635
+ let perm_base = batch * ${m}u;
636
+ let tid = local_id.x;
637
+
638
+ // Copy input to lu
639
+ for (var idx = tid; idx < ${m * n}u; idx += ${workgroupSize}u) {
640
+ lu[lu_base + idx] = input[lu_base + idx];
641
+ }
642
+ // Initialize permutation
643
+ for (var idx = tid; idx < ${m}u; idx += ${workgroupSize}u) {
644
+ perm[perm_base + idx] = i32(idx);
645
+ }
646
+ storageBarrier();
647
+
648
+ // LU decomposition with partial pivoting
649
+ for (var j = 0u; j < ${r}u; j++) {
650
+ // Step 1: Thread 0 finds pivot (max abs value in column j, rows >= j)
651
+ if (tid == 0u) {
652
+ var max_val = abs(lu[lu_base + j * ${n}u + j]);
653
+ var max_row = j;
654
+ for (var i = j + 1u; i < ${m}u; i++) {
655
+ let val = abs(lu[lu_base + i * ${n}u + j]);
656
+ if (val > max_val) {
657
+ max_val = val;
658
+ max_row = i;
659
+ }
660
+ }
661
+ pivot_row = max_row;
662
+ pivot_val = lu[lu_base + max_row * ${n}u + j];
663
+ pivots[piv_base + j] = i32(max_row);
664
+ }
665
+ workgroupBarrier();
666
+
667
+ // Step 2: Swap rows j and pivot_row (threads collaborate)
668
+ let pr = pivot_row;
669
+ if (pr != j) {
670
+ for (var col = tid; col < ${n}u; col += ${workgroupSize}u) {
671
+ let tmp = lu[lu_base + j * ${n}u + col];
672
+ lu[lu_base + j * ${n}u + col] = lu[lu_base + pr * ${n}u + col];
673
+ lu[lu_base + pr * ${n}u + col] = tmp;
674
+ }
675
+ if (tid == 0u) {
676
+ let tmp_p = perm[perm_base + j];
677
+ perm[perm_base + j] = perm[perm_base + pr];
678
+ perm[perm_base + pr] = tmp_p;
679
+ }
680
+ }
681
+ storageBarrier();
682
+
683
+ // Step 3: Compute L[i][j] and update submatrix
684
+ // Each thread handles one row i > j
685
+ for (var i = j + 1u + tid; i < ${m}u; i += ${workgroupSize}u) {
686
+ let factor = lu[lu_base + i * ${n}u + j] / pivot_val;
687
+ lu[lu_base + i * ${n}u + j] = factor; // L[i][j]
688
+ for (var k = j + 1u; k < ${n}u; k++) {
689
+ lu[lu_base + i * ${n}u + k] -= factor * lu[lu_base + j * ${n}u + k];
690
+ }
691
+ }
692
+ storageBarrier();
693
+ }
694
+ }
695
+ `.trim();
696
+ const grid = calculateGrid(batches);
697
+ return [{
698
+ code,
699
+ numInputs: 1,
700
+ numOutputs: 3,
701
+ hasUniform: false,
702
+ passes: [{ grid }]
703
+ }];
704
+ }
417
705
  function createRoutineShader(device, routine) {
418
706
  switch (routine.name) {
419
707
  case require_backend.Routines.Sort: return createSort(device, routine.type);
420
708
  case require_backend.Routines.Argsort: return createArgsort(device, routine.type);
709
+ case require_backend.Routines.TriangularSolve: return createTriangularSolve(device, routine.type, routine.params);
710
+ case require_backend.Routines.Cholesky: return createCholesky(device, routine.type);
711
+ case require_backend.Routines.LU: return createLU(device, routine.type);
421
712
  default: throw new require_backend.UnsupportedRoutineError(routine.name, "webgpu");
422
713
  }
423
714
  }
@@ -675,8 +966,10 @@ function pipelineSource(device, kernel) {
675
966
  else source = `(${a} * ${b})`;
676
967
  else if (op === require_backend.AluOp.Idiv) source = require_backend.isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
677
968
  else if (op === require_backend.AluOp.Mod) source = `(${a} % ${b})`;
678
- else if (op === require_backend.AluOp.Min) source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
679
- else if (op === require_backend.AluOp.Max) source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
969
+ else if (op === require_backend.AluOp.Min) if (dtype === require_backend.DType.Bool) source = `(${a} && ${b})`;
970
+ else source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
971
+ else if (op === require_backend.AluOp.Max) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
972
+ else source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
680
973
  else if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
681
974
  else if (op === require_backend.AluOp.Cmpne) if (require_backend.isFloatDtype(src[0].dtype)) {
682
975
  const x = isGensym(a) ? a : gensym();
@@ -1,4 +1,4 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-tngXtWe4.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-DaqL-MNz.js";
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -414,10 +414,301 @@ function createArgsort(device, type) {
414
414
  const batches = prod(shape.slice(0, -1));
415
415
  return bitonicSortShader(device, dtype, n, batches, true);
416
416
  }
417
+ /**
418
+ * Generate a triangular solve shader.
419
+ *
420
+ * Solves A @ X.T = B.T for X, where A is upper-triangular.
421
+ * Uses a parallelized back-substitution:
422
+ * 1. Copy b to x
423
+ * 2. For j = n-1 down to 0:
424
+ * - Divide x[j] by a[j,j] (single thread)
425
+ * - All threads subtract x[j] * a[i,j] from x[i] for i < j in parallel
426
+ */
427
+ function createTriangularSolve(device, type, params) {
428
+ const dtype = type.inputDtypes[0];
429
+ const aShape = type.inputShapes[0];
430
+ const bShape = type.inputShapes[1];
431
+ const n = aShape[aShape.length - 1];
432
+ const numRhs = bShape[bShape.length - 2];
433
+ const numMatrices = prod(aShape.slice(0, -2));
434
+ const needsF16 = dtype === DType.Float16;
435
+ const ty = dtypeToWgsl(dtype, true);
436
+ const workgroupSize = findPow2(n, device.limits.maxComputeWorkgroupSizeX);
437
+ const code = `
438
+ ${needsF16 ? "enable f16;" : ""}
439
+ ${headerWgsl}
440
+
441
+ @group(0) @binding(0) var<storage, read> a: array<${ty}>;
442
+ @group(0) @binding(1) var<storage, read> b: array<${ty}>;
443
+ @group(0) @binding(2) var<storage, read_write> x: array<${ty}>;
444
+
445
+ // Shared memory for the current pivot value x[j]
446
+ var<workgroup> x_j: ${ty};
447
+
448
+ @compute @workgroup_size(${workgroupSize})
449
+ fn main(
450
+ @builtin(workgroup_id) wg_id: vec3<u32>,
451
+ @builtin(local_invocation_id) local_id: vec3<u32>,
452
+ ) {
453
+ let wg_idx = wg_id.x + wg_id.y * ${gridOffsetY}u;
454
+ let mat_idx = wg_idx / ${numRhs}u;
455
+ let rhs_idx = wg_idx % ${numRhs}u;
456
+
457
+ if (mat_idx >= ${numMatrices}u) {
458
+ return;
459
+ }
460
+
461
+ let a_base = mat_idx * ${n * n}u;
462
+ let bx_base = (mat_idx * ${numRhs}u + rhs_idx) * ${n}u;
463
+ let tid = local_id.x;
464
+
465
+ // Step 1: Copy b to x (threads collaborate)
466
+ for (var idx = tid; idx < ${n}u; idx += ${workgroupSize}u) {
467
+ x[bx_base + idx] = b[bx_base + idx];
468
+ }
469
+ storageBarrier();
470
+
471
+ // Step 2: Back-substitution from j = n-1 down to 0
472
+ for (var jj = 0u; jj < ${n}u; jj++) {
473
+ let j = ${n - 1}u - jj;
474
+
475
+ // Thread 0 computes x[j] = x[j] / a[j,j]
476
+ if (tid == 0u) {
477
+ ${params.unitDiagonal ? `x_j = x[bx_base + j];` : `x_j = x[bx_base + j] / a[a_base + j * ${n}u + j];`}
478
+ x[bx_base + j] = x_j;
479
+ }
480
+ workgroupBarrier(); // Sync shared memory x_j
481
+
482
+ // All threads subtract x[j] * a[i,j] from x[i] for i < j
483
+ for (var i = tid; i < j; i += ${workgroupSize}u) {
484
+ x[bx_base + i] -= x_j * a[a_base + i * ${n}u + j];
485
+ }
486
+ workgroupBarrier();
487
+ storageBarrier();
488
+ }
489
+ }
490
+ `.trim();
491
+ const totalWorkgroups = numMatrices * numRhs;
492
+ const grid = calculateGrid(totalWorkgroups);
493
+ return [{
494
+ code,
495
+ numInputs: 2,
496
+ numOutputs: 1,
497
+ hasUniform: false,
498
+ passes: [{ grid }]
499
+ }];
500
+ }
501
+ /**
502
+ * Generate a Cholesky decomposition shader.
503
+ *
504
+ * Computes the lower triangular matrix L such that A = L * L^T for each
505
+ * positive semi-definite matrix in the batch. Uses the Cholesky-Crout
506
+ * algorithm which processes column-by-column.
507
+ *
508
+ * For each column j:
509
+ * 1. All threads compute their row's sum in parallel and store to output
510
+ * 2. Thread 0 computes L[j][j] = sqrt(output[j][j]) and stores to shared memory
511
+ * 3. All threads divide their output[i][j] by L[j][j] in parallel
512
+ */
513
+ function createCholesky(device, type) {
514
+ const dtype = type.inputDtypes[0];
515
+ const shape = type.inputShapes[0];
516
+ const n = shape[shape.length - 1];
517
+ const batches = prod(shape.slice(0, -2));
518
+ const needsF16 = dtype === DType.Float16;
519
+ const ty = dtypeToWgsl(dtype, true);
520
+ const workgroupSize = findPow2(n, device.limits.maxComputeWorkgroupSizeX);
521
+ const code = `
522
+ ${needsF16 ? "enable f16;" : ""}
523
+ ${headerWgsl}
524
+
525
+ @group(0) @binding(0) var<storage, read> input: array<${ty}>;
526
+ @group(0) @binding(1) var<storage, read_write> output: array<${ty}>;
527
+
528
+ // Shared memory for the diagonal element
529
+ var<workgroup> L_jj: ${ty};
530
+
531
+ @compute @workgroup_size(${workgroupSize})
532
+ fn main(
533
+ @builtin(workgroup_id) wg_id: vec3<u32>,
534
+ @builtin(local_invocation_id) local_id: vec3<u32>,
535
+ ) {
536
+ let batch = wg_id.x + wg_id.y * ${gridOffsetY}u;
537
+ if (batch >= ${batches}u) {
538
+ return;
539
+ }
540
+
541
+ let base = batch * ${n * n}u;
542
+ let tid = local_id.x;
543
+
544
+ // Zero out output and copy lower triangle from input (threads collaborate)
545
+ for (var idx = tid; idx < ${n * n}u; idx += ${workgroupSize}u) {
546
+ let row = idx / ${n}u;
547
+ let col = idx % ${n}u;
548
+ output[base + idx] = select(0, input[base + idx], col <= row);
549
+ }
550
+ storageBarrier();
551
+
552
+ // Cholesky-Crout algorithm: process column by column
553
+ for (var j = 0u; j < ${n}u; j++) {
554
+ // Step 1: All threads compute sum for their rows i >= j in parallel
555
+ // sum = A[i][j] - sum(L[i][k] * L[j][k] for k < j)
556
+ for (var i = j + tid; i < ${n}u; i += ${workgroupSize}u) {
557
+ var sum = output[base + i * ${n}u + j];
558
+ for (var k = 0u; k < j; k++) {
559
+ sum -= output[base + i * ${n}u + k] * output[base + j * ${n}u + k];
560
+ }
561
+ output[base + i * ${n}u + j] = sum;
562
+ }
563
+ storageBarrier();
564
+
565
+ // Step 2: Thread 0 computes L[j][j] = sqrt(output[j][j])
566
+ if (tid == 0u) {
567
+ L_jj = sqrt(output[base + j * ${n}u + j]);
568
+ output[base + j * ${n}u + j] = L_jj;
569
+ }
570
+ workgroupBarrier();
571
+
572
+ // Step 3: All threads divide output[i][j] by L[j][j] for i > j
573
+ for (var i = j + 1u + tid; i < ${n}u; i += ${workgroupSize}u) {
574
+ output[base + i * ${n}u + j] /= L_jj;
575
+ }
576
+ storageBarrier();
577
+ }
578
+ }
579
+ `.trim();
580
+ const grid = calculateGrid(batches);
581
+ return [{
582
+ code,
583
+ numInputs: 1,
584
+ numOutputs: 1,
585
+ hasUniform: false,
586
+ passes: [{ grid }]
587
+ }];
588
+ }
589
+ /**
590
+ * Generate an LU decomposition shader with partial pivoting.
591
+ *
592
+ * Computes PA = LU where P is a permutation matrix, L is lower triangular
593
+ * with unit diagonal, and U is upper triangular.
594
+ *
595
+ * For each column j:
596
+ * 1. Find pivot row (max absolute value in column j, rows >= j)
597
+ * 2. Swap rows j and pivot row
598
+ * 3. Compute L[i][j] = A[i][j] / A[j][j] for i > j
599
+ * 4. Update submatrix: A[i][k] -= L[i][j] * A[j][k] for i > j, k > j
600
+ */
601
+ function createLU(device, type) {
602
+ const dtype = type.inputDtypes[0];
603
+ const shape = type.inputShapes[0];
604
+ const m = shape[shape.length - 2];
605
+ const n = shape[shape.length - 1];
606
+ const r = Math.min(m, n);
607
+ const batches = prod(shape.slice(0, -2));
608
+ const needsF16 = dtype === DType.Float16;
609
+ const ty = dtypeToWgsl(dtype, true);
610
+ const workgroupSize = findPow2(Math.max(m, n), device.limits.maxComputeWorkgroupSizeX);
611
+ const code = `
612
+ ${needsF16 ? "enable f16;" : ""}
613
+ ${headerWgsl}
614
+
615
+ @group(0) @binding(0) var<storage, read> input: array<${ty}>;
616
+ @group(0) @binding(1) var<storage, read_write> lu: array<${ty}>;
617
+ @group(0) @binding(2) var<storage, read_write> pivots: array<i32>;
618
+ @group(0) @binding(3) var<storage, read_write> perm: array<i32>;
619
+
620
+ var<workgroup> pivot_row: u32;
621
+ var<workgroup> pivot_val: ${ty};
622
+
623
+ @compute @workgroup_size(${workgroupSize})
624
+ fn main(
625
+ @builtin(workgroup_id) wg_id: vec3<u32>,
626
+ @builtin(local_invocation_id) local_id: vec3<u32>,
627
+ ) {
628
+ let batch = wg_id.x + wg_id.y * ${gridOffsetY}u;
629
+ if (batch >= ${batches}u) {
630
+ return;
631
+ }
632
+
633
+ let lu_base = batch * ${m * n}u;
634
+ let piv_base = batch * ${r}u;
635
+ let perm_base = batch * ${m}u;
636
+ let tid = local_id.x;
637
+
638
+ // Copy input to lu
639
+ for (var idx = tid; idx < ${m * n}u; idx += ${workgroupSize}u) {
640
+ lu[lu_base + idx] = input[lu_base + idx];
641
+ }
642
+ // Initialize permutation
643
+ for (var idx = tid; idx < ${m}u; idx += ${workgroupSize}u) {
644
+ perm[perm_base + idx] = i32(idx);
645
+ }
646
+ storageBarrier();
647
+
648
+ // LU decomposition with partial pivoting
649
+ for (var j = 0u; j < ${r}u; j++) {
650
+ // Step 1: Thread 0 finds pivot (max abs value in column j, rows >= j)
651
+ if (tid == 0u) {
652
+ var max_val = abs(lu[lu_base + j * ${n}u + j]);
653
+ var max_row = j;
654
+ for (var i = j + 1u; i < ${m}u; i++) {
655
+ let val = abs(lu[lu_base + i * ${n}u + j]);
656
+ if (val > max_val) {
657
+ max_val = val;
658
+ max_row = i;
659
+ }
660
+ }
661
+ pivot_row = max_row;
662
+ pivot_val = lu[lu_base + max_row * ${n}u + j];
663
+ pivots[piv_base + j] = i32(max_row);
664
+ }
665
+ workgroupBarrier();
666
+
667
+ // Step 2: Swap rows j and pivot_row (threads collaborate)
668
+ let pr = pivot_row;
669
+ if (pr != j) {
670
+ for (var col = tid; col < ${n}u; col += ${workgroupSize}u) {
671
+ let tmp = lu[lu_base + j * ${n}u + col];
672
+ lu[lu_base + j * ${n}u + col] = lu[lu_base + pr * ${n}u + col];
673
+ lu[lu_base + pr * ${n}u + col] = tmp;
674
+ }
675
+ if (tid == 0u) {
676
+ let tmp_p = perm[perm_base + j];
677
+ perm[perm_base + j] = perm[perm_base + pr];
678
+ perm[perm_base + pr] = tmp_p;
679
+ }
680
+ }
681
+ storageBarrier();
682
+
683
+ // Step 3: Compute L[i][j] and update submatrix
684
+ // Each thread handles one row i > j
685
+ for (var i = j + 1u + tid; i < ${m}u; i += ${workgroupSize}u) {
686
+ let factor = lu[lu_base + i * ${n}u + j] / pivot_val;
687
+ lu[lu_base + i * ${n}u + j] = factor; // L[i][j]
688
+ for (var k = j + 1u; k < ${n}u; k++) {
689
+ lu[lu_base + i * ${n}u + k] -= factor * lu[lu_base + j * ${n}u + k];
690
+ }
691
+ }
692
+ storageBarrier();
693
+ }
694
+ }
695
+ `.trim();
696
+ const grid = calculateGrid(batches);
697
+ return [{
698
+ code,
699
+ numInputs: 1,
700
+ numOutputs: 3,
701
+ hasUniform: false,
702
+ passes: [{ grid }]
703
+ }];
704
+ }
417
705
  function createRoutineShader(device, routine) {
418
706
  switch (routine.name) {
419
707
  case Routines.Sort: return createSort(device, routine.type);
420
708
  case Routines.Argsort: return createArgsort(device, routine.type);
709
+ case Routines.TriangularSolve: return createTriangularSolve(device, routine.type, routine.params);
710
+ case Routines.Cholesky: return createCholesky(device, routine.type);
711
+ case Routines.LU: return createLU(device, routine.type);
421
712
  default: throw new UnsupportedRoutineError(routine.name, "webgpu");
422
713
  }
423
714
  }
@@ -675,8 +966,10 @@ function pipelineSource(device, kernel) {
675
966
  else source = `(${a} * ${b})`;
676
967
  else if (op === AluOp.Idiv) source = isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
677
968
  else if (op === AluOp.Mod) source = `(${a} % ${b})`;
678
- else if (op === AluOp.Min) source = `min(${strip1(a)}, ${strip1(b)})`;
679
- else if (op === AluOp.Max) source = `max(${strip1(a)}, ${strip1(b)})`;
969
+ else if (op === AluOp.Min) if (dtype === DType.Bool) source = `(${a} && ${b})`;
970
+ else source = `min(${strip1(a)}, ${strip1(b)})`;
971
+ else if (op === AluOp.Max) if (dtype === DType.Bool) source = `(${a} || ${b})`;
972
+ else source = `max(${strip1(a)}, ${strip1(b)})`;
680
973
  else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
681
974
  else if (op === AluOp.Cmpne) if (isFloatDtype(src[0].dtype)) {
682
975
  const x = isGensym(a) ? a : gensym();
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.1.4",
3
+ "version": "0.1.5",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",