@novastera-oss/llamarn 0.3.1 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (347) hide show
  1. package/README.md +86 -3
  2. package/RNLlamaCpp.podspec +1 -1
  3. package/android/CMakeLists.txt +11 -3
  4. package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
  5. package/android/src/main/cpp/include/llama.h +53 -114
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  13. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  20. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  21. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  22. package/cpp/LlamaCppModel.cpp +2 -10
  23. package/cpp/PureCppImpl.cpp +71 -4
  24. package/cpp/SystemUtils.cpp +3 -7
  25. package/cpp/build-info.cpp +2 -2
  26. package/cpp/llama.cpp/CMakeLists.txt +2 -0
  27. package/cpp/llama.cpp/CODEOWNERS +1 -1
  28. package/cpp/llama.cpp/Makefile +6 -1605
  29. package/cpp/llama.cpp/README.md +5 -1
  30. package/cpp/llama.cpp/common/arg.cpp +230 -51
  31. package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
  32. package/cpp/llama.cpp/common/chat.cpp +539 -8
  33. package/cpp/llama.cpp/common/chat.h +8 -1
  34. package/cpp/llama.cpp/common/common.cpp +60 -15
  35. package/cpp/llama.cpp/common/common.h +64 -15
  36. package/cpp/llama.cpp/common/speculative.cpp +135 -54
  37. package/cpp/llama.cpp/common/speculative.h +8 -1
  38. package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
  39. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
  40. package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
  41. package/cpp/llama.cpp/flake.nix +0 -5
  42. package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
  43. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
  44. package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
  45. package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
  46. package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
  47. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
  51. package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
  52. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
  54. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
  55. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
  57. package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
  73. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
  74. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
  75. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
  76. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
  77. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
  80. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
  82. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
  83. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
  84. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
  85. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  100. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
  101. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
  102. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
  103. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
  104. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
  105. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
  106. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
  107. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
  108. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
  110. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
  111. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
  112. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
  113. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
  114. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
  117. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
  118. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  119. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  120. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
  121. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
  122. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
  123. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  124. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  125. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
  126. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  127. package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  128. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
  131. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
  132. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
  134. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
  135. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
  136. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
  137. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  138. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  141. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  142. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
  143. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
  144. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  145. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
  147. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
  148. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
  149. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  150. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
  152. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  153. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  154. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  161. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  164. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  166. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
  167. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
  168. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
  169. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
  170. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
  173. package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
  174. package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
  175. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
  176. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  177. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
  178. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
  179. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
  180. package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
  181. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
  182. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
  183. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  184. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  185. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  186. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  187. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
  188. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  189. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
  190. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  191. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  192. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
  193. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
  194. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
  195. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
  196. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
  197. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
  198. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
  199. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
  200. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
  201. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  202. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
  203. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
  204. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
  205. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  206. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  207. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
  208. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  209. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
  210. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  211. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  212. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  213. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
  214. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  215. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
  216. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
  217. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
  218. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
  219. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
  220. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
  221. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
  222. package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  223. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
  224. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
  225. package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
  226. package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
  227. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
  228. package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
  229. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
  230. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
  231. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
  232. package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
  233. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
  234. package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
  235. package/cpp/llama.cpp/include/llama.h +53 -114
  236. package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
  237. package/cpp/llama.cpp/models/templates/README.md +2 -1
  238. package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
  239. package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
  240. package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
  241. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
  242. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
  243. package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
  244. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  245. package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
  246. package/cpp/llama.cpp/src/llama-adapter.h +3 -0
  247. package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
  248. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  249. package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
  250. package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
  251. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  252. package/cpp/llama.cpp/src/llama-context.cpp +61 -252
  253. package/cpp/llama.cpp/src/llama-context.h +10 -15
  254. package/cpp/llama.cpp/src/llama-cparams.h +0 -1
  255. package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
  256. package/cpp/llama.cpp/src/llama-graph.h +90 -51
  257. package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
  258. package/cpp/llama.cpp/src/llama-hparams.h +21 -6
  259. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
  260. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
  261. package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
  262. package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
  263. package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
  264. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
  265. package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
  266. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
  267. package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
  268. package/cpp/llama.cpp/src/llama-memory.h +13 -10
  269. package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
  270. package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
  271. package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
  272. package/cpp/llama.cpp/src/llama-model.h +28 -4
  273. package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
  274. package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
  275. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  276. package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
  277. package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
  278. package/cpp/rn-completion.cpp +3 -27
  279. package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
  280. package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
  281. package/ios/include/chat.h +8 -1
  282. package/ios/include/common/minja/chat-template.hpp +16 -7
  283. package/ios/include/common/minja/minja.hpp +47 -12
  284. package/ios/include/common.h +64 -15
  285. package/ios/include/llama.h +53 -114
  286. package/ios/include/speculative.h +8 -1
  287. package/ios/libs/llama.xcframework/Info.plist +18 -18
  288. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  289. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
  290. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  291. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
  292. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
  293. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  294. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  295. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
  296. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  297. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  298. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  299. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  300. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  301. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  302. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  303. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
  304. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
  305. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
  306. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
  307. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
  308. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
  309. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
  310. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  311. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
  312. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
  313. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
  314. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  315. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  316. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  317. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
  318. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  319. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
  320. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
  321. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  322. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  323. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  324. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  325. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  326. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  327. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  328. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  329. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  330. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
  331. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  332. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
  333. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
  334. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  335. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  336. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
  337. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
  338. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  339. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  340. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  341. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  342. package/lib/module/NativeRNLlamaCpp.js.map +1 -1
  343. package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
  344. package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
  345. package/package.json +1 -2
  346. package/src/NativeRNLlamaCpp.ts +7 -0
  347. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
@@ -1,34 +1,44 @@
1
- #include "ggml-webgpu.h"
1
+ /*
2
+ WebGPU backend implementation.
3
+ Note: Use ClangFormat to format this file.
4
+ */
2
5
 
3
- #include <webgpu/webgpu_cpp.h>
6
+ #include "ggml-webgpu.h"
4
7
 
5
- #include "ggml-impl.h"
6
8
  #include "ggml-backend-impl.h"
7
-
9
+ #include "ggml-impl.h"
8
10
  #include "ggml-wgsl-shaders.hpp"
9
11
 
12
+ #include <webgpu/webgpu_cpp.h>
13
+
14
+ #include <condition_variable>
10
15
  #include <cstring>
11
16
  #include <iostream>
12
17
  #include <mutex>
18
+ #include <string>
13
19
  #include <vector>
14
20
 
15
21
  #ifdef GGML_WEBGPU_DEBUG
16
- #define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
22
+ # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
23
+ # define WEBGPU_DEBUG_BUF_ELEMS 32
17
24
  #else
18
- #define WEBGPU_LOG_DEBUG(msg) ((void) 0)
19
- #endif // GGML_WEBGPU_DEBUG
25
+ # define WEBGPU_LOG_DEBUG(msg) ((void) 0)
26
+ #endif // GGML_WEBGPU_DEBUG
20
27
 
21
28
  /* Constants */
22
29
 
23
- #define WEBGPU_MUL_MAT_WG_SIZE 64
24
- #define WEBGPU_MUL_MAT_PARAMS_SIZE (13 * sizeof(uint32_t)) // M, N, K, batch sizes, broadcasts
25
- #define WEBGPU_CPY_PARAMS_SIZE (15 * sizeof(uint32_t)) // strides and offsets
26
- #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
30
+ #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
31
+ #define WEBGPU_MUL_MAT_WG_SIZE 64
32
+ #define WEBGPU_NUM_PARAM_BUFS 100
33
+ #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
34
+ #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
35
+ #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
36
+ #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
27
37
 
28
38
  /* End Constants */
29
39
 
30
40
  // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
31
- static void * const webgpu_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
41
+ static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
32
42
 
33
43
  // Always returns the base offset of a tensor, regardless of views.
34
44
  static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
@@ -40,100 +50,175 @@ static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
40
50
 
41
51
  /* Struct definitions */
42
52
 
