@holoscript/engine 6.0.3 → 6.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/AutoMesher-CK47F6AV.js +17 -0
- package/dist/GPUBuffers-2LHBCD7X.js +9 -0
- package/dist/WebGPUContext-TNEUYU2Y.js +11 -0
- package/dist/animation/index.cjs +38 -38
- package/dist/animation/index.d.cts +1 -1
- package/dist/animation/index.d.ts +1 -1
- package/dist/animation/index.js +1 -1
- package/dist/audio/index.cjs +16 -6
- package/dist/audio/index.d.cts +1 -1
- package/dist/audio/index.d.ts +1 -1
- package/dist/audio/index.js +1 -1
- package/dist/camera/index.cjs +23 -23
- package/dist/camera/index.d.cts +1 -1
- package/dist/camera/index.d.ts +1 -1
- package/dist/camera/index.js +1 -1
- package/dist/character/index.cjs +6 -4
- package/dist/character/index.js +1 -1
- package/dist/choreography/index.cjs +1194 -0
- package/dist/choreography/index.d.cts +687 -0
- package/dist/choreography/index.d.ts +687 -0
- package/dist/choreography/index.js +1156 -0
- package/dist/chunk-2CSNRI2N.js +217 -0
- package/dist/chunk-33T2WINR.js +266 -0
- package/dist/chunk-35R73OFM.js +1257 -0
- package/dist/chunk-4MMDSUNP.js +1256 -0
- package/dist/chunk-5V6HOU72.js +319 -0
- package/dist/chunk-6QOP6PYF.js +1038 -0
- package/dist/chunk-7KMJVHIL.js +8944 -0
- package/dist/chunk-7VPUC62U.js +1106 -0
- package/dist/chunk-A2Y6RCAT.js +1878 -0
- package/dist/chunk-AHM42MK6.js +8944 -0
- package/dist/chunk-BL7IDTHE.js +218 -0
- package/dist/chunk-CITOMSWL.js +10462 -0
- package/dist/chunk-CXDPKW2K.js +8944 -0
- package/dist/chunk-CXZPLD4S.js +223 -0
- package/dist/chunk-CZYJE7IH.js +5169 -0
- package/dist/chunk-D2OP7YC7.js +6325 -0
- package/dist/chunk-EDRVQHUU.js +1544 -0
- package/dist/chunk-EJSLOOW2.js +3589 -0
- package/dist/chunk-F53SFGW5.js +1878 -0
- package/dist/chunk-HCFPELPY.js +919 -0
- package/dist/chunk-HNEE36PY.js +93 -0
- package/dist/chunk-HYXNV36F.js +1256 -0
- package/dist/chunk-IB7KHVFY.js +821 -0
- package/dist/chunk-IBBO7YYG.js +690 -0
- package/dist/chunk-ILIBGINU.js +5470 -0
- package/dist/chunk-IS4MHLKN.js +5479 -0
- package/dist/chunk-JT2PFKWD.js +5479 -0
- package/dist/chunk-K4CUB4NY.js +1038 -0
- package/dist/chunk-KATDQXRJ.js +10462 -0
- package/dist/chunk-KBQE6ZFJ.js +8944 -0
- package/dist/chunk-KBVD5K7E.js +560 -0
- package/dist/chunk-KCDPVQRY.js +4088 -0
- package/dist/chunk-KN4QJPKN.js +8944 -0
- package/dist/chunk-KWJ3ROSI.js +8944 -0
- package/dist/chunk-L45VF6DD.js +919 -0
- package/dist/chunk-LY4T37YK.js +307 -0
- package/dist/chunk-MDN5WZXA.js +1544 -0
- package/dist/chunk-MGCDP6VU.js +928 -0
- package/dist/chunk-NCX7X6G2.js +8681 -0
- package/dist/chunk-OF54BPVD.js +913 -0
- package/dist/chunk-OWSN2Q3Q.js +690 -0
- package/dist/chunk-PRRB5TTA.js +406 -0
- package/dist/chunk-PXWVQF76.js +4086 -0
- package/dist/chunk-PYCOIDT2.js +812 -0
- package/dist/chunk-PZCSADOV.js +928 -0
- package/dist/chunk-Q2XBVS2K.js +1038 -0
- package/dist/chunk-QDZRXWN5.js +1776 -0
- package/dist/chunk-RNWOZ6WQ.js +913 -0
- package/dist/chunk-ROLFT4CJ.js +1693 -0
- package/dist/chunk-SLTJRZ2N.js +266 -0
- package/dist/chunk-SRUS5XSU.js +4088 -0
- package/dist/chunk-TKCA3WZ5.js +5409 -0
- package/dist/chunk-TNRMXYI2.js +1650 -0
- package/dist/chunk-TQB3GJGM.js +9763 -0
- package/dist/chunk-TUFGXG6K.js +510 -0
- package/dist/chunk-U6KMTGQJ.js +632 -0
- package/dist/chunk-VMGJQST6.js +8681 -0
- package/dist/chunk-X4F4TCG4.js +5470 -0
- package/dist/chunk-ZIFROE75.js +1544 -0
- package/dist/chunk-ZIJQYHSQ.js +1204 -0
- package/dist/combat/index.cjs +4 -4
- package/dist/combat/index.d.cts +1 -1
- package/dist/combat/index.d.ts +1 -1
- package/dist/combat/index.js +1 -1
- package/dist/ecs/index.cjs +1 -1
- package/dist/ecs/index.js +1 -1
- package/dist/environment/index.cjs +14 -14
- package/dist/environment/index.d.cts +1 -1
- package/dist/environment/index.d.ts +1 -1
- package/dist/environment/index.js +1 -1
- package/dist/gpu/index.cjs +4810 -0
- package/dist/gpu/index.js +3714 -0
- package/dist/hologram/index.cjs +27 -1
- package/dist/hologram/index.js +1 -1
- package/dist/index-B2PIsAmR.d.cts +2180 -0
- package/dist/index-B2PIsAmR.d.ts +2180 -0
- package/dist/index-BHySEPX7.d.cts +2921 -0
- package/dist/index-BJV21zuy.d.cts +341 -0
- package/dist/index-BJV21zuy.d.ts +341 -0
- package/dist/index-BQutTphC.d.cts +790 -0
- package/dist/index-ByIq2XrS.d.cts +3910 -0
- package/dist/index-BysHjDSO.d.cts +224 -0
- package/dist/index-BysHjDSO.d.ts +224 -0
- package/dist/index-CKwAJGck.d.ts +455 -0
- package/dist/index-CUl3QstQ.d.cts +3006 -0
- package/dist/index-CUl3QstQ.d.ts +3006 -0
- package/dist/index-CmYtNiI-.d.cts +953 -0
- package/dist/index-CmYtNiI-.d.ts +953 -0
- package/dist/index-CnRzWxi_.d.cts +522 -0
- package/dist/index-CnRzWxi_.d.ts +522 -0
- package/dist/index-CwRWbSC7.d.ts +2921 -0
- package/dist/index-CxKIBstO.d.ts +790 -0
- package/dist/index-DJ6-R8vh.d.cts +455 -0
- package/dist/index-DQKisbcI.d.cts +4968 -0
- package/dist/index-DQKisbcI.d.ts +4968 -0
- package/dist/index-DRT2zJez.d.ts +3910 -0
- package/dist/index-DfNLiAka.d.cts +192 -0
- package/dist/index-DfNLiAka.d.ts +192 -0
- package/dist/index-nMvkoRm8.d.cts +405 -0
- package/dist/index-nMvkoRm8.d.ts +405 -0
- package/dist/index-s9yOFU37.d.cts +604 -0
- package/dist/index-s9yOFU37.d.ts +604 -0
- package/dist/index.cjs +22966 -6960
- package/dist/index.d.cts +864 -20
- package/dist/index.d.ts +864 -20
- package/dist/index.js +3062 -48
- package/dist/input/index.cjs +1 -1
- package/dist/input/index.js +1 -1
- package/dist/orbital/index.cjs +3 -3
- package/dist/orbital/index.d.cts +1 -1
- package/dist/orbital/index.d.ts +1 -1
- package/dist/orbital/index.js +1 -1
- package/dist/particles/index.cjs +16 -16
- package/dist/particles/index.d.cts +1 -1
- package/dist/particles/index.d.ts +1 -1
- package/dist/particles/index.js +1 -1
- package/dist/physics/index.cjs +2377 -21
- package/dist/physics/index.d.cts +1 -1
- package/dist/physics/index.d.ts +1 -1
- package/dist/physics/index.js +35 -1
- package/dist/postfx/index.cjs +3491 -0
- package/dist/postfx/index.js +93 -0
- package/dist/procedural/index.cjs +1 -1
- package/dist/procedural/index.js +1 -1
- package/dist/puppeteer-5VF6KDVO.js +52197 -0
- package/dist/puppeteer-IZVZ3SG4.js +52197 -0
- package/dist/rendering/index.cjs +33 -32
- package/dist/rendering/index.d.cts +1 -1
- package/dist/rendering/index.d.ts +1 -1
- package/dist/rendering/index.js +8 -6
- package/dist/runtime/index.cjs +23 -13
- package/dist/runtime/index.d.cts +1 -1
- package/dist/runtime/index.d.ts +1 -1
- package/dist/runtime/index.js +8 -6
- package/dist/runtime/protocols/index.cjs +349 -0
- package/dist/runtime/protocols/index.js +15 -0
- package/dist/scene/index.cjs +8 -8
- package/dist/scene/index.d.cts +1 -1
- package/dist/scene/index.d.ts +1 -1
- package/dist/scene/index.js +1 -1
- package/dist/shader/index.cjs +3087 -0
- package/dist/shader/index.js +3044 -0
- package/dist/simulation/index.cjs +10680 -0
- package/dist/simulation/index.d.cts +3 -0
- package/dist/simulation/index.d.ts +3 -0
- package/dist/simulation/index.js +307 -0
- package/dist/spatial/index.cjs +2443 -0
- package/dist/spatial/index.d.cts +1545 -0
- package/dist/spatial/index.d.ts +1545 -0
- package/dist/spatial/index.js +2400 -0
- package/dist/terrain/index.cjs +1 -1
- package/dist/terrain/index.d.cts +1 -1
- package/dist/terrain/index.d.ts +1 -1
- package/dist/terrain/index.js +1 -1
- package/dist/transformers.node-4NKAPD5U.js +45620 -0
- package/dist/vm/index.cjs +7 -8
- package/dist/vm/index.d.cts +1 -1
- package/dist/vm/index.d.ts +1 -1
- package/dist/vm/index.js +1 -1
- package/dist/vm-bridge/index.cjs +2 -2
- package/dist/vm-bridge/index.d.cts +2 -2
- package/dist/vm-bridge/index.d.ts +2 -2
- package/dist/vm-bridge/index.js +1 -1
- package/dist/vr/index.cjs +6 -6
- package/dist/vr/index.js +1 -1
- package/dist/world/index.cjs +3 -3
- package/dist/world/index.d.cts +1 -1
- package/dist/world/index.d.ts +1 -1
- package/dist/world/index.js +1 -1
- package/package.json +53 -21
- package/LICENSE +0 -21
|
@@ -0,0 +1,510 @@
|
|
|
1
|
+
// wgsl-raw:C:\Users\josep\Documents\GitHub\HoloScript\packages\engine\src\gpu\shaders\cg_kernels.wgsl
|
|
2
|
+
var cg_kernels_default = "/**\n * Conjugate Gradient Kernels \u2014 Sparse Linear Algebra on WebGPU\n *\n * Unified bind group layout:\n * group(0): CSR matrix (SpMV only)\n * group(1): Vectors (vec_in read, vec_out read_write)\n * group(2): SolverArgs uniform\n * group(3): Reduction workspace (dot/final_reduce only)\n *\n * Each entry point references only the groups it needs.\n * With layout:'auto', each pipeline gets a layout derived from\n * only the bindings its entry point actually accesses.\n */\n\n// \u2500\u2500 Shared Types \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n\nstruct SolverArgs {\n num_rows: u32,\n vector_width: u32,\n n: u32,\n alpha: f32,\n};\n\n// \u2500\u2500 Group 0: CSR Matrix \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n\n@group(0) @binding(0) var<storage, read> csr_val: array<f32>;\n@group(0) @binding(1) var<storage, read> csr_col: array<u32>;\n@group(0) @binding(2) var<storage, read> csr_row: array<u32>;\n\n// \u2500\u2500 Group 1: Vectors \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n\n@group(1) @binding(0) var<storage, read> vec_in: array<f32>;\n@group(1) @binding(1) var<storage, read_write> vec_out: array<f32>;\n\n// \u2500\u2500 Group 2: Solver Arguments \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n\n@group(2) @binding(0) var<uniform> args: SolverArgs;\n\n// \u2500\u2500 Group 3: Reduction Workspace \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n\n@group(3) @binding(0) var<storage, read_write> partial_sums: array<f32>;\n@group(3) @binding(1) var<storage, read_write> scalar_result: array<f32>;\n\n// \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n// 1. SpMV \u2014 CSR-Vector (multi-thread per row)\n// Assigns vector_width threads per row for irregular TET10 sparsity.\n// Uses: groups 0, 1, 2\n// \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n\nvar<workgroup> spmv_shared: array<f32, 256>;\n\n@compute @workgroup_size(256)\nfn spmv_vector(\n @builtin(global_invocation_id) global_id: vec3<u32>,\n @builtin(local_invocation_id) local_id: vec3<u32>\n) {\n let tid = local_id.x;\n let gid = global_id.x;\n let threads_per_row = args.vector_width;\n let row = gid / threads_per_row;\n let lane = gid % threads_per_row;\n\n if (row >= args.num_rows) {\n return;\n }\n\n let row_start = csr_row[row];\n let row_end = csr_row[row + 1];\n\n var sum: f32 = 0.0;\n for (var i = row_start + lane; i < row_end; i = i + threads_per_row) {\n sum += csr_val[i] * vec_in[csr_col[i]];\n }\n\n spmv_shared[tid] = sum;\n workgroupBarrier();\n\n for (var s = threads_per_row / 2u; s > 0u; s >>= 1u) {\n if (lane < s) {\n spmv_shared[tid] += spmv_shared[tid + s];\n }\n workgroupBarrier();\n }\n\n if (lane == 0u) {\n vec_out[row] = spmv_shared[tid];\n }\n}\n\n// Legacy scalar SpMV (1 thread per row, for small/regular matrices)\n// Uses: groups 0, 1, 2\n@compute @workgroup_size(64)\nfn spmv(@builtin(global_invocation_id) global_id: vec3<u32>) {\n let row = global_id.x;\n if (row >= args.num_rows) {\n return;\n }\n\n let row_start = csr_row[row];\n let row_end = csr_row[row + 1];\n\n var sum: f32 = 0.0;\n for (var i = row_start; i < row_end; i = i + 1u) {\n sum += csr_val[i] * vec_in[csr_col[i]];\n }\n\n vec_out[row] = sum;\n}\n\n// \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n// 2. SAXPY: vec_out = alpha * vec_in + vec_out\n// Uses: groups 1, 2\n// \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n\n@compute @workgroup_size(256)\nfn saxpy(@builtin(global_invocation_id) global_id: vec3<u32>) {\n let idx = global_id.x;\n if (idx >= args.n) {\n return;\n }\n vec_out[idx] = args.alpha * vec_in[idx] + vec_out[idx];\n}\n\n// \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n// 3. Fused CG Update: p = r + beta * p\n// vec_in = r (read), vec_out = p (read_write), args.alpha = beta\n// Uses: groups 1, 2\n// \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n\n@compute @workgroup_size(256)\nfn p_update(@builtin(global_invocation_id) global_id: vec3<u32>) {\n let idx = global_id.x;\n if (idx >= args.n) {\n return;\n }\n vec_out[idx] = vec_in[idx] + args.alpha * vec_out[idx];\n}\n\n// \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n// 4. Vector Copy: vec_out = vec_in\n// Uses: groups 1, 2\n// \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n\n@compute @workgroup_size(256)\nfn vec_copy(@builtin(global_invocation_id) global_id: vec3<u32>) {\n let idx = global_id.x;\n if (idx >= args.n) {\n return;\n }\n vec_out[idx] = vec_in[idx];\n}\n\n// \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n// 5. Vector Zero: vec_out = 0\n// Uses: groups 1 (binding 1 only), 2\n// \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n\n@compute @workgroup_size(256)\nfn vec_zero(@builtin(global_invocation_id) global_id: vec3<u32>) {\n let idx = global_id.x;\n if (idx >= args.n) {\n return;\n }\n vec_out[idx] = 0.0;\n}\n\n// \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n// 6. Dot Product \u2014 Phase 1: per-workgroup partial sums\n// result[wg_id] = sum of vec_in[i] * vec_out[i] for this workgroup\n// Uses: groups 1, 2, 3 (binding 0)\n// \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n\nvar<workgroup> dot_shared: array<f32, 256>;\n\n@compute @workgroup_size(256)\nfn dot_product(\n @builtin(global_invocation_id) global_id: vec3<u32>,\n @builtin(local_invocation_id) local_id: vec3<u32>,\n @builtin(workgroup_id) workgroup_id: vec3<u32>\n) {\n let idx = global_id.x;\n let tid = local_id.x;\n\n if (idx < args.n) {\n dot_shared[tid] = vec_in[idx] * vec_out[idx];\n } else {\n dot_shared[tid] = 0.0;\n }\n\n workgroupBarrier();\n\n for (var s = 128u; s > 0u; s >>= 1u) {\n if (tid < s) {\n dot_shared[tid] += dot_shared[tid + s];\n }\n workgroupBarrier();\n }\n\n if (tid == 0u) {\n partial_sums[workgroup_id.x] = dot_shared[0];\n }\n}\n\n// \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n// 7. Final Reduce \u2014 Phase 2: sum partial_sums \u2192 scalar_result[0]\n// args.n = number of partial sums to reduce\n// Uses: groups 2, 3 (bindings 0 and 1)\n// \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n\nvar<workgroup> reduce_shared: array<f32, 256>;\n\n@compute @workgroup_size(256)\nfn final_reduce(@builtin(local_invocation_id) local_id: vec3<u32>) {\n let tid = local_id.x;\n let count = args.n;\n\n var acc: f32 = 0.0;\n var i = tid;\n loop {\n if (i >= count) {\n break;\n }\n acc += partial_sums[i];\n i += 256u;\n }\n reduce_shared[tid] = acc;\n\n workgroupBarrier();\n\n for (var s = 128u; s > 0u; s >>= 1u) {\n if (tid < s) {\n reduce_shared[tid] += reduce_shared[tid + s];\n }\n workgroupBarrier();\n }\n\n if (tid == 0u) {\n scalar_result[0] = reduce_shared[0];\n }\n}\n";
|
|
3
|
+
|
|
4
|
+
// src/gpu/SparseLinearSolver.ts
|
|
5
|
+
var WG_SIZE = 256;
|
|
6
|
+
var SparseLinearSolver = class {
|
|
7
|
+
constructor(context) {
|
|
8
|
+
this.context = context;
|
|
9
|
+
this.device = context.getDevice();
|
|
10
|
+
}
|
|
11
|
+
context;
|
|
12
|
+
device;
|
|
13
|
+
shaderModule;
|
|
14
|
+
spmvPipeline;
|
|
15
|
+
spmvVectorPipeline;
|
|
16
|
+
saxpyPipeline;
|
|
17
|
+
dotPipeline;
|
|
18
|
+
finalReducePipeline;
|
|
19
|
+
vecCopyPipeline;
|
|
20
|
+
vecZeroPipeline;
|
|
21
|
+
pUpdatePipeline;
|
|
22
|
+
initialized = false;
|
|
23
|
+
/** Compile shaders and create all compute pipelines */
|
|
24
|
+
async initialize() {
|
|
25
|
+
if (this.initialized) return;
|
|
26
|
+
this.shaderModule = this.device.createShaderModule({
|
|
27
|
+
label: "CG Kernels",
|
|
28
|
+
code: cg_kernels_default
|
|
29
|
+
});
|
|
30
|
+
const [spmv, spmvVec, saxpy, dot, finalReduce, vecCopy, vecZero, pUpdate] = await Promise.all([
|
|
31
|
+
this.device.createComputePipelineAsync({
|
|
32
|
+
label: "SpMV Scalar",
|
|
33
|
+
layout: "auto",
|
|
34
|
+
compute: { module: this.shaderModule, entryPoint: "spmv" }
|
|
35
|
+
}),
|
|
36
|
+
this.device.createComputePipelineAsync({
|
|
37
|
+
label: "SpMV Vector",
|
|
38
|
+
layout: "auto",
|
|
39
|
+
compute: { module: this.shaderModule, entryPoint: "spmv_vector" }
|
|
40
|
+
}),
|
|
41
|
+
this.device.createComputePipelineAsync({
|
|
42
|
+
label: "SAXPY",
|
|
43
|
+
layout: "auto",
|
|
44
|
+
compute: { module: this.shaderModule, entryPoint: "saxpy" }
|
|
45
|
+
}),
|
|
46
|
+
this.device.createComputePipelineAsync({
|
|
47
|
+
label: "Dot Product",
|
|
48
|
+
layout: "auto",
|
|
49
|
+
compute: { module: this.shaderModule, entryPoint: "dot_product" }
|
|
50
|
+
}),
|
|
51
|
+
this.device.createComputePipelineAsync({
|
|
52
|
+
label: "Final Reduce",
|
|
53
|
+
layout: "auto",
|
|
54
|
+
compute: { module: this.shaderModule, entryPoint: "final_reduce" }
|
|
55
|
+
}),
|
|
56
|
+
this.device.createComputePipelineAsync({
|
|
57
|
+
label: "Vec Copy",
|
|
58
|
+
layout: "auto",
|
|
59
|
+
compute: { module: this.shaderModule, entryPoint: "vec_copy" }
|
|
60
|
+
}),
|
|
61
|
+
this.device.createComputePipelineAsync({
|
|
62
|
+
label: "Vec Zero",
|
|
63
|
+
layout: "auto",
|
|
64
|
+
compute: { module: this.shaderModule, entryPoint: "vec_zero" }
|
|
65
|
+
}),
|
|
66
|
+
this.device.createComputePipelineAsync({
|
|
67
|
+
label: "P-Update",
|
|
68
|
+
layout: "auto",
|
|
69
|
+
compute: { module: this.shaderModule, entryPoint: "p_update" }
|
|
70
|
+
})
|
|
71
|
+
]);
|
|
72
|
+
this.spmvPipeline = spmv;
|
|
73
|
+
this.spmvVectorPipeline = spmvVec;
|
|
74
|
+
this.saxpyPipeline = saxpy;
|
|
75
|
+
this.dotPipeline = dot;
|
|
76
|
+
this.finalReducePipeline = finalReduce;
|
|
77
|
+
this.vecCopyPipeline = vecCopy;
|
|
78
|
+
this.vecZeroPipeline = vecZero;
|
|
79
|
+
this.pUpdatePipeline = pUpdate;
|
|
80
|
+
this.initialized = true;
|
|
81
|
+
}
|
|
82
|
+
/**
|
|
83
|
+
* Solve Ax = b using Conjugate Gradient on the GPU.
|
|
84
|
+
*
|
|
85
|
+
* Algorithm (Hestenes-Stiefel):
|
|
86
|
+
* r₀ = b - A·x₀
|
|
87
|
+
* p₀ = r₀
|
|
88
|
+
* for k = 0, 1, 2, ...
|
|
89
|
+
* Ap = A·p
|
|
90
|
+
* α = (r·r) / (p·Ap)
|
|
91
|
+
* x = x + α·p
|
|
92
|
+
* r = r - α·Ap
|
|
93
|
+
* if ||r||² < tol: break
|
|
94
|
+
* β = (r_new·r_new) / (r·r)
|
|
95
|
+
* p = r + β·p ← fused kernel
|
|
96
|
+
*/
|
|
97
|
+
async solveCG(A, b, xGuess, options = {}) {
|
|
98
|
+
if (!this.initialized) {
|
|
99
|
+
throw new Error("SparseLinearSolver not initialized. Call initialize() first.");
|
|
100
|
+
}
|
|
101
|
+
const {
|
|
102
|
+
maxIterations = 1e3,
|
|
103
|
+
toleranceSq = 1e-10,
|
|
104
|
+
convergenceCheckInterval = 50,
|
|
105
|
+
onProgress
|
|
106
|
+
} = options;
|
|
107
|
+
const n = A.num_rows;
|
|
108
|
+
const vectorWidth = 16;
|
|
109
|
+
const numWgSpmvVec = Math.ceil(n * vectorWidth / WG_SIZE);
|
|
110
|
+
const numWgVec = Math.ceil(n / WG_SIZE);
|
|
111
|
+
const numWgDot = Math.ceil(n / WG_SIZE);
|
|
112
|
+
const csrVal = this.uploadStorage(A.val, "csr-val");
|
|
113
|
+
const csrCol = this.uploadStorage(A.col_ind, "csr-col");
|
|
114
|
+
const csrRow = this.uploadStorage(A.row_ptr, "csr-row");
|
|
115
|
+
const bufB = this.uploadStorage(b, "vec-b");
|
|
116
|
+
const bufX = this.uploadStorage(xGuess, "vec-x");
|
|
117
|
+
const bufR = this.emptyVec(n, "vec-r");
|
|
118
|
+
const bufP = this.emptyVec(n, "vec-p");
|
|
119
|
+
const bufAp = this.emptyVec(n, "vec-Ap");
|
|
120
|
+
const bufPartials = this.emptyVec(numWgDot, "partial-sums");
|
|
121
|
+
const bufScalar = this.emptyVec(1, "scalar-result");
|
|
122
|
+
const bufStaging = this.device.createBuffer({
|
|
123
|
+
label: "staging",
|
|
124
|
+
size: 4,
|
|
125
|
+
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST
|
|
126
|
+
});
|
|
127
|
+
const bufArgs = this.device.createBuffer({
|
|
128
|
+
label: "solver-args",
|
|
129
|
+
size: 16,
|
|
130
|
+
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
131
|
+
});
|
|
132
|
+
const allBuffers = [csrVal, csrCol, csrRow, bufB, bufX, bufR, bufP, bufAp, bufPartials, bufScalar, bufStaging, bufArgs];
|
|
133
|
+
{
|
|
134
|
+
this.writeArgs(bufArgs, n, vectorWidth, n, 0);
|
|
135
|
+
const enc = this.device.createCommandEncoder({ label: "init-spmv" });
|
|
136
|
+
this.dispatchSpmv(enc, csrVal, csrCol, csrRow, bufX, bufAp, bufArgs, numWgSpmvVec, true);
|
|
137
|
+
this.device.queue.submit([enc.finish()]);
|
|
138
|
+
await this.device.queue.onSubmittedWorkDone();
|
|
139
|
+
}
|
|
140
|
+
{
|
|
141
|
+
const enc = this.device.createCommandEncoder({ label: "init-residual" });
|
|
142
|
+
enc.copyBufferToBuffer(bufB, 0, bufR, 0, n * 4);
|
|
143
|
+
this.device.queue.submit([enc.finish()]);
|
|
144
|
+
await this.device.queue.onSubmittedWorkDone();
|
|
145
|
+
}
|
|
146
|
+
{
|
|
147
|
+
this.writeArgs(bufArgs, n, vectorWidth, n, -1);
|
|
148
|
+
const enc = this.device.createCommandEncoder({ label: "init-saxpy" });
|
|
149
|
+
this.dispatchSaxpy(enc, bufAp, bufR, bufArgs, numWgVec);
|
|
150
|
+
this.device.queue.submit([enc.finish()]);
|
|
151
|
+
await this.device.queue.onSubmittedWorkDone();
|
|
152
|
+
}
|
|
153
|
+
{
|
|
154
|
+
this.writeArgs(bufArgs, n, vectorWidth, n, 0);
|
|
155
|
+
const enc = this.device.createCommandEncoder({ label: "init-copy-p" });
|
|
156
|
+
this.dispatchVecCopy(enc, bufR, bufP, bufArgs, numWgVec);
|
|
157
|
+
this.device.queue.submit([enc.finish()]);
|
|
158
|
+
await this.device.queue.onSubmittedWorkDone();
|
|
159
|
+
}
|
|
160
|
+
let rDotR = await this.dotProduct(bufR, bufR, bufPartials, bufScalar, bufStaging, bufArgs, n, numWgDot);
|
|
161
|
+
if (rDotR < toleranceSq) {
|
|
162
|
+
const x = await this.readback(bufX, n);
|
|
163
|
+
this.cleanup(allBuffers);
|
|
164
|
+
return { x, iterations: 0, residualNormSq: rDotR, converged: true };
|
|
165
|
+
}
|
|
166
|
+
let iteration = 0;
|
|
167
|
+
let converged = false;
|
|
168
|
+
for (iteration = 0; iteration < maxIterations; iteration++) {
|
|
169
|
+
{
|
|
170
|
+
this.writeArgs(bufArgs, n, vectorWidth, n, 0);
|
|
171
|
+
const enc = this.device.createCommandEncoder();
|
|
172
|
+
this.dispatchSpmv(enc, csrVal, csrCol, csrRow, bufP, bufAp, bufArgs, numWgSpmvVec, true);
|
|
173
|
+
this.device.queue.submit([enc.finish()]);
|
|
174
|
+
}
|
|
175
|
+
const pAp = await this.dotProduct(bufP, bufAp, bufPartials, bufScalar, bufStaging, bufArgs, n, numWgDot);
|
|
176
|
+
if (Math.abs(pAp) < 1e-30) {
|
|
177
|
+
converged = rDotR < toleranceSq;
|
|
178
|
+
break;
|
|
179
|
+
}
|
|
180
|
+
const alpha = rDotR / pAp;
|
|
181
|
+
{
|
|
182
|
+
this.writeArgs(bufArgs, n, vectorWidth, n, alpha);
|
|
183
|
+
const enc = this.device.createCommandEncoder();
|
|
184
|
+
this.dispatchSaxpy(enc, bufP, bufX, bufArgs, numWgVec);
|
|
185
|
+
this.device.queue.submit([enc.finish()]);
|
|
186
|
+
await this.device.queue.onSubmittedWorkDone();
|
|
187
|
+
}
|
|
188
|
+
{
|
|
189
|
+
this.writeArgs(bufArgs, n, vectorWidth, n, -alpha);
|
|
190
|
+
const enc = this.device.createCommandEncoder();
|
|
191
|
+
this.dispatchSaxpy(enc, bufAp, bufR, bufArgs, numWgVec);
|
|
192
|
+
this.device.queue.submit([enc.finish()]);
|
|
193
|
+
await this.device.queue.onSubmittedWorkDone();
|
|
194
|
+
}
|
|
195
|
+
const rNewDotRNew = await this.dotProduct(bufR, bufR, bufPartials, bufScalar, bufStaging, bufArgs, n, numWgDot);
|
|
196
|
+
if (rNewDotRNew < toleranceSq) {
|
|
197
|
+
rDotR = rNewDotRNew;
|
|
198
|
+
converged = true;
|
|
199
|
+
iteration++;
|
|
200
|
+
onProgress?.(iteration, rNewDotRNew);
|
|
201
|
+
break;
|
|
202
|
+
}
|
|
203
|
+
if (onProgress && iteration % convergenceCheckInterval === 0) {
|
|
204
|
+
onProgress(iteration, rNewDotRNew);
|
|
205
|
+
}
|
|
206
|
+
const beta = rNewDotRNew / rDotR;
|
|
207
|
+
{
|
|
208
|
+
this.writeArgs(bufArgs, n, vectorWidth, n, beta);
|
|
209
|
+
const enc = this.device.createCommandEncoder();
|
|
210
|
+
this.dispatchPUpdate(enc, bufR, bufP, bufArgs, numWgVec);
|
|
211
|
+
this.device.queue.submit([enc.finish()]);
|
|
212
|
+
await this.device.queue.onSubmittedWorkDone();
|
|
213
|
+
}
|
|
214
|
+
rDotR = rNewDotRNew;
|
|
215
|
+
}
|
|
216
|
+
const solution = await this.readback(bufX, n);
|
|
217
|
+
this.cleanup(allBuffers);
|
|
218
|
+
return { x: solution, iterations: iteration, residualNormSq: rDotR, converged };
|
|
219
|
+
}
|
|
220
|
+
/**
|
|
221
|
+
* solveCGDirect — Direct GPU-to-GPU Conjugate Gradient solve.
|
|
222
|
+
*
|
|
223
|
+
* Same as solveCG but avoids CPU readback of the solution vector.
|
|
224
|
+
* Returns the live GPUBuffer containing the result.
|
|
225
|
+
*
|
|
226
|
+
* @warning Caller is responsible for destroying the returned xBuffer.
|
|
227
|
+
*/
|
|
228
|
+
async solveCGDirect(A, b, x0, options = {}) {
|
|
229
|
+
const n = A.num_rows;
|
|
230
|
+
const maxIterations = options.maxIterations ?? 1e3;
|
|
231
|
+
const toleranceSq = options.toleranceSq ?? 1e-10;
|
|
232
|
+
const xExtraUsage = options.xExtraUsage ?? 0;
|
|
233
|
+
const valBuffer = this.uploadStorage(A.val, "val");
|
|
234
|
+
const colIndBuffer = this.uploadStorage(new Uint32Array(A.col_ind), "col_ind");
|
|
235
|
+
const rowPtrBuffer = this.uploadStorage(new Uint32Array(A.row_ptr), "row_ptr");
|
|
236
|
+
const bBuffer = this.uploadStorage(b, "b");
|
|
237
|
+
const xBuffer = this.uploadStorage(x0, "x", xExtraUsage);
|
|
238
|
+
const rBuffer = this.emptyVec(n, "r");
|
|
239
|
+
const pBuffer = this.emptyVec(n, "p");
|
|
240
|
+
const ApBuffer = this.emptyVec(n, "Ap");
|
|
241
|
+
const rDotRBuffer = this.emptyVec(1, "rDotR");
|
|
242
|
+
const rDotRStagingBuffer = this.device.createBuffer({
|
|
243
|
+
size: 4,
|
|
244
|
+
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST
|
|
245
|
+
});
|
|
246
|
+
const numWgVec = Math.ceil(n / WG_SIZE);
|
|
247
|
+
const numWgDot = Math.ceil(n / WG_SIZE);
|
|
248
|
+
const bufArgs = this.device.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST });
|
|
249
|
+
const partials = this.emptyVec(numWgDot, "partials");
|
|
250
|
+
{
|
|
251
|
+
const enc = this.device.createCommandEncoder();
|
|
252
|
+
this.dispatchVecCopy(enc, bBuffer, rBuffer, bufArgs, numWgVec);
|
|
253
|
+
this.dispatchSpmv(enc, valBuffer, colIndBuffer, rowPtrBuffer, xBuffer, ApBuffer, bufArgs, numWgVec, true);
|
|
254
|
+
this.device.queue.submit([enc.finish()]);
|
|
255
|
+
}
|
|
256
|
+
{
|
|
257
|
+
this.writeArgs(bufArgs, n, 0, n, -1);
|
|
258
|
+
const enc = this.device.createCommandEncoder();
|
|
259
|
+
this.dispatchSaxpy(enc, ApBuffer, rBuffer, bufArgs, numWgVec);
|
|
260
|
+
this.device.queue.submit([enc.finish()]);
|
|
261
|
+
}
|
|
262
|
+
{
|
|
263
|
+
const enc = this.device.createCommandEncoder();
|
|
264
|
+
this.dispatchVecCopy(enc, rBuffer, pBuffer, bufArgs, numWgVec);
|
|
265
|
+
this.device.queue.submit([enc.finish()]);
|
|
266
|
+
}
|
|
267
|
+
let iteration = 0;
|
|
268
|
+
let converged = false;
|
|
269
|
+
let rDotR = await this.dotProduct(rBuffer, rBuffer, partials, rDotRBuffer, rDotRStagingBuffer, bufArgs, n, numWgDot);
|
|
270
|
+
for (iteration = 0; iteration < maxIterations; iteration++) {
|
|
271
|
+
if (rDotR < toleranceSq) {
|
|
272
|
+
converged = true;
|
|
273
|
+
break;
|
|
274
|
+
}
|
|
275
|
+
{
|
|
276
|
+
const enc = this.device.createCommandEncoder();
|
|
277
|
+
this.dispatchSpmv(enc, valBuffer, colIndBuffer, rowPtrBuffer, pBuffer, ApBuffer, bufArgs, numWgVec, true);
|
|
278
|
+
this.device.queue.submit([enc.finish()]);
|
|
279
|
+
}
|
|
280
|
+
const pAp = await this.dotProduct(pBuffer, ApBuffer, partials, rDotRBuffer, rDotRStagingBuffer, bufArgs, n, numWgDot);
|
|
281
|
+
const alpha = rDotR / (pAp + 1e-20);
|
|
282
|
+
{
|
|
283
|
+
this.writeArgs(bufArgs, n, 0, n, alpha);
|
|
284
|
+
const enc = this.device.createCommandEncoder();
|
|
285
|
+
this.dispatchSaxpy(enc, pBuffer, xBuffer, bufArgs, numWgVec);
|
|
286
|
+
this.device.queue.submit([enc.finish()]);
|
|
287
|
+
}
|
|
288
|
+
{
|
|
289
|
+
this.writeArgs(bufArgs, n, 0, n, -alpha);
|
|
290
|
+
const enc = this.device.createCommandEncoder();
|
|
291
|
+
this.dispatchSaxpy(enc, ApBuffer, rBuffer, bufArgs, numWgVec);
|
|
292
|
+
this.device.queue.submit([enc.finish()]);
|
|
293
|
+
}
|
|
294
|
+
const oldRDotR = rDotR;
|
|
295
|
+
rDotR = await this.dotProduct(rBuffer, rBuffer, partials, rDotRBuffer, rDotRStagingBuffer, bufArgs, n, numWgDot);
|
|
296
|
+
const beta = rDotR / (oldRDotR + 1e-20);
|
|
297
|
+
{
|
|
298
|
+
this.writeArgs(bufArgs, n, 0, n, beta);
|
|
299
|
+
const enc = this.device.createCommandEncoder();
|
|
300
|
+
this.dispatchPUpdate(enc, rBuffer, pBuffer, bufArgs, numWgVec);
|
|
301
|
+
this.device.queue.submit([enc.finish()]);
|
|
302
|
+
}
|
|
303
|
+
}
|
|
304
|
+
this.cleanup([
|
|
305
|
+
valBuffer,
|
|
306
|
+
colIndBuffer,
|
|
307
|
+
rowPtrBuffer,
|
|
308
|
+
bBuffer,
|
|
309
|
+
rBuffer,
|
|
310
|
+
pBuffer,
|
|
311
|
+
ApBuffer,
|
|
312
|
+
rDotRBuffer,
|
|
313
|
+
rDotRStagingBuffer
|
|
314
|
+
]);
|
|
315
|
+
return { xBuffer, iterations: iteration, residualNormSq: rDotR, converged };
|
|
316
|
+
}
|
|
317
|
+
// ═══════════════════════════════════════════════════════════════════
|
|
318
|
+
// Dispatch helpers — each sets the bind groups its entry point needs
|
|
319
|
+
// ═══════════════════════════════════════════════════════════════════
|
|
320
|
+
/** SpMV: groups 0 (CSR), 1 (vecs), 2 (args) */
|
|
321
|
+
dispatchSpmv(enc, val, col, row, x, y, args, numWgs, useVector) {
|
|
322
|
+
const pipeline = useVector ? this.spmvVectorPipeline : this.spmvPipeline;
|
|
323
|
+
const pass = enc.beginComputePass({ label: "spmv" });
|
|
324
|
+
pass.setPipeline(pipeline);
|
|
325
|
+
pass.setBindGroup(0, this.device.createBindGroup({
|
|
326
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
327
|
+
entries: [
|
|
328
|
+
{ binding: 0, resource: { buffer: val } },
|
|
329
|
+
{ binding: 1, resource: { buffer: col } },
|
|
330
|
+
{ binding: 2, resource: { buffer: row } }
|
|
331
|
+
]
|
|
332
|
+
}));
|
|
333
|
+
pass.setBindGroup(1, this.device.createBindGroup({
|
|
334
|
+
layout: pipeline.getBindGroupLayout(1),
|
|
335
|
+
entries: [
|
|
336
|
+
{ binding: 0, resource: { buffer: x } },
|
|
337
|
+
{ binding: 1, resource: { buffer: y } }
|
|
338
|
+
]
|
|
339
|
+
}));
|
|
340
|
+
pass.setBindGroup(2, this.device.createBindGroup({
|
|
341
|
+
layout: pipeline.getBindGroupLayout(2),
|
|
342
|
+
entries: [{ binding: 0, resource: { buffer: args } }]
|
|
343
|
+
}));
|
|
344
|
+
pass.dispatchWorkgroups(numWgs);
|
|
345
|
+
pass.end();
|
|
346
|
+
}
|
|
347
|
+
/** SAXPY: groups 1 (vecs), 2 (args) */
|
|
348
|
+
dispatchSaxpy(enc, x, y, args, numWgs) {
|
|
349
|
+
const pass = enc.beginComputePass({ label: "saxpy" });
|
|
350
|
+
pass.setPipeline(this.saxpyPipeline);
|
|
351
|
+
pass.setBindGroup(1, this.device.createBindGroup({
|
|
352
|
+
layout: this.saxpyPipeline.getBindGroupLayout(1),
|
|
353
|
+
entries: [
|
|
354
|
+
{ binding: 0, resource: { buffer: x } },
|
|
355
|
+
{ binding: 1, resource: { buffer: y } }
|
|
356
|
+
]
|
|
357
|
+
}));
|
|
358
|
+
pass.setBindGroup(2, this.device.createBindGroup({
|
|
359
|
+
layout: this.saxpyPipeline.getBindGroupLayout(2),
|
|
360
|
+
entries: [{ binding: 0, resource: { buffer: args } }]
|
|
361
|
+
}));
|
|
362
|
+
pass.dispatchWorkgroups(numWgs);
|
|
363
|
+
pass.end();
|
|
364
|
+
}
|
|
365
|
+
/** Fused p = r + beta*p: groups 1 (vecs), 2 (args) */
|
|
366
|
+
dispatchPUpdate(enc, r, p, args, numWgs) {
|
|
367
|
+
const pass = enc.beginComputePass({ label: "p-update" });
|
|
368
|
+
pass.setPipeline(this.pUpdatePipeline);
|
|
369
|
+
pass.setBindGroup(1, this.device.createBindGroup({
|
|
370
|
+
layout: this.pUpdatePipeline.getBindGroupLayout(1),
|
|
371
|
+
entries: [
|
|
372
|
+
{ binding: 0, resource: { buffer: r } },
|
|
373
|
+
{ binding: 1, resource: { buffer: p } }
|
|
374
|
+
]
|
|
375
|
+
}));
|
|
376
|
+
pass.setBindGroup(2, this.device.createBindGroup({
|
|
377
|
+
layout: this.pUpdatePipeline.getBindGroupLayout(2),
|
|
378
|
+
entries: [{ binding: 0, resource: { buffer: args } }]
|
|
379
|
+
}));
|
|
380
|
+
pass.dispatchWorkgroups(numWgs);
|
|
381
|
+
pass.end();
|
|
382
|
+
}
|
|
383
|
+
/** Vec copy: groups 1 (vecs), 2 (args) */
|
|
384
|
+
dispatchVecCopy(enc, src, dst, args, numWgs) {
|
|
385
|
+
const pass = enc.beginComputePass({ label: "vec-copy" });
|
|
386
|
+
pass.setPipeline(this.vecCopyPipeline);
|
|
387
|
+
pass.setBindGroup(1, this.device.createBindGroup({
|
|
388
|
+
layout: this.vecCopyPipeline.getBindGroupLayout(1),
|
|
389
|
+
entries: [
|
|
390
|
+
{ binding: 0, resource: { buffer: src } },
|
|
391
|
+
{ binding: 1, resource: { buffer: dst } }
|
|
392
|
+
]
|
|
393
|
+
}));
|
|
394
|
+
pass.setBindGroup(2, this.device.createBindGroup({
|
|
395
|
+
layout: this.vecCopyPipeline.getBindGroupLayout(2),
|
|
396
|
+
entries: [{ binding: 0, resource: { buffer: args } }]
|
|
397
|
+
}));
|
|
398
|
+
pass.dispatchWorkgroups(numWgs);
|
|
399
|
+
pass.end();
|
|
400
|
+
}
|
|
401
|
+
/**
|
|
402
|
+
* Full dot product: v1·v2
|
|
403
|
+
* Phase 1: dot_product kernel → partial_sums (per-workgroup)
|
|
404
|
+
* Phase 2: final_reduce → scalar_result[0]
|
|
405
|
+
* Readback: staging mapAsync → CPU f32
|
|
406
|
+
*/
|
|
407
|
+
async dotProduct(v1, v2, partials, scalar, staging, args, n, numWgDot) {
|
|
408
|
+
{
|
|
409
|
+
this.writeArgs(args, n, 0, n, 0);
|
|
410
|
+
const enc = this.device.createCommandEncoder({ label: "dot-phase1" });
|
|
411
|
+
const pass = enc.beginComputePass();
|
|
412
|
+
pass.setPipeline(this.dotPipeline);
|
|
413
|
+
pass.setBindGroup(1, this.device.createBindGroup({
|
|
414
|
+
layout: this.dotPipeline.getBindGroupLayout(1),
|
|
415
|
+
entries: [
|
|
416
|
+
{ binding: 0, resource: { buffer: v1 } },
|
|
417
|
+
{ binding: 1, resource: { buffer: v2 } }
|
|
418
|
+
]
|
|
419
|
+
}));
|
|
420
|
+
pass.setBindGroup(2, this.device.createBindGroup({
|
|
421
|
+
layout: this.dotPipeline.getBindGroupLayout(2),
|
|
422
|
+
entries: [{ binding: 0, resource: { buffer: args } }]
|
|
423
|
+
}));
|
|
424
|
+
pass.setBindGroup(3, this.device.createBindGroup({
|
|
425
|
+
layout: this.dotPipeline.getBindGroupLayout(3),
|
|
426
|
+
entries: [{ binding: 0, resource: { buffer: partials } }]
|
|
427
|
+
}));
|
|
428
|
+
pass.dispatchWorkgroups(numWgDot);
|
|
429
|
+
pass.end();
|
|
430
|
+
this.device.queue.submit([enc.finish()]);
|
|
431
|
+
}
|
|
432
|
+
{
|
|
433
|
+
this.writeArgs(args, numWgDot, 0, numWgDot, 0);
|
|
434
|
+
const enc = this.device.createCommandEncoder({ label: "dot-phase2" });
|
|
435
|
+
const pass = enc.beginComputePass();
|
|
436
|
+
pass.setPipeline(this.finalReducePipeline);
|
|
437
|
+
pass.setBindGroup(2, this.device.createBindGroup({
|
|
438
|
+
layout: this.finalReducePipeline.getBindGroupLayout(2),
|
|
439
|
+
entries: [{ binding: 0, resource: { buffer: args } }]
|
|
440
|
+
}));
|
|
441
|
+
pass.setBindGroup(3, this.device.createBindGroup({
|
|
442
|
+
layout: this.finalReducePipeline.getBindGroupLayout(3),
|
|
443
|
+
entries: [
|
|
444
|
+
{ binding: 0, resource: { buffer: partials } },
|
|
445
|
+
{ binding: 1, resource: { buffer: scalar } }
|
|
446
|
+
]
|
|
447
|
+
}));
|
|
448
|
+
pass.dispatchWorkgroups(1);
|
|
449
|
+
pass.end();
|
|
450
|
+
enc.copyBufferToBuffer(scalar, 0, staging, 0, 4);
|
|
451
|
+
this.device.queue.submit([enc.finish()]);
|
|
452
|
+
}
|
|
453
|
+
await staging.mapAsync(GPUMapMode.READ);
|
|
454
|
+
const value = new Float32Array(staging.getMappedRange())[0];
|
|
455
|
+
staging.unmap();
|
|
456
|
+
return value;
|
|
457
|
+
}
|
|
458
|
+
// ═══════════════════════════════════════════════════════════════════
|
|
459
|
+
// Buffer helpers
|
|
460
|
+
// ═══════════════════════════════════════════════════════════════════
|
|
461
|
+
writeArgs(buf, numRows, vectorWidth, n, alpha) {
|
|
462
|
+
const data = new ArrayBuffer(16);
|
|
463
|
+
new Uint32Array(data, 0, 3).set([numRows, vectorWidth, n]);
|
|
464
|
+
new Float32Array(data, 12, 1).set([alpha]);
|
|
465
|
+
this.device.queue.writeBuffer(buf, 0, data);
|
|
466
|
+
}
|
|
467
|
+
uploadStorage(data, label, extraUsage = 0) {
|
|
468
|
+
const buf = this.device.createBuffer({
|
|
469
|
+
label,
|
|
470
|
+
size: data.byteLength,
|
|
471
|
+
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | extraUsage,
|
|
472
|
+
mappedAtCreation: true
|
|
473
|
+
});
|
|
474
|
+
if (data instanceof Float32Array) new Float32Array(buf.getMappedRange()).set(data);
|
|
475
|
+
else new Uint32Array(buf.getMappedRange()).set(data);
|
|
476
|
+
buf.unmap();
|
|
477
|
+
return buf;
|
|
478
|
+
}
|
|
479
|
+
emptyVec(n, label, extraUsage = 0) {
|
|
480
|
+
return this.device.createBuffer({
|
|
481
|
+
label,
|
|
482
|
+
size: Math.max(4, n * 4),
|
|
483
|
+
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | extraUsage
|
|
484
|
+
});
|
|
485
|
+
}
|
|
486
|
+
async readback(buf, n) {
|
|
487
|
+
const staging = this.device.createBuffer({
|
|
488
|
+
size: n * 4,
|
|
489
|
+
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST
|
|
490
|
+
});
|
|
491
|
+
const enc = this.device.createCommandEncoder();
|
|
492
|
+
enc.copyBufferToBuffer(buf, 0, staging, 0, n * 4);
|
|
493
|
+
this.device.queue.submit([enc.finish()]);
|
|
494
|
+
await staging.mapAsync(GPUMapMode.READ);
|
|
495
|
+
const result = new Float32Array(staging.getMappedRange()).slice();
|
|
496
|
+
staging.unmap();
|
|
497
|
+
staging.destroy();
|
|
498
|
+
return result;
|
|
499
|
+
}
|
|
500
|
+
cleanup(buffers) {
|
|
501
|
+
for (const b of buffers) b.destroy();
|
|
502
|
+
}
|
|
503
|
+
destroy() {
|
|
504
|
+
this.initialized = false;
|
|
505
|
+
}
|
|
506
|
+
};
|
|
507
|
+
|
|
508
|
+
export {
|
|
509
|
+
SparseLinearSolver
|
|
510
|
+
};
|