53
+ // Forward reference
54
+ static void ggml_webgpu_create_buffer(wgpu::Device & device,
55
+ wgpu::Buffer & buffer,
56
+ size_t size,
57
+ wgpu::BufferUsage usage,
58
+ const char * label);
59
+
60
+ struct webgpu_pool_bufs {
61
+ wgpu::Buffer host_buf;
62
+ wgpu::Buffer dev_buf;
63
+ };
64
+
65
+ // Holds a pool of parameter buffers for WebGPU operations
66
+ struct webgpu_buf_pool {
67
+ std::vector<webgpu_pool_bufs> free;
68
+
69
+ std::mutex mutex;
70
+
71
+ std::condition_variable cv;
72
+
73
+ void init(wgpu::Device device,
74
+ int num_bufs,
75
+ size_t buf_size,
76
+ wgpu::BufferUsage dev_buf_usage,
77
+ wgpu::BufferUsage host_buf_usage) {
78
+ for (int i = 0; i < num_bufs; i++) {
79
+ wgpu::Buffer host_buf;
80
+ wgpu::Buffer dev_buf;
81
+ ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
82
+ ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
83
+ free.push_back({ host_buf, dev_buf });
84
+ }
85
+ }
86
+
87
+ webgpu_pool_bufs alloc_bufs() {
88
+ std::unique_lock<std::mutex> lock(mutex);
89
+ cv.wait(lock, [this] { return !free.empty(); });
90
+ webgpu_pool_bufs bufs = free.back();
91
+ free.pop_back();
92
+ return bufs;
93
+ }
94
+
95
+ void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
96
+ std::lock_guard<std::mutex> lock(mutex);
97
+ free.insert(free.end(), bufs.begin(), bufs.end());
98
+ cv.notify_all();
99
+ }
100
+
101
+ void cleanup() {
102
+ std::lock_guard<std::mutex> lock(mutex);
103
+ for (auto & bufs : free) {
104
+ bufs.host_buf.Destroy();
105
+ bufs.dev_buf.Destroy();
106
+ }
107
+ free.clear();
108
+ }
109
+ };
110
+
43
111
  // All the base objects needed to run operations on a WebGPU device
44
112
  struct webgpu_context_struct {
45
113
  wgpu::Instance instance;
46
- wgpu::Adapter adapter;
47
- wgpu::Device device;
48
- wgpu::Queue queue;
49
- wgpu::Limits limits;
50
- wgpu::SupportedFeatures features;
114
+ wgpu::Adapter adapter;
115
+ wgpu::Device device;
116
+ wgpu::Queue queue;
117
+ wgpu::Limits limits;
51
118
 
52
- std::mutex mutex;
53
- bool device_initialized = false;
119
+ std::recursive_mutex mutex;
120
+
121
+ webgpu_buf_pool param_buf_pool;
122
+ webgpu_buf_pool set_rows_error_buf_pool;
54
123
 
55
- // pipelines and parameter buffers
56
- // TODO: reuse params buffers for different pipelines when possible
57
124
  wgpu::ComputePipeline memset_pipeline;
58
- wgpu::Buffer memset_params_dev_buf;
59
- wgpu::Buffer memset_params_host_buf;
60
- wgpu::ComputePipeline mul_mat_pipeline;
61
- wgpu::Buffer mul_mat_params_dev_buf;
62
- wgpu::Buffer mul_mat_params_host_buf;
125
+ wgpu::ComputePipeline mul_mat_pipeline[30][2];
126
+ wgpu::ComputePipeline set_rows_pipeline;
63
127
  wgpu::ComputePipeline cpy_pipeline;
64
- wgpu::Buffer cpy_params_dev_buf;
65
- wgpu::Buffer cpy_params_host_buf;
66
128
 
67
129
  size_t memset_bytes_per_thread;
68
130
 
69
131
  // Staging buffer for reading data from the GPU
70
132
  wgpu::Buffer get_tensor_staging_buf;
133
+
134
+ // Command buffers which need to be submitted
135
+ std::vector<wgpu::CommandBuffer> staged_command_bufs;
136
+
137
+ // Parameter buffers associated with the staged command buffers
138
+ std::vector<webgpu_pool_bufs> staged_param_bufs;
139
+ // Buffers associated with set_rows operations, used to store potential errors
140
+ std::vector<webgpu_pool_bufs> staged_set_row_error_bufs;
141
+
142
+ std::vector<wgpu::FutureWaitInfo> callback_futures;
143
+
144
+ #ifdef GGML_WEBGPU_DEBUG
145
+ wgpu::Buffer debug_host_buf;
146
+ wgpu::Buffer debug_dev_buf;
147
+ #endif
71
148
  };
72
149
 
73
150
  typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
74
151
 
75
152
  struct ggml_backend_webgpu_reg_context {
76
153
  webgpu_context webgpu_ctx;
77
-
78
- size_t device_count;
79
- const char * name;
154
+ size_t device_count;
155
+ const char * name;
80
156
  };
81
157
 
82
158
  struct ggml_backend_webgpu_device_context {
83
159
  webgpu_context webgpu_ctx;
84
-
85
- std::string device_name;
86
- std::string device_desc;
160
+ std::string device_name;
161
+ std::string device_desc;
87
162
  };
88
163
 
89
164
  struct ggml_backend_webgpu_context {
90
165
  webgpu_context webgpu_ctx;
91
-
92
- std::string name;
166
+ std::string name;
93
167
  };
94
168
 
95
169
  struct ggml_backend_webgpu_buffer_context {
96
170
  webgpu_context webgpu_ctx;
97
-
98
- wgpu::Buffer buffer;
171
+ wgpu::Buffer buffer;
99
172
 
100
173
  ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) :
101
- webgpu_ctx(ctx), buffer(buf) {
102
- }
174
+ webgpu_ctx(std::move(ctx)),
175
+ buffer(std::move(buf)) {}
103
176
  };
104
177
 
105
178
  /* End struct definitions */
106
179
 
107
180
  /* WebGPU object initializations */
108
181
 
109
- static void ggml_webgpu_create_pipeline(wgpu::Device &device, wgpu::ComputePipeline &pipeline, const char * shader_code, const char * label, const std::vector<wgpu::ConstantEntry> &constants = {}) {
182
+ static void ggml_webgpu_create_pipeline(wgpu::Device & device,
183
+ wgpu::ComputePipeline & pipeline,
184
+ const char * shader_code,
185
+ const char * label,
186
+ const std::vector<wgpu::ConstantEntry> & constants = {}) {
110
187
  WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()");
188
+
111
189
  wgpu::ShaderSourceWGSL shader_source;
112
190
  shader_source.code = shader_code;
191
+
113
192
  wgpu::ShaderModuleDescriptor shader_desc;
114
193
  shader_desc.nextInChain = &shader_source;
194
+
115
195
  wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
116
196
 
117
197
  wgpu::ComputePipelineDescriptor pipeline_desc;
118
- pipeline_desc.label = label;
119
- pipeline_desc.compute.module = shader_module;
120
- pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
121
- pipeline_desc.layout = nullptr; // nullptr means auto layout
198
+ pipeline_desc.label = label;
199
+ pipeline_desc.compute.module = shader_module;
200
+ pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
201
+ pipeline_desc.layout = nullptr; // nullptr means auto layout
122
202
  if (constants.size() > 0) {
123
- pipeline_desc.compute.constants = constants.data();
203
+ pipeline_desc.compute.constants = constants.data();
124
204
  pipeline_desc.compute.constantCount = constants.size();
125
205
  }
126
206
  pipeline = device.CreateComputePipeline(&pipeline_desc);
127
207
  }
128
208
 
129
- static void ggml_webgpu_create_buffer(wgpu::Device &device, wgpu::Buffer &buffer, size_t size, wgpu::BufferUsage usage, const char* label) {
209
+ static void ggml_webgpu_create_buffer(wgpu::Device & device,
210
+ wgpu::Buffer & buffer,
211
+ size_t size,
212
+ wgpu::BufferUsage usage,
213
+ const char * label) {
130
214
  WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()");
131
215
 
132
216
  wgpu::BufferDescriptor buffer_desc;
133
- buffer_desc.size = size;
134
- buffer_desc.usage = usage;
135
- buffer_desc.label = label;
217
+ buffer_desc.size = size;
218
+ buffer_desc.usage = usage;
219
+ buffer_desc.label = label;
136
220
  buffer_desc.mappedAtCreation = false;
221
+
137
222
  // TODO: error handling
138
223
  buffer = device.CreateBuffer(&buffer_desc);
139
224
  }
@@ -142,75 +227,197 @@ static void ggml_webgpu_create_buffer(wgpu::Device &device, wgpu::Buffer &buffer
142
227
 
143
228
  /** WebGPU Actions */
144
229
 
145
- static void ggml_backend_webgpu_map_buffer(webgpu_context ctx, wgpu::Buffer buffer, wgpu::MapMode mode, size_t offset, size_t size) {
146
- ctx->instance.WaitAny(buffer.MapAsync(
147
- mode, offset, size, wgpu::CallbackMode::WaitAnyOnly,
148
- [](wgpu::MapAsyncStatus status, wgpu::StringView message) {
149
- if (status != wgpu::MapAsyncStatus::Success) {
150
- GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", message.data);
230
+ // Wait for the queue to finish processing all submitted work
231
+ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
232
+ std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
233
+ if (ctx->callback_futures.empty()) {
234
+ // no existing callbacks, wait on queue submission
235
+ ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
236
+ wgpu::CallbackMode::AllowSpontaneous,
237
+ [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
238
+ if (status != wgpu::QueueWorkDoneStatus::Success) {
239
+ GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
240
+ }
241
+ }),
242
+ UINT64_MAX);
243
+ } else {
244
+ // existing callbacks, wait on them
245
+ ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
246
+ ctx->callback_futures.clear();
247
+ }
248
+ }
249
+
250
+ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
251
+ std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
252
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_submit_queue()");
253
+ if (ctx->staged_command_bufs.empty()) {
254
+ // Nothing to submit
255
+ return;
256
+ }
257
+ ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data());
258
+
259
+ // If there are SET_ROWS operations in this submission, copy their error buffers to the host.
260
+ if (ctx->staged_set_row_error_bufs.size() > 0) {
261
+ wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
262
+ for (auto & error_bufs : ctx->staged_set_row_error_bufs) {
263
+ // Copy the error buffer to the host buffer
264
+ encoder.CopyBufferToBuffer(error_bufs.dev_buf, 0, error_bufs.host_buf, 0, error_bufs.host_buf.GetSize());
265
+ }
266
+ wgpu::CommandBuffer commands = encoder.Finish();
267
+ ctx->queue.Submit(1, &commands);
268
+ }
269
+
270
+ ctx->staged_command_bufs.clear();
271
+ std::vector<webgpu_pool_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);
272
+ std::vector<webgpu_pool_bufs> staged_set_row_error_bufs = std::move(ctx->staged_set_row_error_bufs);
273
+
274
+ // Free the staged parameter buffers once the submission completes
275
+ wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
276
+ wgpu::CallbackMode::AllowSpontaneous,
277
+ [ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
278
+ if (status != wgpu::QueueWorkDoneStatus::Success) {
279
+ GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
151
280
  }
152
- }),
153
- UINT64_MAX
154
- );
155
- }
156
-
157
- static void ggml_backend_webgpu_buffer_memset(webgpu_context ctx, wgpu::Buffer buf, uint32_t value, size_t offset, size_t size) {
158
- std::lock_guard<std::mutex> lock(ctx->mutex);
159
- wgpu::Device device = ctx->device;
160
-
161
- // map the host parameters buffer
162
- ggml_backend_webgpu_map_buffer(ctx, ctx->memset_params_host_buf, wgpu::MapMode::Write, 0, ctx->memset_params_host_buf.GetSize());
163
- uint32_t * params = (uint32_t *) ctx->memset_params_host_buf.GetMappedRange();
164
-
165
- params[0] = (uint32_t)offset;
166
- params[1] = (uint32_t)size;
167
- params[2] = value;
168
- ctx->memset_params_host_buf.Unmap();
169
-
170
- wgpu::BindGroupEntry entries[2];
171
- entries[0].binding = 0; // binding for the buffer to memset
172
- entries[0].buffer = buf;
173
- entries[0].offset = 0;
174
- entries[0].size = buf.GetSize();
175
- entries[1].binding = 1; // binding for the parameters
176
- entries[1].buffer = ctx->memset_params_dev_buf;
177
- entries[1].offset = 0;
178
- entries[1].size = ctx->memset_params_dev_buf.GetSize();
281
+ // Free the staged buffers
282
+ ctx->param_buf_pool.free_bufs(staged_param_bufs);
283
+ });
284
+ ctx->callback_futures.push_back({ p_f });
285
+
286
+ // Check for errrors in SET_ROWS operations
287
+ for (auto & error_bufs : staged_set_row_error_bufs) {
288
+ wgpu::Future f = error_bufs.host_buf.MapAsync(
289
+ wgpu::MapMode::Read,
290
+ 0,
291
+ error_bufs.host_buf.GetSize(),
292
+ wgpu::CallbackMode::AllowSpontaneous,
293
+ [ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
294
+ if (status != wgpu::MapAsyncStatus::Success) {
295
+ GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
296
+ } else {
297
+ const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf.GetConstMappedRange();
298
+ if (*error_data) {
299
+ GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
300
+ }
301
+ // We can't unmap in here due to WebGPU reentrancy limitations.
302
+ ctx->set_rows_error_buf_pool.free_bufs({ error_bufs });
303
+ }
304
+ });
305
+ ctx->callback_futures.push_back({ f });
306
+ }
307
+ }
308
+
309
+ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
310
+ wgpu::Buffer & buffer,
311
+ wgpu::MapMode mode,
312
+ size_t offset,
313
+ size_t size) {
314
+ ctx->instance.WaitAny(buffer.MapAsync(mode,
315
+ offset,
316
+ size,
317
+ wgpu::CallbackMode::AllowSpontaneous,
318
+ [](wgpu::MapAsyncStatus status, wgpu::StringView message) {
319
+ if (status != wgpu::MapAsyncStatus::Success) {
320
+ GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
321
+ message.data);
322
+ }
323
+ }),
324
+ UINT64_MAX);
325
+ }
326
+
327
+ #ifdef GGML_WEBGPU_DEBUG
328
+ // This function adds debugging information to shaders, as WebGPU does not support printing directly.
329
+ // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
330
+ // debug statements in the shader, and then call this function after encoding the commands and submitting them.
331
+ static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
332
+ ggml_backend_webgpu_submit_queue(ctx);
333
+ wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
334
+ encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
335
+ wgpu::CommandBuffer commands = encoder.Finish();
336
+ ctx->queue.Submit(1, &commands);
337
+
338
+ ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
339
+ const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange();
340
+ std::cout << "debug data:";
341
+ for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) {
342
+ std::cout << " " << i << ": " << debug_data[i];
343
+ }
344
+ std::cout << "\n";
345
+ ctx->debug_host_buf.Unmap();
346
+ }
347
+ #endif
348
+
349
+ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & ctx,
350
+ wgpu::ComputePipeline & pipeline,
351
+ std::vector<uint32_t> params,
352
+ std::vector<wgpu::BindGroupEntry> bind_group_entries,
353
+ uint32_t wg_x,
354
+ bool submit_and_wait = false) {
355
+ webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
356
+
357
+ ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
358
+ uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
359
+ for (size_t i = 0; i < params.size(); i++) {
360
+ _params[i] = params[i];
361
+ };
362
+
363
+ params_bufs.host_buf.Unmap();
364
+
365
+ uint32_t params_bufs_binding_num = bind_group_entries.size();
366
+ bind_group_entries.push_back({ .binding = params_bufs_binding_num,
367
+ .buffer = params_bufs.dev_buf,
368
+ .offset = 0,
369
+ .size = params_bufs.dev_buf.GetSize() });
179
370
 
180
371
  wgpu::BindGroupDescriptor bind_group_desc;
181
- bind_group_desc.layout = ctx->memset_pipeline.GetBindGroupLayout(0);
182
- bind_group_desc.entryCount = 2;
183
- bind_group_desc.label = "ggml_memset";
184
- bind_group_desc.entries = entries;
185
- wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
372
+ bind_group_desc.layout = pipeline.GetBindGroupLayout(0);
373
+ bind_group_desc.entryCount = bind_group_entries.size();
374
+ bind_group_desc.entries = bind_group_entries.data();
375
+ wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
186
376
 
187
- wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
188
- encoder.CopyBufferToBuffer(
189
- ctx->memset_params_host_buf, 0,
190
- ctx->memset_params_dev_buf, 0,
191
- ctx->memset_params_dev_buf.GetSize()
192
- );
377
+ wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
378
+ encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
193
379
  wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
194
- pass.SetPipeline(ctx->memset_pipeline);
380
+ pass.SetPipeline(pipeline);
195
381
  pass.SetBindGroup(0, bind_group);
196
- size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread;
197
- pass.DispatchWorkgroups(((size + 3) + bytes_per_wg - 1) / bytes_per_wg, 1, 1);
382
+ pass.DispatchWorkgroups(wg_x, 1, 1);
198
383
  pass.End();
199
384
  wgpu::CommandBuffer commands = encoder.Finish();
200
-
201
- ctx->queue.Submit(1, &commands);
385
+ if (submit_and_wait) {
386
+ // Submit and wait immediately
387
+ ctx->queue.Submit(1, &commands);
388
+ ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
389
+ wgpu::CallbackMode::AllowSpontaneous,
390
+ [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
391
+ if (status != wgpu::QueueWorkDoneStatus::Success) {
392
+ GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
393
+ }
394
+ ctx->param_buf_pool.free_bufs({ params_bufs });
395
+ }),
396
+ UINT64_MAX);
397
+ } else {
398
+ // Lock the context mutex when pushing to the staging vectors.
399
+ std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
400
+ // Enqueue commands and only submit if we have enough staged commands
401
+ ctx->staged_command_bufs.push_back(commands);
402
+ ctx->staged_param_bufs.push_back(params_bufs);
403
+ if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
404
+ ggml_backend_webgpu_submit_queue(ctx);
405
+ }
406
+ }
202
407
  }
203
408
 
204
- static void ggml_backend_webgpu_wait_on_submission(webgpu_context ctx) {
205
- // Wait for the queue to finish processing all commands
206
- ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::WaitAnyOnly,
207
- [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
208
- if (status != wgpu::QueueWorkDoneStatus::Success) {
209
- GGML_LOG_ERROR("ggml_webgpu: Failed to wait on queue: %s\n", message.data);
210
- }
211
- }),
212
- UINT64_MAX
213
- );
409
+ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
410
+ wgpu::Buffer & buf,
411
+ uint32_t value,
412
+ size_t offset,
413
+ size_t size) {
414
+ std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
415
+ std::vector<wgpu::BindGroupEntry> entries = {
416
+ { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
417
+ };
418
+ size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread;
419
+ uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg;
420
+ ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true);
214
421
  }
215
422
 
216
423
  /** End WebGPU Actions */
@@ -218,218 +425,227 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context ctx) {
218
425
  /** GGML Backend Interface */
219
426
 
220
427
  static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
221
- ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context;
428
+ ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
222
429
  return ctx->name.c_str();
223
430
  }
224
431
 
225
432
  static void ggml_backend_webgpu_free(ggml_backend_t backend) {
226
- ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context;
433
+ ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
227
434
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
228
435
 
229
436
  // TODO: cleanup
230
437
  GGML_UNUSED(ctx);
231
438
  }
232
439
 
440
+ static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
441
+ return webgpu_tensor_offset(tensor) + tensor->view_offs;
442
+ }
443
+
444
+ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
445
+ ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
446
+ return ctx->buffer;
447
+ }
448
+
449
+ static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) {
450
+ size_t offset = ggml_webgpu_tensor_offset(t);
451
+ return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
452
+ }
453
+
454
+ static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor * t) {
455
+ size_t offset = ggml_webgpu_tensor_offset(t);
456
+ return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
457
+ }
458
+
459
+ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
460
+ return (ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t) + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
461
+ ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
462
+ }
463
+
464
+ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
465
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
466
+
467
+ std::vector<uint32_t> params = { ne,
468
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
469
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
470
+ // Convert byte-strides to element-strides
471
+ (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
472
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
473
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
474
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
475
+ (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
476
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
477
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
478
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
479
+ // Logical shape — same for both tensors even if permuted
480
+ (uint32_t) src->ne[0],
481
+ (uint32_t) src->ne[1],
482
+ (uint32_t) src->ne[2],
483
+ (uint32_t) src->ne[3] };
484
+
485
+ std::vector<wgpu::BindGroupEntry> entries = {
486
+ { .binding = 0,
487
+ .buffer = ggml_webgpu_tensor_buf(src),
488
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
489
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
490
+ { .binding = 1,
491
+ .buffer = ggml_webgpu_tensor_buf(dst),
492
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
493
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
494
+ };
495
+
496
+ size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
497
+ uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
498
+ ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x);
499
+ }
500
+
501
+ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
502
+ // For set rows specifically, we need to check if src and idx are empty tensors.
503
+ if (ggml_is_empty(src) || ggml_is_empty(idx)) {
504
+ return;
505
+ }
506
+
507
+ webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
508
+ if (error_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
509
+ error_bufs.host_buf.Unmap();
510
+ }
511
+
512
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
513
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
514
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
515
+ // Convert byte-strides to element-strides
516
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
517
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
518
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
519
+ (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
520
+ (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),
521
+ (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
522
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
523
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
524
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
525
+ // Shape of src
526
+ (uint32_t) src->ne[0],
527
+ (uint32_t) src->ne[1],
528
+ (uint32_t) src->ne[2],
529
+ (uint32_t) src->ne[3],
530
+ // Shape of idx
531
+ (uint32_t) (idx->ne[1]),
532
+ (uint32_t) (idx->ne[2]) };
533
+
534
+ std::vector<wgpu::BindGroupEntry> entries = {
535
+ { .binding = 0,
536
+ .buffer = ggml_webgpu_tensor_buf(src),
537
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
538
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
539
+ { .binding = 1,
540
+ .buffer = ggml_webgpu_tensor_buf(idx),
541
+ .offset = ggml_webgpu_tensor_align_offset(ctx, idx),
542
+ .size = ggml_webgpu_tensor_binding_size(ctx, idx) },
543
+ { .binding = 2,
544
+ .buffer = ggml_webgpu_tensor_buf(dst),
545
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
546
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) },
547
+ { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
548
+ };
549
+
550
+ size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
551
+ uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
552
+
553
+ std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
554
+ ctx->staged_set_row_error_bufs.push_back(error_bufs);
555
+
556
+ ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x);
557
+ }
558
+
559
+ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
560
+ std::vector<uint32_t> params = {
561
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
562
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
563
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
564
+ (uint32_t) dst->ne[1], // number of rows in result (M)
565
+ (uint32_t) dst->ne[0], // number of columns in result (N)
566
+ (uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
567
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1
568
+ (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1
569
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2
570
+ (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2
571
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3
572
+ (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3
573
+ (uint32_t) src0->ne[2], // batch size in dimension 2
574
+ (uint32_t) src0->ne[3], // batch size in dimension 3
575
+ (uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2
576
+ (uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3
577
+ };
578
+
579
+ std::vector<wgpu::BindGroupEntry> entries = {
580
+ { .binding = 0,
581
+ .buffer = ggml_webgpu_tensor_buf(src0),
582
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
583
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
584
+ { .binding = 1,
585
+ .buffer = ggml_webgpu_tensor_buf(src1),
586
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
587
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) },
588
+ { .binding = 2,
589
+ .buffer = ggml_webgpu_tensor_buf(dst),
590
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
591
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) },
592
+ };
593
+
594
+ uint32_t wg_x =
595
+ (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
596
+ ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x);
597
+ }
598
+
233
599
  // Returns true if node has enqueued work into the queue, false otherwise
234
- static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
600
+ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
235
601
  if (ggml_is_empty(node)) {
236
602
  return false;
237
603
  }
238
-
239
604
  WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
240
605
 
606
+ ggml_tensor * src0 = node->src[0];
607
+ ggml_tensor * src1 = node->src[1];
241
608
 
242
609
  switch (node->op) {
243
- // no-ops
610
+ // no-ops
244
611
  case GGML_OP_NONE:
245
612
  case GGML_OP_VIEW:
246
613
  case GGML_OP_PERMUTE:
247
614
  return false;
248
-
249
- case GGML_OP_CPY: {
250
- std::lock_guard<std::mutex> lock(ctx->mutex);
251
- const ggml_tensor * src = node->src[0];
252
- ggml_backend_webgpu_buffer_context * src_ctx = (ggml_backend_webgpu_buffer_context *) src->buffer->context;
253
- size_t src_offset = webgpu_tensor_offset(src) + src->view_offs;
254
- // assumes power of 2 offset alignment
255
- size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
256
- // align to minimum offset alignment
257
- src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
258
- ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context;
259
- size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs;
260
- size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
261
- dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
262
-
263
- wgpu::Device device = ctx->device;
264
- ggml_backend_webgpu_map_buffer(ctx, ctx->cpy_params_host_buf,
265
- wgpu::MapMode::Write, 0, ctx->cpy_params_host_buf.GetSize());
266
- uint32_t * params = (uint32_t *) ctx->cpy_params_host_buf.GetMappedRange();
267
- uint32_t ne = (uint32_t)ggml_nelements(node);
268
- params[0] = ne;
269
- params[1] = src_misalignment/ggml_type_size(src->type);
270
- params[2] = dst_misalignment/ggml_type_size(node->type);
271
-
272
- // Convert byte-strides to element-strides
273
- params[3] = (uint32_t)src->nb[0]/ggml_type_size(src->type);
274
- params[4] = (uint32_t)src->nb[1]/ggml_type_size(src->type);
275
- params[5] = (uint32_t)src->nb[2]/ggml_type_size(src->type);
276
- params[6] = (uint32_t)src->nb[3]/ggml_type_size(src->type);
277
- params[7] = (uint32_t)node->nb[0]/ggml_type_size(node->type);
278
- params[8] = (uint32_t)node->nb[1]/ggml_type_size(node->type);
279
- params[9] = (uint32_t)node->nb[2]/ggml_type_size(node->type);
280
- params[10] = (uint32_t)node->nb[3]/ggml_type_size(node->type);
281
- // Logical shape — same for both tensors even if permuted
282
- params[11] = (uint32_t)(src->ne[0]);
283
- params[12] = (uint32_t)(src->ne[1]);
284
- params[13] = (uint32_t)(src->ne[2]);
285
- params[14] = (uint32_t)(src->ne[3]);
286
-
287
- ctx->cpy_params_host_buf.Unmap();
288
-
289
- wgpu::BindGroupEntry entries[3];
290
- entries[0].binding = 0;
291
- entries[0].buffer = src_ctx->buffer;
292
- entries[0].offset = src_offset;
293
- entries[0].size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
294
-
295
- entries[1].binding = 1;
296
- entries[1].buffer = dst_ctx->buffer;
297
- entries[1].offset = dst_offset;
298
- entries[1].size = (ggml_nbytes(node) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
299
-
300
- entries[2].binding = 2;
301
- entries[2].buffer = ctx->cpy_params_dev_buf;
302
- entries[2].offset = 0;
303
- entries[2].size = ctx->cpy_params_dev_buf.GetSize();
304
-
305
- wgpu::BindGroupDescriptor bind_group_desc;
306
- bind_group_desc.layout = ctx->cpy_pipeline.GetBindGroupLayout(0);
307
- bind_group_desc.label = "ggml_op_cpy";
308
- bind_group_desc.entryCount = 3;
309
- bind_group_desc.entries = entries;
310
- wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
311
-
312
- wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
313
- encoder.CopyBufferToBuffer(
314
- ctx->cpy_params_host_buf, 0,
315
- ctx->cpy_params_dev_buf, 0,
316
- ctx->cpy_params_dev_buf.GetSize()
317
- );
318
- wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
319
- pass.SetPipeline(ctx->cpy_pipeline);
320
- pass.SetBindGroup(0, bind_group);
321
- size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
322
- pass.DispatchWorkgroups((ne + max_wg_size - 1) / max_wg_size);
323
- pass.End();
324
- wgpu::CommandBuffer commands = encoder.Finish();
325
-
326
- // TODO, don't submit here, batch submissions
327
- ctx->queue.Submit(1, &commands);
328
- // TODO, don't wait on submission here
329
- ggml_backend_webgpu_wait_on_submission(ctx);
330
- return true;
331
- }
332
-
615
+ case GGML_OP_CPY:
616
+ {
617
+ ggml_webgpu_cpy(ctx, src0, node);
618
+ break;
619
+ }
620
+ case GGML_OP_SET_ROWS:
621
+ {
622
+ ggml_webgpu_set_rows(ctx, src0, src1, node);
623
+ break;
624
+ }
333
625
  case GGML_OP_MUL_MAT:
334
- {
335
- const ggml_tensor * src0 = node->src[0];
336
- ggml_backend_webgpu_buffer_context * src0_ctx = (ggml_backend_webgpu_buffer_context *) src0->buffer->context;
337
- size_t src0_offset = webgpu_tensor_offset(src0) + src0->view_offs;
338
- const ggml_tensor * src1 = node->src[1];
339
- ggml_backend_webgpu_buffer_context * src1_ctx = (ggml_backend_webgpu_buffer_context *) src1->buffer->context;
340
- size_t src1_offset = webgpu_tensor_offset(src1) + src1->view_offs;
341
- ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context;
342
-
343
- size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs;
344
-
345
- wgpu::Device device = ctx->device;
346
-
347
- // map the host parameters buffer
348
- ggml_backend_webgpu_map_buffer(ctx, ctx->mul_mat_params_host_buf,
349
- wgpu::MapMode::Write, 0, ctx->mul_mat_params_host_buf.GetSize());
350
- uint32_t * params = (uint32_t *) ctx->mul_mat_params_host_buf.GetMappedRange();
351
-
352
- params[0] = (uint32_t)node->ne[1]; // number of rows in result (M)
353
- params[1] = (uint32_t)node->ne[0]; // number of columns in result (N)
354
- params[2] = (uint32_t)src0->ne[0]; // number of columns in src0/src1 (K)
355
-
356
- params[3] = (uint32_t)src0->nb[1]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 1
357
- params[4] = (uint32_t)src1->nb[1]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 1
358
- params[5] = (uint32_t)src0->nb[2]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 2
359
- params[6] = (uint32_t)src1->nb[2]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 2
360
- params[7] = (uint32_t)src0->nb[3]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 3
361
- params[8] = (uint32_t)src1->nb[3]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 3
362
-
363
- params[9] = (uint32_t)src0->ne[2]; // batch size in dimension 2
364
- params[10] = (uint32_t)src0->ne[3]; // batch size in dimension 3
365
- params[11] = (uint32_t)(src1->ne[2]/src0->ne[2]); // broadcast in dimension 2
366
- params[12] = (uint32_t)(src1->ne[3]/src0->ne[3]); // broadcast in dimension 3
367
-
368
- ctx->mul_mat_params_host_buf.Unmap();
369
-
370
- wgpu::BindGroupEntry entries[4];
371
- entries[0].binding = 0;
372
- entries[0].buffer = src0_ctx->buffer;
373
- entries[0].offset = src0_offset;
374
- entries[0].size = ggml_nbytes(src0);
375
-
376
- entries[1].binding = 1;
377
- entries[1].buffer = src1_ctx->buffer;
378
- entries[1].offset = src1_offset;
379
- entries[1].size = ggml_nbytes(src1);
380
-
381
- entries[2].binding = 2;
382
- entries[2].buffer = dst_ctx->buffer;
383
- entries[2].offset = dst_offset;
384
- entries[2].size = ggml_nbytes(node);
385
-
386
- entries[3].binding = 3;
387
- entries[3].buffer = ctx->mul_mat_params_dev_buf;
388
- entries[3].offset = 0;
389
- entries[3].size = ctx->mul_mat_params_dev_buf.GetSize();
390
-
391
- wgpu::BindGroupDescriptor bind_group_desc;
392
- bind_group_desc.layout = ctx->mul_mat_pipeline.GetBindGroupLayout(0);
393
- bind_group_desc.entryCount = 4;
394
- bind_group_desc.label = "ggml_op_mul_mat";
395
- bind_group_desc.entries = entries;
396
- wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
397
-
398
- wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
399
- encoder.CopyBufferToBuffer(
400
- ctx->mul_mat_params_host_buf, 0,
401
- ctx->mul_mat_params_dev_buf, 0,
402
- ctx->mul_mat_params_dev_buf.GetSize()
403
- );
404
- wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
405
- pass.SetPipeline(ctx->mul_mat_pipeline);
406
- pass.SetBindGroup(0, bind_group);
407
- pass.DispatchWorkgroups((node->ne[0] * node->ne[1] * node->ne[2] * node->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE);
408
- pass.End();
409
- wgpu::CommandBuffer commands = encoder.Finish();
410
-
411
- // TODO, don't submit here, batch submissions
412
- ctx->queue.Submit(1, &commands);
413
- // TODO, don't wait on submission here
414
- ggml_backend_webgpu_wait_on_submission(ctx);
415
- return true;
416
- }
417
-
626
+ {
627
+ ggml_webgpu_mul_mat(ctx, src0, src1, node);
628
+ break;
629
+ }
418
630
  default:
419
631
  return false;
420
632
  }
633
+ return true;
421
634
  }
422
635
 
423
636
  static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
424
637
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
425
638
 
426
639
  ggml_backend_webgpu_context * backend_ctx = static_cast<ggml_backend_webgpu_context *>(backend->context);
427
- webgpu_context ctx = backend_ctx->webgpu_ctx;
640
+ webgpu_context ctx = backend_ctx->webgpu_ctx;
428
641
 
429
642
  for (int i = 0; i < cgraph->n_nodes; i++) {
430
643
  ggml_webgpu_encode_node(ctx, cgraph->nodes[i]);
431
644
  }
432
645
 
646
+ ggml_backend_webgpu_submit_queue(ctx);
647
+ ggml_backend_webgpu_wait_on_submission(ctx);
648
+
433
649
  return GGML_STATUS_SUCCESS;
434
650
  }
435
651
 
@@ -465,49 +681,72 @@ static void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer)
465
681
  return webgpu_ptr_base;
466
682
  }
467
683
 
468
- static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
684
+ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer,
685
+ ggml_tensor * tensor,
686
+ uint8_t value,
687
+ size_t offset,
688
+ size_t size) {
469
689
  if (size == 0) {
470
690
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
471
691
  return;
472
692
  }
473
693
 
474
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")");
694
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", "
695
+ << offset << ", " << size << ")");
475
696
 
476
697
  ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
698
+
477
699
  size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
700
+
478
701
  // This is a trick to set all bytes of a u32 to the same 1 byte value.
479
- uint32_t val32 = (uint32_t)value * 0x01010101;
702
+ uint32_t val32 = (uint32_t) value * 0x01010101;
480
703
  ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
481
704
  }
482
705
 
483
- static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
484
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
485
- ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
486
- webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
706
+ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
707
+ ggml_tensor * tensor,
708
+ const void * data,
709
+ size_t offset,
710
+ size_t size) {
711
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", "
712
+ << offset << ", " << size << ")");
713
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
714
+ webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
487
715
 
488
716
  size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
489
717
 
490
- webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size/4)*4);
718
+ webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
491
719
 
492
720
  if (size % 4 != 0) {
493
721
  // If size is not a multiple of 4, we need to memset the remaining bytes
494
722
  size_t remaining_size = size % 4;
723
+
495
724
  // pack the remaining bytes into a uint32_t
496
725
  uint32_t val32 = 0;
726
+
497
727
  for (size_t i = 0; i < remaining_size; i++) {
498
- ((uint8_t *)&val32)[i] = ((const uint8_t *)data)[size - remaining_size + i];
728
+ ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
499
729
  }
500
730
  // memset the remaining bytes
501
- ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
731
+ ggml_backend_webgpu_buffer_memset(
732
+ webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
733
+ } else {
734
+ // wait for WriteBuffer to complete
735
+ ggml_backend_webgpu_wait_on_submission(webgpu_ctx);
502
736
  }
503
737
  }
504
738
 
505
- static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
506
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
739
+ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
740
+ const ggml_tensor * tensor,
741
+ void * data,
742
+ size_t offset,
743
+ size_t size) {
744
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", "
745
+ << offset << ", " << size << ")");
507
746
 
508
- ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
509
- webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
510
- wgpu::Device device = webgpu_ctx->device;
747
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
748
+ webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
749
+ wgpu::Device device = webgpu_ctx->device;
511
750
 
512
751
  size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
513
752
 
@@ -517,22 +756,25 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
517
756
  final_size = size + (4 - (size % 4));
518
757
  }
519
758
 
520
- std::lock_guard<std::mutex> lock(webgpu_ctx->mutex);
759
+ std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
521
760
 
522
- if (webgpu_ctx->get_tensor_staging_buf == nullptr ||
523
- webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
761
+ if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
524
762
  // Create a new staging buffer if it doesn't exist or is too small
525
763
  if (webgpu_ctx->get_tensor_staging_buf) {
526
764
  webgpu_ctx->get_tensor_staging_buf.Destroy();
527
765
  }
528
- ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size,
529
- wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
766
+ ggml_webgpu_create_buffer(device,
767
+ webgpu_ctx->get_tensor_staging_buf,
768
+ final_size,
769
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
770
+ "get_tensor_staging_buf");
530
771
  }
531
772
 
532
773
  // Copy the data from the buffer to the staging buffer
533
774
  wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
534
775
  encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size);
535
776
  wgpu::CommandBuffer commands = encoder.Finish();
777
+
536
778
  // Submit the command buffer to the queue
537
779
  webgpu_ctx->queue.Submit(1, &commands);
538
780
 
@@ -548,7 +790,6 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
548
790
 
549
791
  static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
550
792
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
551
-
552
793
  ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
553
794
  ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size);
554
795
  }
@@ -556,13 +797,13 @@ static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8
556
797
  static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
557
798
  /* .free_buffer = */ ggml_backend_webgpu_buffer_free_buffer,
558
799
  /* .get_base = */ ggml_backend_webgpu_buffer_get_base,
559
- /* .init_tensor = */ NULL, // TODO: optional, needed?
800
+ /* .init_tensor = */ NULL, // TODO: optional, needed?
560
801
  /* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor,
561
802
  /* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor,
562
803
  /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor,
563
- /* .cpy_tensor = */ NULL, // TODO: optional, implement this
804
+ /* .cpy_tensor = */ NULL, // TODO: optional, implement this
564
805
  /* .clear = */ ggml_backend_webgpu_buffer_clear,
565
- /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor
806
+ /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor
566
807
  };
567
808
 
568
809
  /* End GGML Backend Buffer Interface */
@@ -574,13 +815,17 @@ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer
574
815
  return ctx->device_name.c_str();
575
816
  }
576
817
 
577
- static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
818
+ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
819
+ size_t size) {
578
820
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")");
579
821
  ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
580
822
 
581
823
  wgpu::Buffer buf;
582
- ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, size,
583
- wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, "allocated_buffer");
824
+ ggml_webgpu_create_buffer(ctx->webgpu_ctx->device,
825
+ buf,
826
+ (size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1),
827
+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
828
+ "allocated_buffer");
584
829
 
585
830
  ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf);
586
831
 
@@ -615,8 +860,8 @@ static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_
615
860
  static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
616
861
  ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
617
862
  // TODO: what do we actually want to return here? maxBufferSize might not be the full available memory.
618
- *free = ctx->webgpu_ctx->limits.maxBufferSize;
619
- *total = ctx->webgpu_ctx->limits.maxBufferSize;
863
+ *free = ctx->webgpu_ctx->limits.maxBufferSize;
864
+ *total = ctx->webgpu_ctx->limits.maxBufferSize;
620
865
  }
621
866
 
622
867
  static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {
@@ -639,98 +884,140 @@ static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct
639
884
 
640
885
  static ggml_guid_t ggml_backend_webgpu_guid(void) {
641
886
  static const char * guid_str = "__ggml_webgpu :)";
642
- return reinterpret_cast<ggml_guid_t>((void *)guid_str);
887
+ return reinterpret_cast<ggml_guid_t>((void *) guid_str);
643
888
  }
644
889
 
645
- static void ggml_webgpu_init_memset_pipeline(webgpu_context webgpu_ctx) {
890
+ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
646
891
  // we use the maximum workgroup size for the memset pipeline
647
892
  size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
648
893
  size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
649
894
  // Size the bytes_per_thread so that the largest buffer size can be handled
650
- webgpu_ctx->memset_bytes_per_thread = (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads;
895
+ webgpu_ctx->memset_bytes_per_thread =
896
+ (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads;
651
897
  std::vector<wgpu::ConstantEntry> constants(2);
652
- constants[0].key = "wg_size";
898
+ constants[0].key = "wg_size";
653
899
  constants[0].value = max_wg_size;
654
- constants[1].key = "bytes_per_thread";
900
+ constants[1].key = "bytes_per_thread";
655
901
  constants[1].value = webgpu_ctx->memset_bytes_per_thread;
656
902
  ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->memset_pipeline, wgsl_memset, "memset", constants);
657
- ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_dev_buf,
658
- 3 * sizeof(uint32_t), // 3 parameters: buffer size, offset, value
659
- wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "memset_params_dev_buf");
660
- ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_host_buf,
661
- 3 * sizeof(uint32_t), wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "memset_params_host_buf");
662
903
  }
663
904
 
664
- static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context webgpu_ctx) {
665
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat");
666
- ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_dev_buf, WEBGPU_MUL_MAT_PARAMS_SIZE,
667
- wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "mul_mat_params_dev_buf");
668
- ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_host_buf, WEBGPU_MUL_MAT_PARAMS_SIZE,
669
- wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "mul_mat_params_host_buf");
905
+ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
906
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
907
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
908
+ wgsl_mul_mat_f32_f32,
909
+ "mul_mat_f32_f32");
910
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
911
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
912
+ wgsl_mul_mat_f16_f16,
913
+ "mul_mat_f16_f16");
914
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
915
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
916
+ wgsl_mul_mat_f16_f32,
917
+ "mul_mat_f16_f32");
918
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
919
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
920
+ wgsl_mul_mat_q4_0_f32,
921
+ "mul_mat_q4_0_f32");
922
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
923
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
924
+ wgsl_mul_mat_q4_1_f32,
925
+ "mul_mat_q4_1_f32");
926
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
927
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32],
928
+ wgsl_mul_mat_q5_0_f32,
929
+ "mul_mat_q5_0_f32");
930
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
931
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32],
932
+ wgsl_mul_mat_q5_1_f32,
933
+ "mul_mat_q5_1_f32");
934
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
935
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32],
936
+ wgsl_mul_mat_q8_0_f32,
937
+ "mul_mat_q8_0_f32");
938
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
939
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32],
940
+ wgsl_mul_mat_q2_k_f32,
941
+ "mul_mat_q2_k_f32");
942
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
943
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32],
944
+ wgsl_mul_mat_q3_k_f32,
945
+ "mul_mat_q3_k_f32");
946
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
947
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32],
948
+ wgsl_mul_mat_q4_k_f32,
949
+ "mul_mat_q4_k_f32");
950
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
951
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32],
952
+ wgsl_mul_mat_q5_k_f32,
953
+ "mul_mat_q5_k_f32");
954
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
955
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32],
956
+ wgsl_mul_mat_q6_k_f32,
957
+ "mul_mat_q6_k_f32");
958
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
959
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
960
+ wgsl_mul_mat_iq2_xxs_f32,
961
+ "mul_mat_iq2_xxs_f32");
962
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
963
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
964
+ wgsl_mul_mat_iq2_xs_f32,
965
+ "mul_mat_iq2_xs_f32");
966
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
967
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32],
968
+ wgsl_mul_mat_iq2_s_f32,
969
+ "mul_mat_iq2_s_f32");
970
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
971
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
972
+ wgsl_mul_mat_iq3_xxs_f32,
973
+ "mul_mat_iq3_xxs_f32");
974
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
975
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32],
976
+ wgsl_mul_mat_iq3_s_f32,
977
+ "mul_mat_iq3_s_f32");
978
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
979
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32],
980
+ wgsl_mul_mat_iq1_s_f32,
981
+ "mul_mat_iq1_s_f32");
982
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
983
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32],
984
+ wgsl_mul_mat_iq1_m_f32,
985
+ "mul_mat_iq1_m_f32");
986
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
987
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
988
+ wgsl_mul_mat_iq4_nl_f32,
989
+ "mul_mat_iq4_nl_f32");
990
+ ggml_webgpu_create_pipeline(webgpu_ctx->device,
991
+ webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
992
+ wgsl_mul_mat_iq4_xs_f32,
993
+ "mul_mat_iq4_xs_f32");
670
994
  }
671
995
 
672
- static void ggml_webgpu_init_cpy_pipeline(webgpu_context webgpu_ctx) {
996
+ static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
673
997
  std::vector<wgpu::ConstantEntry> constants(1);
674
- constants[0].key = "wg_size";
998
+ constants[0].key = "wg_size";
675
999
  constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
1000
+ ggml_webgpu_create_pipeline(
1001
+ webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", constants);
1002
+ }
676
1003
 
1004
+ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
1005
+ std::vector<wgpu::ConstantEntry> constants(1);
1006
+ constants[0].key = "wg_size";
1007
+ constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
677
1008
  ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants);
678
- ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->cpy_params_dev_buf, WEBGPU_CPY_PARAMS_SIZE,
679
- wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "cpy_params_dev_buf");
680
- ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->cpy_params_host_buf, WEBGPU_CPY_PARAMS_SIZE,
681
- wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "cpy_params_host_buf");
682
1009
  }
683
1010
 
684
- // TODO: Make thread safe if multiple devices are used
685
1011
  static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
686
1012
  GGML_UNUSED(params);
687
1013
 
688
1014
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()");
689
1015
 
690
- ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
691
- webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
692
-
693
- std::lock_guard<std::mutex> lock(webgpu_ctx->mutex);
694
-
695
- if (!webgpu_ctx->device_initialized) {
696
- // Initialize device
697
- wgpu::DeviceDescriptor dev_desc;
698
- dev_desc.requiredLimits = &webgpu_ctx->limits;
699
- dev_desc.requiredFeatures = webgpu_ctx->features.features;
700
- dev_desc.requiredFeatureCount = webgpu_ctx->features.featureCount;
701
- dev_desc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous,
702
- [](const wgpu::Device& device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
703
- GGML_UNUSED(device);
704
- GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
705
- });
706
- dev_desc.SetUncapturedErrorCallback(
707
- [](const wgpu::Device& device, wgpu::ErrorType reason, wgpu::StringView message) {
708
- GGML_UNUSED(device);
709
- GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
710
- });
711
- webgpu_ctx->instance.WaitAny(webgpu_ctx->adapter.RequestDevice(&dev_desc, wgpu::CallbackMode::WaitAnyOnly,
712
- [webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
713
- if (status != wgpu::RequestDeviceStatus::Success) {
714
- GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data);
715
- return;
716
- }
717
- webgpu_ctx->device = device;
718
- }),
719
- UINT64_MAX
720
- );
721
- GGML_ASSERT(webgpu_ctx->device != nullptr);
722
-
723
- // Initialize (compute) queue
724
- webgpu_ctx->queue = webgpu_ctx->device.GetQueue();
725
-
726
- ggml_webgpu_init_memset_pipeline(webgpu_ctx);
727
- ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
728
- ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
729
- webgpu_ctx->device_initialized = true;
730
- }
1016
+ ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
1017
+ webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
731
1018
 
732
1019
  static ggml_backend_webgpu_context backend_ctx;
733
- backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
1020
+ backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
734
1021
  backend_ctx.webgpu_ctx = webgpu_ctx;
735
1022
 
736
1023
  // See GGML Backend Interface section
@@ -748,14 +1035,15 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
748
1035
  // See GGML Backend Buffer Type Interface section
749
1036
  static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
750
1037
  /* .iface = */ {
751
- /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
752
- /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
753
- /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment,
754
- /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size,
755
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
756
- /* .is_host = */ NULL, // defaults to false
1038
+ /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
1039
+ /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
1040
+ /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment,
1041
+ /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size,
1042
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1043
+ /* .is_host = */ NULL, // defaults to false
757
1044
  },
758
- /* .device = */ dev,
1045
+ /* .device = */
1046
+ dev,
759
1047
  /* .context = */ NULL,
760
1048
  };
761
1049
 
@@ -764,7 +1052,7 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
764
1052
 
765
1053
  static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
766
1054
  GGML_UNUSED(dev);
767
- return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
1055
+ return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
768
1056
  }
769
1057
 
770
1058
  static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
@@ -776,9 +1064,44 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
776
1064
  case GGML_OP_PERMUTE:
777
1065
  return true;
778
1066
  case GGML_OP_CPY:
1067
+ case GGML_OP_SET_ROWS:
779
1068
  return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
780
1069
  case GGML_OP_MUL_MAT:
781
- return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
1070
+ {
1071
+ switch (op->src[1]->type) {
1072
+ case GGML_TYPE_F16:
1073
+ return op->src[0]->type == GGML_TYPE_F16;
1074
+ case GGML_TYPE_F32:
1075
+ switch (op->src[0]->type) {
1076
+ case GGML_TYPE_F32:
1077
+ case GGML_TYPE_F16:
1078
+ case GGML_TYPE_Q4_0:
1079
+ case GGML_TYPE_Q4_1:
1080
+ case GGML_TYPE_Q5_0:
1081
+ case GGML_TYPE_Q5_1:
1082
+ case GGML_TYPE_Q8_0:
1083
+ case GGML_TYPE_Q2_K:
1084
+ case GGML_TYPE_Q3_K:
1085
+ case GGML_TYPE_Q4_K:
1086
+ case GGML_TYPE_Q5_K:
1087
+ case GGML_TYPE_Q6_K:
1088
+ case GGML_TYPE_IQ2_XXS:
1089
+ case GGML_TYPE_IQ2_XS:
1090
+ case GGML_TYPE_IQ2_S:
1091
+ case GGML_TYPE_IQ3_XXS:
1092
+ case GGML_TYPE_IQ3_S:
1093
+ case GGML_TYPE_IQ1_S:
1094
+ case GGML_TYPE_IQ1_M:
1095
+ case GGML_TYPE_IQ4_NL:
1096
+ case GGML_TYPE_IQ4_XS:
1097
+ return true;
1098
+ default:
1099
+ return false;
1100
+ }
1101
+ default:
1102
+ return false;
1103
+ }
1104
+ }
782
1105
  default:
783
1106
  return false;
784
1107
  }
@@ -827,30 +1150,105 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
827
1150
  webgpu_context ctx = reg_ctx->webgpu_ctx;
828
1151
 
829
1152
  wgpu::RequestAdapterOptions options = {};
830
- auto callback = [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char *message, void *userdata) {
831
- if (status != wgpu::RequestAdapterStatus::Success) {
832
- GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
833
- return;
834
- }
835
- *static_cast<wgpu::Adapter *>(userdata) = adapter;
836
- };
837
- void *userdata = &ctx->adapter;
838
- ctx->instance.WaitAny(ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::WaitAnyOnly, callback, userdata), UINT64_MAX);
1153
+ auto callback =
1154
+ [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message, void * userdata) {
1155
+ if (status != wgpu::RequestAdapterStatus::Success) {
1156
+ GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
1157
+ return;
1158
+ }
1159
+ *static_cast<wgpu::Adapter *>(userdata) = std::move(adapter);
1160
+ };
1161
+ void * userdata = &ctx->adapter;
1162
+ ctx->instance.WaitAny(
1163
+ ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous, callback, userdata), UINT64_MAX);
839
1164
  GGML_ASSERT(ctx->adapter != nullptr);
840
1165
 
841
1166
  ctx->adapter.GetLimits(&ctx->limits);
842
- ctx->adapter.GetFeatures(&ctx->features);
843
1167
 
844
1168
  wgpu::AdapterInfo info{};
845
1169
  ctx->adapter.GetInfo(&info);
846
1170
 
1171
+ // Initialize device
1172
+ std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
1173
+ wgpu::FeatureName::ImplicitDeviceSynchronization };
1174
+ wgpu::DeviceDescriptor dev_desc;
1175
+ dev_desc.requiredLimits = &ctx->limits;
1176
+ dev_desc.requiredFeatures = required_features.data();
1177
+ dev_desc.requiredFeatureCount = required_features.size();
1178
+ dev_desc.SetDeviceLostCallback(
1179
+ wgpu::CallbackMode::AllowSpontaneous,
1180
+ [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
1181
+ GGML_UNUSED(device);
1182
+ GGML_LOG_ERROR(
1183
+ "ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), std::string(message).c_str());
1184
+ });
1185
+ dev_desc.SetUncapturedErrorCallback(
1186
+ [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
1187
+ GGML_UNUSED(device);
1188
+ GGML_LOG_ERROR(
1189
+ "ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), std::string(message).c_str());
1190
+ });
1191
+ ctx->instance.WaitAny(ctx->adapter.RequestDevice(
1192
+ &dev_desc,
1193
+ wgpu::CallbackMode::AllowSpontaneous,
1194
+ [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
1195
+ if (status != wgpu::RequestDeviceStatus::Success) {
1196
+ GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
1197
+ return;
1198
+ }
1199
+ ctx->device = std::move(device);
1200
+ }),
1201
+ UINT64_MAX);
1202
+ GGML_ASSERT(ctx->device != nullptr);
1203
+
1204
+ // Initialize (compute) queue
1205
+ ctx->queue = ctx->device.GetQueue();
1206
+
1207
+ // Create buffer pool for shader parameters
1208
+ ctx->param_buf_pool.init(ctx->device,
1209
+ WEBGPU_NUM_PARAM_BUFS,
1210
+ WEBGPU_PARAMS_BUF_SIZE_BYTES,
1211
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
1212
+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
1213
+ ctx->set_rows_error_buf_pool.init(ctx->device,
1214
+ WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
1215
+ WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
1216
+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
1217
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
1218
+
1219
+ ggml_webgpu_init_memset_pipeline(ctx);
1220
+ ggml_webgpu_init_mul_mat_pipeline(ctx);
1221
+ ggml_webgpu_init_set_rows_pipeline(ctx);
1222
+ ggml_webgpu_init_cpy_pipeline(ctx);
1223
+
1224
+ #ifdef GGML_WEBGPU_DEBUG
1225
+ // Initialize debug buffers
1226
+ ggml_webgpu_create_buffer(ctx->device,
1227
+ ctx->debug_host_buf,
1228
+ WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
1229
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
1230
+ "debug_host_buf");
1231
+ ggml_webgpu_create_buffer(ctx->device,
1232
+ ctx->debug_dev_buf,
1233
+ WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
1234
+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
1235
+ "debug_dev_buf");
1236
+ #endif
1237
+
847
1238
  static ggml_backend_webgpu_device_context device_ctx;
848
- device_ctx.webgpu_ctx = ctx;
1239
+ device_ctx.webgpu_ctx = ctx;
849
1240
  device_ctx.device_name = GGML_WEBGPU_NAME;
850
- device_ctx.device_desc = std::string(info.description.data);
851
-
852
- GGML_LOG_INFO("ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | device_desc: %s\n",
853
- info.vendorID, info.vendor.data, info.architecture.data, info.deviceID, info.device.data, info.description.data);
1241
+ device_ctx.device_desc = info.description;
1242
+
1243
+ GGML_LOG_INFO(
1244
+ "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
1245
+ "device_desc: %s\n",
1246
+ info.vendorID,
1247
+ std::string(info.vendor).c_str(),
1248
+ std::string(info.architecture).c_str(),
1249
+ info.deviceID,
1250
+ std::string(info.device).c_str(),
1251
+ std::string(info.description).c_str());
854
1252
 
855
1253
  // See GGML Backend Device Interface section
856
1254
  static ggml_backend_device device = {
@@ -861,7 +1259,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
861
1259
  return &device;
862
1260
  }
863
1261
 
864
-
865
1262
  static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
866
1263
  /* .get_name = */ ggml_backend_webgpu_reg_get_name,
867
1264
  /* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count,
@@ -871,23 +1268,21 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
871
1268
 
872
1269
  /* End GGML Backend Registration Interface */
873
1270
 
874
- // TODO: Does this need to be thread safe? Is it only called once?
875
1271
  ggml_backend_reg_t ggml_backend_webgpu_reg() {
876
1272
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
877
1273
 
878
1274
  webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
879
- webgpu_ctx->device_initialized = false;
880
1275
 
881
1276
  static ggml_backend_webgpu_reg_context ctx;
882
- ctx.webgpu_ctx = webgpu_ctx;
883
- ctx.name = GGML_WEBGPU_NAME;
1277
+ ctx.webgpu_ctx = webgpu_ctx;
1278
+ ctx.name = GGML_WEBGPU_NAME;
884
1279
  ctx.device_count = 1;
885
1280
 
886
- wgpu::InstanceDescriptor instance_descriptor{};
887
- std::vector<wgpu::InstanceFeatureName> instance_features = {wgpu::InstanceFeatureName::TimedWaitAny};
888
- instance_descriptor.requiredFeatures = instance_features.data();
889
- instance_descriptor.requiredFeatureCount = instance_features.size();
890
- webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
1281
+ wgpu::InstanceDescriptor instance_descriptor{};
1282
+ std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
1283
+ instance_descriptor.requiredFeatures = instance_features.data();
1284
+ instance_descriptor.requiredFeatureCount = instance_features.size();
1285
+ webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
891
1286
  GGML_ASSERT(webgpu_ctx->instance != nullptr);
892
1287
 
893
1288
  static ggml_backend_reg reg = {