@novastera-oss/llamarn 0.3.1 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +86 -3
- package/RNLlamaCpp.podspec +1 -1
- package/android/CMakeLists.txt +11 -3
- package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
- package/android/src/main/cpp/include/llama.h +53 -114
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -10
- package/cpp/PureCppImpl.cpp +71 -4
- package/cpp/SystemUtils.cpp +3 -7
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -1
- package/cpp/llama.cpp/Makefile +6 -1605
- package/cpp/llama.cpp/README.md +5 -1
- package/cpp/llama.cpp/common/arg.cpp +230 -51
- package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
- package/cpp/llama.cpp/common/chat.cpp +539 -8
- package/cpp/llama.cpp/common/chat.h +8 -1
- package/cpp/llama.cpp/common/common.cpp +60 -15
- package/cpp/llama.cpp/common/common.h +64 -15
- package/cpp/llama.cpp/common/speculative.cpp +135 -54
- package/cpp/llama.cpp/common/speculative.h +8 -1
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
- package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
- package/cpp/llama.cpp/flake.nix +0 -5
- package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
- package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
- package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
- package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
- package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
- package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
- package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
- package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
- package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
- package/cpp/llama.cpp/include/llama.h +53 -114
- package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
- package/cpp/llama.cpp/models/templates/README.md +2 -1
- package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
- package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
- package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
- package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
- package/cpp/llama.cpp/src/llama-adapter.h +3 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
- package/cpp/llama.cpp/src/llama-arch.h +18 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
- package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +61 -252
- package/cpp/llama.cpp/src/llama-context.h +10 -15
- package/cpp/llama.cpp/src/llama-cparams.h +0 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
- package/cpp/llama.cpp/src/llama-graph.h +90 -51
- package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
- package/cpp/llama.cpp/src/llama-hparams.h +21 -6
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
- package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
- package/cpp/llama.cpp/src/llama-memory.h +13 -10
- package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
- package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
- package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
- package/cpp/llama.cpp/src/llama-model.h +28 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
- package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
- package/cpp/llama.cpp/src/llama-vocab.h +1 -0
- package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
- package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
- package/cpp/rn-completion.cpp +3 -27
- package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
- package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
- package/ios/include/chat.h +8 -1
- package/ios/include/common/minja/chat-template.hpp +16 -7
- package/ios/include/common/minja/minja.hpp +47 -12
- package/ios/include/common.h +64 -15
- package/ios/include/llama.h +53 -114
- package/ios/include/speculative.h +8 -1
- package/ios/libs/llama.xcframework/Info.plist +18 -18
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/lib/module/NativeRNLlamaCpp.js.map +1 -1
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
- package/package.json +1 -2
- package/src/NativeRNLlamaCpp.ts +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
|
@@ -1,34 +1,44 @@
|
|
|
1
|
-
|
|
1
|
+
/*
|
|
2
|
+
WebGPU backend implementation.
|
|
3
|
+
Note: Use ClangFormat to format this file.
|
|
4
|
+
*/
|
|
2
5
|
|
|
3
|
-
#include
|
|
6
|
+
#include "ggml-webgpu.h"
|
|
4
7
|
|
|
5
|
-
#include "ggml-impl.h"
|
|
6
8
|
#include "ggml-backend-impl.h"
|
|
7
|
-
|
|
9
|
+
#include "ggml-impl.h"
|
|
8
10
|
#include "ggml-wgsl-shaders.hpp"
|
|
9
11
|
|
|
12
|
+
#include <webgpu/webgpu_cpp.h>
|
|
13
|
+
|
|
14
|
+
#include <condition_variable>
|
|
10
15
|
#include <cstring>
|
|
11
16
|
#include <iostream>
|
|
12
17
|
#include <mutex>
|
|
18
|
+
#include <string>
|
|
13
19
|
#include <vector>
|
|
14
20
|
|
|
15
21
|
#ifdef GGML_WEBGPU_DEBUG
|
|
16
|
-
#define WEBGPU_LOG_DEBUG(msg)
|
|
22
|
+
# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
|
|
23
|
+
# define WEBGPU_DEBUG_BUF_ELEMS 32
|
|
17
24
|
#else
|
|
18
|
-
#define WEBGPU_LOG_DEBUG(msg) ((void) 0)
|
|
19
|
-
#endif
|
|
25
|
+
# define WEBGPU_LOG_DEBUG(msg) ((void) 0)
|
|
26
|
+
#endif // GGML_WEBGPU_DEBUG
|
|
20
27
|
|
|
21
28
|
/* Constants */
|
|
22
29
|
|
|
23
|
-
#define
|
|
24
|
-
#define
|
|
25
|
-
#define
|
|
26
|
-
#define
|
|
30
|
+
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
|
|
31
|
+
#define WEBGPU_MUL_MAT_WG_SIZE 64
|
|
32
|
+
#define WEBGPU_NUM_PARAM_BUFS 100
|
|
33
|
+
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
|
|
34
|
+
#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
|
|
35
|
+
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
|
|
36
|
+
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
|
|
27
37
|
|
|
28
38
|
/* End Constants */
|
|
29
39
|
|
|
30
40
|
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
|
|
31
|
-
static void * const webgpu_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
|
|
41
|
+
static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
|
|
32
42
|
|
|
33
43
|
// Always returns the base offset of a tensor, regardless of views.
|
|
34
44
|
static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
|
|
@@ -40,100 +50,175 @@ static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
|
|
|
40
50
|
|
|
41
51
|
/* Struct definitions */
|
|
42
52
|
|
|
53
|
+
// Forward reference
|
|
54
|
+
static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
|
55
|
+
wgpu::Buffer & buffer,
|
|
56
|
+
size_t size,
|
|
57
|
+
wgpu::BufferUsage usage,
|
|
58
|
+
const char * label);
|
|
59
|
+
|
|
60
|
+
struct webgpu_pool_bufs {
|
|
61
|
+
wgpu::Buffer host_buf;
|
|
62
|
+
wgpu::Buffer dev_buf;
|
|
63
|
+
};
|
|
64
|
+
|
|
65
|
+
// Holds a pool of parameter buffers for WebGPU operations
|
|
66
|
+
struct webgpu_buf_pool {
|
|
67
|
+
std::vector<webgpu_pool_bufs> free;
|
|
68
|
+
|
|
69
|
+
std::mutex mutex;
|
|
70
|
+
|
|
71
|
+
std::condition_variable cv;
|
|
72
|
+
|
|
73
|
+
void init(wgpu::Device device,
|
|
74
|
+
int num_bufs,
|
|
75
|
+
size_t buf_size,
|
|
76
|
+
wgpu::BufferUsage dev_buf_usage,
|
|
77
|
+
wgpu::BufferUsage host_buf_usage) {
|
|
78
|
+
for (int i = 0; i < num_bufs; i++) {
|
|
79
|
+
wgpu::Buffer host_buf;
|
|
80
|
+
wgpu::Buffer dev_buf;
|
|
81
|
+
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
|
|
82
|
+
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
|
|
83
|
+
free.push_back({ host_buf, dev_buf });
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
webgpu_pool_bufs alloc_bufs() {
|
|
88
|
+
std::unique_lock<std::mutex> lock(mutex);
|
|
89
|
+
cv.wait(lock, [this] { return !free.empty(); });
|
|
90
|
+
webgpu_pool_bufs bufs = free.back();
|
|
91
|
+
free.pop_back();
|
|
92
|
+
return bufs;
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
|
|
96
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
97
|
+
free.insert(free.end(), bufs.begin(), bufs.end());
|
|
98
|
+
cv.notify_all();
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
void cleanup() {
|
|
102
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
103
|
+
for (auto & bufs : free) {
|
|
104
|
+
bufs.host_buf.Destroy();
|
|
105
|
+
bufs.dev_buf.Destroy();
|
|
106
|
+
}
|
|
107
|
+
free.clear();
|
|
108
|
+
}
|
|
109
|
+
};
|
|
110
|
+
|
|
43
111
|
// All the base objects needed to run operations on a WebGPU device
|
|
44
112
|
struct webgpu_context_struct {
|
|
45
113
|
wgpu::Instance instance;
|
|
46
|
-
wgpu::Adapter
|
|
47
|
-
wgpu::Device
|
|
48
|
-
wgpu::Queue
|
|
49
|
-
wgpu::Limits
|
|
50
|
-
wgpu::SupportedFeatures features;
|
|
114
|
+
wgpu::Adapter adapter;
|
|
115
|
+
wgpu::Device device;
|
|
116
|
+
wgpu::Queue queue;
|
|
117
|
+
wgpu::Limits limits;
|
|
51
118
|
|
|
52
|
-
std::
|
|
53
|
-
|
|
119
|
+
std::recursive_mutex mutex;
|
|
120
|
+
|
|
121
|
+
webgpu_buf_pool param_buf_pool;
|
|
122
|
+
webgpu_buf_pool set_rows_error_buf_pool;
|
|
54
123
|
|
|
55
|
-
// pipelines and parameter buffers
|
|
56
|
-
// TODO: reuse params buffers for different pipelines when possible
|
|
57
124
|
wgpu::ComputePipeline memset_pipeline;
|
|
58
|
-
wgpu::
|
|
59
|
-
wgpu::
|
|
60
|
-
wgpu::ComputePipeline mul_mat_pipeline;
|
|
61
|
-
wgpu::Buffer mul_mat_params_dev_buf;
|
|
62
|
-
wgpu::Buffer mul_mat_params_host_buf;
|
|
125
|
+
wgpu::ComputePipeline mul_mat_pipeline[30][2];
|
|
126
|
+
wgpu::ComputePipeline set_rows_pipeline;
|
|
63
127
|
wgpu::ComputePipeline cpy_pipeline;
|
|
64
|
-
wgpu::Buffer cpy_params_dev_buf;
|
|
65
|
-
wgpu::Buffer cpy_params_host_buf;
|
|
66
128
|
|
|
67
129
|
size_t memset_bytes_per_thread;
|
|
68
130
|
|
|
69
131
|
// Staging buffer for reading data from the GPU
|
|
70
132
|
wgpu::Buffer get_tensor_staging_buf;
|
|
133
|
+
|
|
134
|
+
// Command buffers which need to be submitted
|
|
135
|
+
std::vector<wgpu::CommandBuffer> staged_command_bufs;
|
|
136
|
+
|
|
137
|
+
// Parameter buffers associated with the staged command buffers
|
|
138
|
+
std::vector<webgpu_pool_bufs> staged_param_bufs;
|
|
139
|
+
// Buffers associated with set_rows operations, used to store potential errors
|
|
140
|
+
std::vector<webgpu_pool_bufs> staged_set_row_error_bufs;
|
|
141
|
+
|
|
142
|
+
std::vector<wgpu::FutureWaitInfo> callback_futures;
|
|
143
|
+
|
|
144
|
+
#ifdef GGML_WEBGPU_DEBUG
|
|
145
|
+
wgpu::Buffer debug_host_buf;
|
|
146
|
+
wgpu::Buffer debug_dev_buf;
|
|
147
|
+
#endif
|
|
71
148
|
};
|
|
72
149
|
|
|
73
150
|
typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
|
|
74
151
|
|
|
75
152
|
struct ggml_backend_webgpu_reg_context {
|
|
76
153
|
webgpu_context webgpu_ctx;
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
const char * name;
|
|
154
|
+
size_t device_count;
|
|
155
|
+
const char * name;
|
|
80
156
|
};
|
|
81
157
|
|
|
82
158
|
struct ggml_backend_webgpu_device_context {
|
|
83
159
|
webgpu_context webgpu_ctx;
|
|
84
|
-
|
|
85
|
-
std::string
|
|
86
|
-
std::string device_desc;
|
|
160
|
+
std::string device_name;
|
|
161
|
+
std::string device_desc;
|
|
87
162
|
};
|
|
88
163
|
|
|
89
164
|
struct ggml_backend_webgpu_context {
|
|
90
165
|
webgpu_context webgpu_ctx;
|
|
91
|
-
|
|
92
|
-
std::string name;
|
|
166
|
+
std::string name;
|
|
93
167
|
};
|
|
94
168
|
|
|
95
169
|
struct ggml_backend_webgpu_buffer_context {
|
|
96
170
|
webgpu_context webgpu_ctx;
|
|
97
|
-
|
|
98
|
-
wgpu::Buffer buffer;
|
|
171
|
+
wgpu::Buffer buffer;
|
|
99
172
|
|
|
100
173
|
ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) :
|
|
101
|
-
webgpu_ctx(ctx),
|
|
102
|
-
|
|
174
|
+
webgpu_ctx(std::move(ctx)),
|
|
175
|
+
buffer(std::move(buf)) {}
|
|
103
176
|
};
|
|
104
177
|
|
|
105
178
|
/* End struct definitions */
|
|
106
179
|
|
|
107
180
|
/* WebGPU object initializations */
|
|
108
181
|
|
|
109
|
-
static void ggml_webgpu_create_pipeline(wgpu::Device &device,
|
|
182
|
+
static void ggml_webgpu_create_pipeline(wgpu::Device & device,
|
|
183
|
+
wgpu::ComputePipeline & pipeline,
|
|
184
|
+
const char * shader_code,
|
|
185
|
+
const char * label,
|
|
186
|
+
const std::vector<wgpu::ConstantEntry> & constants = {}) {
|
|
110
187
|
WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()");
|
|
188
|
+
|
|
111
189
|
wgpu::ShaderSourceWGSL shader_source;
|
|
112
190
|
shader_source.code = shader_code;
|
|
191
|
+
|
|
113
192
|
wgpu::ShaderModuleDescriptor shader_desc;
|
|
114
193
|
shader_desc.nextInChain = &shader_source;
|
|
194
|
+
|
|
115
195
|
wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
|
|
116
196
|
|
|
117
197
|
wgpu::ComputePipelineDescriptor pipeline_desc;
|
|
118
|
-
pipeline_desc.label
|
|
119
|
-
pipeline_desc.compute.module
|
|
120
|
-
pipeline_desc.compute.entryPoint = "main";
|
|
121
|
-
pipeline_desc.layout
|
|
198
|
+
pipeline_desc.label = label;
|
|
199
|
+
pipeline_desc.compute.module = shader_module;
|
|
200
|
+
pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
|
|
201
|
+
pipeline_desc.layout = nullptr; // nullptr means auto layout
|
|
122
202
|
if (constants.size() > 0) {
|
|
123
|
-
pipeline_desc.compute.constants
|
|
203
|
+
pipeline_desc.compute.constants = constants.data();
|
|
124
204
|
pipeline_desc.compute.constantCount = constants.size();
|
|
125
205
|
}
|
|
126
206
|
pipeline = device.CreateComputePipeline(&pipeline_desc);
|
|
127
207
|
}
|
|
128
208
|
|
|
129
|
-
static void ggml_webgpu_create_buffer(wgpu::Device &device,
|
|
209
|
+
static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
|
210
|
+
wgpu::Buffer & buffer,
|
|
211
|
+
size_t size,
|
|
212
|
+
wgpu::BufferUsage usage,
|
|
213
|
+
const char * label) {
|
|
130
214
|
WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()");
|
|
131
215
|
|
|
132
216
|
wgpu::BufferDescriptor buffer_desc;
|
|
133
|
-
buffer_desc.size
|
|
134
|
-
buffer_desc.usage
|
|
135
|
-
buffer_desc.label
|
|
217
|
+
buffer_desc.size = size;
|
|
218
|
+
buffer_desc.usage = usage;
|
|
219
|
+
buffer_desc.label = label;
|
|
136
220
|
buffer_desc.mappedAtCreation = false;
|
|
221
|
+
|
|
137
222
|
// TODO: error handling
|
|
138
223
|
buffer = device.CreateBuffer(&buffer_desc);
|
|
139
224
|
}
|
|
@@ -142,75 +227,197 @@ static void ggml_webgpu_create_buffer(wgpu::Device &device, wgpu::Buffer &buffer
|
|
|
142
227
|
|
|
143
228
|
/** WebGPU Actions */
|
|
144
229
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
230
|
+
// Wait for the queue to finish processing all submitted work
|
|
231
|
+
static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
|
|
232
|
+
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
|
233
|
+
if (ctx->callback_futures.empty()) {
|
|
234
|
+
// no existing callbacks, wait on queue submission
|
|
235
|
+
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
|
|
236
|
+
wgpu::CallbackMode::AllowSpontaneous,
|
|
237
|
+
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
|
238
|
+
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
|
239
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
|
|
240
|
+
}
|
|
241
|
+
}),
|
|
242
|
+
UINT64_MAX);
|
|
243
|
+
} else {
|
|
244
|
+
// existing callbacks, wait on them
|
|
245
|
+
ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
|
|
246
|
+
ctx->callback_futures.clear();
|
|
247
|
+
}
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
|
|
251
|
+
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
|
252
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_submit_queue()");
|
|
253
|
+
if (ctx->staged_command_bufs.empty()) {
|
|
254
|
+
// Nothing to submit
|
|
255
|
+
return;
|
|
256
|
+
}
|
|
257
|
+
ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data());
|
|
258
|
+
|
|
259
|
+
// If there are SET_ROWS operations in this submission, copy their error buffers to the host.
|
|
260
|
+
if (ctx->staged_set_row_error_bufs.size() > 0) {
|
|
261
|
+
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
|
|
262
|
+
for (auto & error_bufs : ctx->staged_set_row_error_bufs) {
|
|
263
|
+
// Copy the error buffer to the host buffer
|
|
264
|
+
encoder.CopyBufferToBuffer(error_bufs.dev_buf, 0, error_bufs.host_buf, 0, error_bufs.host_buf.GetSize());
|
|
265
|
+
}
|
|
266
|
+
wgpu::CommandBuffer commands = encoder.Finish();
|
|
267
|
+
ctx->queue.Submit(1, &commands);
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
ctx->staged_command_bufs.clear();
|
|
271
|
+
std::vector<webgpu_pool_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);
|
|
272
|
+
std::vector<webgpu_pool_bufs> staged_set_row_error_bufs = std::move(ctx->staged_set_row_error_bufs);
|
|
273
|
+
|
|
274
|
+
// Free the staged parameter buffers once the submission completes
|
|
275
|
+
wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
|
|
276
|
+
wgpu::CallbackMode::AllowSpontaneous,
|
|
277
|
+
[ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
|
278
|
+
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
|
279
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
|
|
151
280
|
}
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
}
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
281
|
+
// Free the staged buffers
|
|
282
|
+
ctx->param_buf_pool.free_bufs(staged_param_bufs);
|
|
283
|
+
});
|
|
284
|
+
ctx->callback_futures.push_back({ p_f });
|
|
285
|
+
|
|
286
|
+
// Check for errrors in SET_ROWS operations
|
|
287
|
+
for (auto & error_bufs : staged_set_row_error_bufs) {
|
|
288
|
+
wgpu::Future f = error_bufs.host_buf.MapAsync(
|
|
289
|
+
wgpu::MapMode::Read,
|
|
290
|
+
0,
|
|
291
|
+
error_bufs.host_buf.GetSize(),
|
|
292
|
+
wgpu::CallbackMode::AllowSpontaneous,
|
|
293
|
+
[ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
|
294
|
+
if (status != wgpu::MapAsyncStatus::Success) {
|
|
295
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
|
|
296
|
+
} else {
|
|
297
|
+
const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf.GetConstMappedRange();
|
|
298
|
+
if (*error_data) {
|
|
299
|
+
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
|
|
300
|
+
}
|
|
301
|
+
// We can't unmap in here due to WebGPU reentrancy limitations.
|
|
302
|
+
ctx->set_rows_error_buf_pool.free_bufs({ error_bufs });
|
|
303
|
+
}
|
|
304
|
+
});
|
|
305
|
+
ctx->callback_futures.push_back({ f });
|
|
306
|
+
}
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
|
|
310
|
+
wgpu::Buffer & buffer,
|
|
311
|
+
wgpu::MapMode mode,
|
|
312
|
+
size_t offset,
|
|
313
|
+
size_t size) {
|
|
314
|
+
ctx->instance.WaitAny(buffer.MapAsync(mode,
|
|
315
|
+
offset,
|
|
316
|
+
size,
|
|
317
|
+
wgpu::CallbackMode::AllowSpontaneous,
|
|
318
|
+
[](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
|
319
|
+
if (status != wgpu::MapAsyncStatus::Success) {
|
|
320
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
|
|
321
|
+
message.data);
|
|
322
|
+
}
|
|
323
|
+
}),
|
|
324
|
+
UINT64_MAX);
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
#ifdef GGML_WEBGPU_DEBUG
|
|
328
|
+
// This function adds debugging information to shaders, as WebGPU does not support printing directly.
|
|
329
|
+
// To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
|
|
330
|
+
// debug statements in the shader, and then call this function after encoding the commands and submitting them.
|
|
331
|
+
static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
|
|
332
|
+
ggml_backend_webgpu_submit_queue(ctx);
|
|
333
|
+
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
|
|
334
|
+
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
|
|
335
|
+
wgpu::CommandBuffer commands = encoder.Finish();
|
|
336
|
+
ctx->queue.Submit(1, &commands);
|
|
337
|
+
|
|
338
|
+
ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
|
|
339
|
+
const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange();
|
|
340
|
+
std::cout << "debug data:";
|
|
341
|
+
for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) {
|
|
342
|
+
std::cout << " " << i << ": " << debug_data[i];
|
|
343
|
+
}
|
|
344
|
+
std::cout << "\n";
|
|
345
|
+
ctx->debug_host_buf.Unmap();
|
|
346
|
+
}
|
|
347
|
+
#endif
|
|
348
|
+
|
|
349
|
+
static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & ctx,
|
|
350
|
+
wgpu::ComputePipeline & pipeline,
|
|
351
|
+
std::vector<uint32_t> params,
|
|
352
|
+
std::vector<wgpu::BindGroupEntry> bind_group_entries,
|
|
353
|
+
uint32_t wg_x,
|
|
354
|
+
bool submit_and_wait = false) {
|
|
355
|
+
webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
|
|
356
|
+
|
|
357
|
+
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
|
|
358
|
+
uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
|
|
359
|
+
for (size_t i = 0; i < params.size(); i++) {
|
|
360
|
+
_params[i] = params[i];
|
|
361
|
+
};
|
|
362
|
+
|
|
363
|
+
params_bufs.host_buf.Unmap();
|
|
364
|
+
|
|
365
|
+
uint32_t params_bufs_binding_num = bind_group_entries.size();
|
|
366
|
+
bind_group_entries.push_back({ .binding = params_bufs_binding_num,
|
|
367
|
+
.buffer = params_bufs.dev_buf,
|
|
368
|
+
.offset = 0,
|
|
369
|
+
.size = params_bufs.dev_buf.GetSize() });
|
|
179
370
|
|
|
180
371
|
wgpu::BindGroupDescriptor bind_group_desc;
|
|
181
|
-
bind_group_desc.layout
|
|
182
|
-
bind_group_desc.entryCount =
|
|
183
|
-
bind_group_desc.
|
|
184
|
-
|
|
185
|
-
wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
|
|
372
|
+
bind_group_desc.layout = pipeline.GetBindGroupLayout(0);
|
|
373
|
+
bind_group_desc.entryCount = bind_group_entries.size();
|
|
374
|
+
bind_group_desc.entries = bind_group_entries.data();
|
|
375
|
+
wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
|
|
186
376
|
|
|
187
|
-
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
|
188
|
-
encoder.CopyBufferToBuffer(
|
|
189
|
-
ctx->memset_params_host_buf, 0,
|
|
190
|
-
ctx->memset_params_dev_buf, 0,
|
|
191
|
-
ctx->memset_params_dev_buf.GetSize()
|
|
192
|
-
);
|
|
377
|
+
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
|
|
378
|
+
encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
|
|
193
379
|
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
|
194
|
-
pass.SetPipeline(
|
|
380
|
+
pass.SetPipeline(pipeline);
|
|
195
381
|
pass.SetBindGroup(0, bind_group);
|
|
196
|
-
|
|
197
|
-
pass.DispatchWorkgroups(((size + 3) + bytes_per_wg - 1) / bytes_per_wg, 1, 1);
|
|
382
|
+
pass.DispatchWorkgroups(wg_x, 1, 1);
|
|
198
383
|
pass.End();
|
|
199
384
|
wgpu::CommandBuffer commands = encoder.Finish();
|
|
200
|
-
|
|
201
|
-
|
|
385
|
+
if (submit_and_wait) {
|
|
386
|
+
// Submit and wait immediately
|
|
387
|
+
ctx->queue.Submit(1, &commands);
|
|
388
|
+
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
|
|
389
|
+
wgpu::CallbackMode::AllowSpontaneous,
|
|
390
|
+
[ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
|
391
|
+
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
|
392
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
|
|
393
|
+
}
|
|
394
|
+
ctx->param_buf_pool.free_bufs({ params_bufs });
|
|
395
|
+
}),
|
|
396
|
+
UINT64_MAX);
|
|
397
|
+
} else {
|
|
398
|
+
// Lock the context mutex when pushing to the staging vectors.
|
|
399
|
+
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
|
400
|
+
// Enqueue commands and only submit if we have enough staged commands
|
|
401
|
+
ctx->staged_command_bufs.push_back(commands);
|
|
402
|
+
ctx->staged_param_bufs.push_back(params_bufs);
|
|
403
|
+
if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
|
|
404
|
+
ggml_backend_webgpu_submit_queue(ctx);
|
|
405
|
+
}
|
|
406
|
+
}
|
|
202
407
|
}
|
|
203
408
|
|
|
204
|
-
static void
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
}
|
|
212
|
-
|
|
213
|
-
|
|
409
|
+
static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
|
|
410
|
+
wgpu::Buffer & buf,
|
|
411
|
+
uint32_t value,
|
|
412
|
+
size_t offset,
|
|
413
|
+
size_t size) {
|
|
414
|
+
std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
|
|
415
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
416
|
+
{ .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
|
|
417
|
+
};
|
|
418
|
+
size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread;
|
|
419
|
+
uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg;
|
|
420
|
+
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true);
|
|
214
421
|
}
|
|
215
422
|
|
|
216
423
|
/** End WebGPU Actions */
|
|
@@ -218,218 +425,227 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context ctx) {
|
|
|
218
425
|
/** GGML Backend Interface */
|
|
219
426
|
|
|
220
427
|
static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
|
|
221
|
-
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context;
|
|
428
|
+
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
|
|
222
429
|
return ctx->name.c_str();
|
|
223
430
|
}
|
|
224
431
|
|
|
225
432
|
static void ggml_backend_webgpu_free(ggml_backend_t backend) {
|
|
226
|
-
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context;
|
|
433
|
+
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
|
|
227
434
|
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
|
|
228
435
|
|
|
229
436
|
// TODO: cleanup
|
|
230
437
|
GGML_UNUSED(ctx);
|
|
231
438
|
}
|
|
232
439
|
|
|
440
|
+
static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
|
|
441
|
+
return webgpu_tensor_offset(tensor) + tensor->view_offs;
|
|
442
|
+
}
|
|
443
|
+
|
|
444
|
+
static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
|
|
445
|
+
ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
|
|
446
|
+
return ctx->buffer;
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) {
|
|
450
|
+
size_t offset = ggml_webgpu_tensor_offset(t);
|
|
451
|
+
return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
|
452
|
+
}
|
|
453
|
+
|
|
454
|
+
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor * t) {
|
|
455
|
+
size_t offset = ggml_webgpu_tensor_offset(t);
|
|
456
|
+
return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
|
|
460
|
+
return (ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t) + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
|
|
461
|
+
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
|
465
|
+
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
|
466
|
+
|
|
467
|
+
std::vector<uint32_t> params = { ne,
|
|
468
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
|
469
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
470
|
+
// Convert byte-strides to element-strides
|
|
471
|
+
(uint32_t) (src->nb[0] / ggml_type_size(src->type)),
|
|
472
|
+
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
|
473
|
+
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
|
474
|
+
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
|
475
|
+
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
|
|
476
|
+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
|
477
|
+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
|
478
|
+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
479
|
+
// Logical shape — same for both tensors even if permuted
|
|
480
|
+
(uint32_t) src->ne[0],
|
|
481
|
+
(uint32_t) src->ne[1],
|
|
482
|
+
(uint32_t) src->ne[2],
|
|
483
|
+
(uint32_t) src->ne[3] };
|
|
484
|
+
|
|
485
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
486
|
+
{ .binding = 0,
|
|
487
|
+
.buffer = ggml_webgpu_tensor_buf(src),
|
|
488
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
|
489
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
|
490
|
+
{ .binding = 1,
|
|
491
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
492
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
493
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
|
494
|
+
};
|
|
495
|
+
|
|
496
|
+
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
|
497
|
+
uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
|
|
498
|
+
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x);
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
|
|
502
|
+
// For set rows specifically, we need to check if src and idx are empty tensors.
|
|
503
|
+
if (ggml_is_empty(src) || ggml_is_empty(idx)) {
|
|
504
|
+
return;
|
|
505
|
+
}
|
|
506
|
+
|
|
507
|
+
webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
|
|
508
|
+
if (error_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
|
|
509
|
+
error_bufs.host_buf.Unmap();
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
|
513
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
|
|
514
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
515
|
+
// Convert byte-strides to element-strides
|
|
516
|
+
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
|
517
|
+
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
|
518
|
+
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
|
519
|
+
(uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
|
|
520
|
+
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),
|
|
521
|
+
(uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
|
|
522
|
+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
|
523
|
+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
|
524
|
+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
525
|
+
// Shape of src
|
|
526
|
+
(uint32_t) src->ne[0],
|
|
527
|
+
(uint32_t) src->ne[1],
|
|
528
|
+
(uint32_t) src->ne[2],
|
|
529
|
+
(uint32_t) src->ne[3],
|
|
530
|
+
// Shape of idx
|
|
531
|
+
(uint32_t) (idx->ne[1]),
|
|
532
|
+
(uint32_t) (idx->ne[2]) };
|
|
533
|
+
|
|
534
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
535
|
+
{ .binding = 0,
|
|
536
|
+
.buffer = ggml_webgpu_tensor_buf(src),
|
|
537
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
|
538
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
|
539
|
+
{ .binding = 1,
|
|
540
|
+
.buffer = ggml_webgpu_tensor_buf(idx),
|
|
541
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, idx),
|
|
542
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, idx) },
|
|
543
|
+
{ .binding = 2,
|
|
544
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
545
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
546
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
|
547
|
+
{ .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
|
|
548
|
+
};
|
|
549
|
+
|
|
550
|
+
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
|
551
|
+
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
|
|
552
|
+
|
|
553
|
+
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
|
554
|
+
ctx->staged_set_row_error_bufs.push_back(error_bufs);
|
|
555
|
+
|
|
556
|
+
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x);
|
|
557
|
+
}
|
|
558
|
+
|
|
559
|
+
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
|
560
|
+
std::vector<uint32_t> params = {
|
|
561
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
|
562
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
|
563
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
564
|
+
(uint32_t) dst->ne[1], // number of rows in result (M)
|
|
565
|
+
(uint32_t) dst->ne[0], // number of columns in result (N)
|
|
566
|
+
(uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
|
|
567
|
+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1
|
|
568
|
+
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1
|
|
569
|
+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2
|
|
570
|
+
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2
|
|
571
|
+
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3
|
|
572
|
+
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3
|
|
573
|
+
(uint32_t) src0->ne[2], // batch size in dimension 2
|
|
574
|
+
(uint32_t) src0->ne[3], // batch size in dimension 3
|
|
575
|
+
(uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2
|
|
576
|
+
(uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3
|
|
577
|
+
};
|
|
578
|
+
|
|
579
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
580
|
+
{ .binding = 0,
|
|
581
|
+
.buffer = ggml_webgpu_tensor_buf(src0),
|
|
582
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
|
583
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
|
584
|
+
{ .binding = 1,
|
|
585
|
+
.buffer = ggml_webgpu_tensor_buf(src1),
|
|
586
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
|
587
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
|
588
|
+
{ .binding = 2,
|
|
589
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
590
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
591
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
|
592
|
+
};
|
|
593
|
+
|
|
594
|
+
uint32_t wg_x =
|
|
595
|
+
(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
|
|
596
|
+
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x);
|
|
597
|
+
}
|
|
598
|
+
|
|
233
599
|
// Returns true if node has enqueued work into the queue, false otherwise
|
|
234
|
-
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
|
|
600
|
+
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
|
235
601
|
if (ggml_is_empty(node)) {
|
|
236
602
|
return false;
|
|
237
603
|
}
|
|
238
|
-
|
|
239
604
|
WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
|
|
240
605
|
|
|
606
|
+
ggml_tensor * src0 = node->src[0];
|
|
607
|
+
ggml_tensor * src1 = node->src[1];
|
|
241
608
|
|
|
242
609
|
switch (node->op) {
|
|
243
|
-
|
|
610
|
+
// no-ops
|
|
244
611
|
case GGML_OP_NONE:
|
|
245
612
|
case GGML_OP_VIEW:
|
|
246
613
|
case GGML_OP_PERMUTE:
|
|
247
614
|
return false;
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context;
|
|
259
|
-
size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs;
|
|
260
|
-
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
|
261
|
-
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
|
262
|
-
|
|
263
|
-
wgpu::Device device = ctx->device;
|
|
264
|
-
ggml_backend_webgpu_map_buffer(ctx, ctx->cpy_params_host_buf,
|
|
265
|
-
wgpu::MapMode::Write, 0, ctx->cpy_params_host_buf.GetSize());
|
|
266
|
-
uint32_t * params = (uint32_t *) ctx->cpy_params_host_buf.GetMappedRange();
|
|
267
|
-
uint32_t ne = (uint32_t)ggml_nelements(node);
|
|
268
|
-
params[0] = ne;
|
|
269
|
-
params[1] = src_misalignment/ggml_type_size(src->type);
|
|
270
|
-
params[2] = dst_misalignment/ggml_type_size(node->type);
|
|
271
|
-
|
|
272
|
-
// Convert byte-strides to element-strides
|
|
273
|
-
params[3] = (uint32_t)src->nb[0]/ggml_type_size(src->type);
|
|
274
|
-
params[4] = (uint32_t)src->nb[1]/ggml_type_size(src->type);
|
|
275
|
-
params[5] = (uint32_t)src->nb[2]/ggml_type_size(src->type);
|
|
276
|
-
params[6] = (uint32_t)src->nb[3]/ggml_type_size(src->type);
|
|
277
|
-
params[7] = (uint32_t)node->nb[0]/ggml_type_size(node->type);
|
|
278
|
-
params[8] = (uint32_t)node->nb[1]/ggml_type_size(node->type);
|
|
279
|
-
params[9] = (uint32_t)node->nb[2]/ggml_type_size(node->type);
|
|
280
|
-
params[10] = (uint32_t)node->nb[3]/ggml_type_size(node->type);
|
|
281
|
-
// Logical shape — same for both tensors even if permuted
|
|
282
|
-
params[11] = (uint32_t)(src->ne[0]);
|
|
283
|
-
params[12] = (uint32_t)(src->ne[1]);
|
|
284
|
-
params[13] = (uint32_t)(src->ne[2]);
|
|
285
|
-
params[14] = (uint32_t)(src->ne[3]);
|
|
286
|
-
|
|
287
|
-
ctx->cpy_params_host_buf.Unmap();
|
|
288
|
-
|
|
289
|
-
wgpu::BindGroupEntry entries[3];
|
|
290
|
-
entries[0].binding = 0;
|
|
291
|
-
entries[0].buffer = src_ctx->buffer;
|
|
292
|
-
entries[0].offset = src_offset;
|
|
293
|
-
entries[0].size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
|
|
294
|
-
|
|
295
|
-
entries[1].binding = 1;
|
|
296
|
-
entries[1].buffer = dst_ctx->buffer;
|
|
297
|
-
entries[1].offset = dst_offset;
|
|
298
|
-
entries[1].size = (ggml_nbytes(node) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
|
|
299
|
-
|
|
300
|
-
entries[2].binding = 2;
|
|
301
|
-
entries[2].buffer = ctx->cpy_params_dev_buf;
|
|
302
|
-
entries[2].offset = 0;
|
|
303
|
-
entries[2].size = ctx->cpy_params_dev_buf.GetSize();
|
|
304
|
-
|
|
305
|
-
wgpu::BindGroupDescriptor bind_group_desc;
|
|
306
|
-
bind_group_desc.layout = ctx->cpy_pipeline.GetBindGroupLayout(0);
|
|
307
|
-
bind_group_desc.label = "ggml_op_cpy";
|
|
308
|
-
bind_group_desc.entryCount = 3;
|
|
309
|
-
bind_group_desc.entries = entries;
|
|
310
|
-
wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
|
|
311
|
-
|
|
312
|
-
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
|
313
|
-
encoder.CopyBufferToBuffer(
|
|
314
|
-
ctx->cpy_params_host_buf, 0,
|
|
315
|
-
ctx->cpy_params_dev_buf, 0,
|
|
316
|
-
ctx->cpy_params_dev_buf.GetSize()
|
|
317
|
-
);
|
|
318
|
-
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
|
319
|
-
pass.SetPipeline(ctx->cpy_pipeline);
|
|
320
|
-
pass.SetBindGroup(0, bind_group);
|
|
321
|
-
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
|
322
|
-
pass.DispatchWorkgroups((ne + max_wg_size - 1) / max_wg_size);
|
|
323
|
-
pass.End();
|
|
324
|
-
wgpu::CommandBuffer commands = encoder.Finish();
|
|
325
|
-
|
|
326
|
-
// TODO, don't submit here, batch submissions
|
|
327
|
-
ctx->queue.Submit(1, &commands);
|
|
328
|
-
// TODO, don't wait on submission here
|
|
329
|
-
ggml_backend_webgpu_wait_on_submission(ctx);
|
|
330
|
-
return true;
|
|
331
|
-
}
|
|
332
|
-
|
|
615
|
+
case GGML_OP_CPY:
|
|
616
|
+
{
|
|
617
|
+
ggml_webgpu_cpy(ctx, src0, node);
|
|
618
|
+
break;
|
|
619
|
+
}
|
|
620
|
+
case GGML_OP_SET_ROWS:
|
|
621
|
+
{
|
|
622
|
+
ggml_webgpu_set_rows(ctx, src0, src1, node);
|
|
623
|
+
break;
|
|
624
|
+
}
|
|
333
625
|
case GGML_OP_MUL_MAT:
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
const ggml_tensor * src1 = node->src[1];
|
|
339
|
-
ggml_backend_webgpu_buffer_context * src1_ctx = (ggml_backend_webgpu_buffer_context *) src1->buffer->context;
|
|
340
|
-
size_t src1_offset = webgpu_tensor_offset(src1) + src1->view_offs;
|
|
341
|
-
ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context;
|
|
342
|
-
|
|
343
|
-
size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs;
|
|
344
|
-
|
|
345
|
-
wgpu::Device device = ctx->device;
|
|
346
|
-
|
|
347
|
-
// map the host parameters buffer
|
|
348
|
-
ggml_backend_webgpu_map_buffer(ctx, ctx->mul_mat_params_host_buf,
|
|
349
|
-
wgpu::MapMode::Write, 0, ctx->mul_mat_params_host_buf.GetSize());
|
|
350
|
-
uint32_t * params = (uint32_t *) ctx->mul_mat_params_host_buf.GetMappedRange();
|
|
351
|
-
|
|
352
|
-
params[0] = (uint32_t)node->ne[1]; // number of rows in result (M)
|
|
353
|
-
params[1] = (uint32_t)node->ne[0]; // number of columns in result (N)
|
|
354
|
-
params[2] = (uint32_t)src0->ne[0]; // number of columns in src0/src1 (K)
|
|
355
|
-
|
|
356
|
-
params[3] = (uint32_t)src0->nb[1]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 1
|
|
357
|
-
params[4] = (uint32_t)src1->nb[1]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 1
|
|
358
|
-
params[5] = (uint32_t)src0->nb[2]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 2
|
|
359
|
-
params[6] = (uint32_t)src1->nb[2]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 2
|
|
360
|
-
params[7] = (uint32_t)src0->nb[3]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 3
|
|
361
|
-
params[8] = (uint32_t)src1->nb[3]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 3
|
|
362
|
-
|
|
363
|
-
params[9] = (uint32_t)src0->ne[2]; // batch size in dimension 2
|
|
364
|
-
params[10] = (uint32_t)src0->ne[3]; // batch size in dimension 3
|
|
365
|
-
params[11] = (uint32_t)(src1->ne[2]/src0->ne[2]); // broadcast in dimension 2
|
|
366
|
-
params[12] = (uint32_t)(src1->ne[3]/src0->ne[3]); // broadcast in dimension 3
|
|
367
|
-
|
|
368
|
-
ctx->mul_mat_params_host_buf.Unmap();
|
|
369
|
-
|
|
370
|
-
wgpu::BindGroupEntry entries[4];
|
|
371
|
-
entries[0].binding = 0;
|
|
372
|
-
entries[0].buffer = src0_ctx->buffer;
|
|
373
|
-
entries[0].offset = src0_offset;
|
|
374
|
-
entries[0].size = ggml_nbytes(src0);
|
|
375
|
-
|
|
376
|
-
entries[1].binding = 1;
|
|
377
|
-
entries[1].buffer = src1_ctx->buffer;
|
|
378
|
-
entries[1].offset = src1_offset;
|
|
379
|
-
entries[1].size = ggml_nbytes(src1);
|
|
380
|
-
|
|
381
|
-
entries[2].binding = 2;
|
|
382
|
-
entries[2].buffer = dst_ctx->buffer;
|
|
383
|
-
entries[2].offset = dst_offset;
|
|
384
|
-
entries[2].size = ggml_nbytes(node);
|
|
385
|
-
|
|
386
|
-
entries[3].binding = 3;
|
|
387
|
-
entries[3].buffer = ctx->mul_mat_params_dev_buf;
|
|
388
|
-
entries[3].offset = 0;
|
|
389
|
-
entries[3].size = ctx->mul_mat_params_dev_buf.GetSize();
|
|
390
|
-
|
|
391
|
-
wgpu::BindGroupDescriptor bind_group_desc;
|
|
392
|
-
bind_group_desc.layout = ctx->mul_mat_pipeline.GetBindGroupLayout(0);
|
|
393
|
-
bind_group_desc.entryCount = 4;
|
|
394
|
-
bind_group_desc.label = "ggml_op_mul_mat";
|
|
395
|
-
bind_group_desc.entries = entries;
|
|
396
|
-
wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
|
|
397
|
-
|
|
398
|
-
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
|
399
|
-
encoder.CopyBufferToBuffer(
|
|
400
|
-
ctx->mul_mat_params_host_buf, 0,
|
|
401
|
-
ctx->mul_mat_params_dev_buf, 0,
|
|
402
|
-
ctx->mul_mat_params_dev_buf.GetSize()
|
|
403
|
-
);
|
|
404
|
-
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
|
405
|
-
pass.SetPipeline(ctx->mul_mat_pipeline);
|
|
406
|
-
pass.SetBindGroup(0, bind_group);
|
|
407
|
-
pass.DispatchWorkgroups((node->ne[0] * node->ne[1] * node->ne[2] * node->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE);
|
|
408
|
-
pass.End();
|
|
409
|
-
wgpu::CommandBuffer commands = encoder.Finish();
|
|
410
|
-
|
|
411
|
-
// TODO, don't submit here, batch submissions
|
|
412
|
-
ctx->queue.Submit(1, &commands);
|
|
413
|
-
// TODO, don't wait on submission here
|
|
414
|
-
ggml_backend_webgpu_wait_on_submission(ctx);
|
|
415
|
-
return true;
|
|
416
|
-
}
|
|
417
|
-
|
|
626
|
+
{
|
|
627
|
+
ggml_webgpu_mul_mat(ctx, src0, src1, node);
|
|
628
|
+
break;
|
|
629
|
+
}
|
|
418
630
|
default:
|
|
419
631
|
return false;
|
|
420
632
|
}
|
|
633
|
+
return true;
|
|
421
634
|
}
|
|
422
635
|
|
|
423
636
|
static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
|
424
637
|
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
|
|
425
638
|
|
|
426
639
|
ggml_backend_webgpu_context * backend_ctx = static_cast<ggml_backend_webgpu_context *>(backend->context);
|
|
427
|
-
webgpu_context
|
|
640
|
+
webgpu_context ctx = backend_ctx->webgpu_ctx;
|
|
428
641
|
|
|
429
642
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
430
643
|
ggml_webgpu_encode_node(ctx, cgraph->nodes[i]);
|
|
431
644
|
}
|
|
432
645
|
|
|
646
|
+
ggml_backend_webgpu_submit_queue(ctx);
|
|
647
|
+
ggml_backend_webgpu_wait_on_submission(ctx);
|
|
648
|
+
|
|
433
649
|
return GGML_STATUS_SUCCESS;
|
|
434
650
|
}
|
|
435
651
|
|
|
@@ -465,49 +681,72 @@ static void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer)
|
|
|
465
681
|
return webgpu_ptr_base;
|
|
466
682
|
}
|
|
467
683
|
|
|
468
|
-
static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer,
|
|
684
|
+
static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer,
|
|
685
|
+
ggml_tensor * tensor,
|
|
686
|
+
uint8_t value,
|
|
687
|
+
size_t offset,
|
|
688
|
+
size_t size) {
|
|
469
689
|
if (size == 0) {
|
|
470
690
|
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
|
|
471
691
|
return;
|
|
472
692
|
}
|
|
473
693
|
|
|
474
|
-
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", "
|
|
694
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", "
|
|
695
|
+
<< offset << ", " << size << ")");
|
|
475
696
|
|
|
476
697
|
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
|
698
|
+
|
|
477
699
|
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
|
700
|
+
|
|
478
701
|
// This is a trick to set all bytes of a u32 to the same 1 byte value.
|
|
479
|
-
uint32_t val32 = (uint32_t)value * 0x01010101;
|
|
702
|
+
uint32_t val32 = (uint32_t) value * 0x01010101;
|
|
480
703
|
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
|
|
481
704
|
}
|
|
482
705
|
|
|
483
|
-
static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
706
|
+
static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|
707
|
+
ggml_tensor * tensor,
|
|
708
|
+
const void * data,
|
|
709
|
+
size_t offset,
|
|
710
|
+
size_t size) {
|
|
711
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", "
|
|
712
|
+
<< offset << ", " << size << ")");
|
|
713
|
+
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
|
714
|
+
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
|
|
487
715
|
|
|
488
716
|
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
|
489
717
|
|
|
490
|
-
webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size/4)*4);
|
|
718
|
+
webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
|
|
491
719
|
|
|
492
720
|
if (size % 4 != 0) {
|
|
493
721
|
// If size is not a multiple of 4, we need to memset the remaining bytes
|
|
494
722
|
size_t remaining_size = size % 4;
|
|
723
|
+
|
|
495
724
|
// pack the remaining bytes into a uint32_t
|
|
496
725
|
uint32_t val32 = 0;
|
|
726
|
+
|
|
497
727
|
for (size_t i = 0; i < remaining_size; i++) {
|
|
498
|
-
((uint8_t *)&val32)[i] = ((const uint8_t *)data)[size - remaining_size + i];
|
|
728
|
+
((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
|
|
499
729
|
}
|
|
500
730
|
// memset the remaining bytes
|
|
501
|
-
ggml_backend_webgpu_buffer_memset(
|
|
731
|
+
ggml_backend_webgpu_buffer_memset(
|
|
732
|
+
webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
|
|
733
|
+
} else {
|
|
734
|
+
// wait for WriteBuffer to complete
|
|
735
|
+
ggml_backend_webgpu_wait_on_submission(webgpu_ctx);
|
|
502
736
|
}
|
|
503
737
|
}
|
|
504
738
|
|
|
505
|
-
static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
|
506
|
-
|
|
739
|
+
static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
|
740
|
+
const ggml_tensor * tensor,
|
|
741
|
+
void * data,
|
|
742
|
+
size_t offset,
|
|
743
|
+
size_t size) {
|
|
744
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", "
|
|
745
|
+
<< offset << ", " << size << ")");
|
|
507
746
|
|
|
508
|
-
ggml_backend_webgpu_buffer_context * buf_ctx
|
|
509
|
-
webgpu_context
|
|
510
|
-
wgpu::Device
|
|
747
|
+
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
|
748
|
+
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
|
|
749
|
+
wgpu::Device device = webgpu_ctx->device;
|
|
511
750
|
|
|
512
751
|
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
|
513
752
|
|
|
@@ -517,22 +756,25 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
|
|
517
756
|
final_size = size + (4 - (size % 4));
|
|
518
757
|
}
|
|
519
758
|
|
|
520
|
-
std::lock_guard<std::
|
|
759
|
+
std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
|
|
521
760
|
|
|
522
|
-
if (webgpu_ctx->get_tensor_staging_buf == nullptr ||
|
|
523
|
-
webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
|
|
761
|
+
if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
|
|
524
762
|
// Create a new staging buffer if it doesn't exist or is too small
|
|
525
763
|
if (webgpu_ctx->get_tensor_staging_buf) {
|
|
526
764
|
webgpu_ctx->get_tensor_staging_buf.Destroy();
|
|
527
765
|
}
|
|
528
|
-
ggml_webgpu_create_buffer(device,
|
|
529
|
-
|
|
766
|
+
ggml_webgpu_create_buffer(device,
|
|
767
|
+
webgpu_ctx->get_tensor_staging_buf,
|
|
768
|
+
final_size,
|
|
769
|
+
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
|
|
770
|
+
"get_tensor_staging_buf");
|
|
530
771
|
}
|
|
531
772
|
|
|
532
773
|
// Copy the data from the buffer to the staging buffer
|
|
533
774
|
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
|
534
775
|
encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size);
|
|
535
776
|
wgpu::CommandBuffer commands = encoder.Finish();
|
|
777
|
+
|
|
536
778
|
// Submit the command buffer to the queue
|
|
537
779
|
webgpu_ctx->queue.Submit(1, &commands);
|
|
538
780
|
|
|
@@ -548,7 +790,6 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
|
|
548
790
|
|
|
549
791
|
static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
|
550
792
|
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
|
|
551
|
-
|
|
552
793
|
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
|
553
794
|
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size);
|
|
554
795
|
}
|
|
@@ -556,13 +797,13 @@ static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8
|
|
|
556
797
|
static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
|
|
557
798
|
/* .free_buffer = */ ggml_backend_webgpu_buffer_free_buffer,
|
|
558
799
|
/* .get_base = */ ggml_backend_webgpu_buffer_get_base,
|
|
559
|
-
/* .init_tensor = */ NULL,
|
|
800
|
+
/* .init_tensor = */ NULL, // TODO: optional, needed?
|
|
560
801
|
/* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor,
|
|
561
802
|
/* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor,
|
|
562
803
|
/* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor,
|
|
563
|
-
/* .cpy_tensor = */ NULL,
|
|
804
|
+
/* .cpy_tensor = */ NULL, // TODO: optional, implement this
|
|
564
805
|
/* .clear = */ ggml_backend_webgpu_buffer_clear,
|
|
565
|
-
/* .reset = */ NULL,
|
|
806
|
+
/* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor
|
|
566
807
|
};
|
|
567
808
|
|
|
568
809
|
/* End GGML Backend Buffer Interface */
|
|
@@ -574,13 +815,17 @@ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer
|
|
|
574
815
|
return ctx->device_name.c_str();
|
|
575
816
|
}
|
|
576
817
|
|
|
577
|
-
static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
|
|
818
|
+
static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
|
|
819
|
+
size_t size) {
|
|
578
820
|
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")");
|
|
579
821
|
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
|
|
580
822
|
|
|
581
823
|
wgpu::Buffer buf;
|
|
582
|
-
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device,
|
|
583
|
-
|
|
824
|
+
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device,
|
|
825
|
+
buf,
|
|
826
|
+
(size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1),
|
|
827
|
+
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
|
|
828
|
+
"allocated_buffer");
|
|
584
829
|
|
|
585
830
|
ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf);
|
|
586
831
|
|
|
@@ -615,8 +860,8 @@ static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_
|
|
|
615
860
|
static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
|
616
861
|
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
|
|
617
862
|
// TODO: what do we actually want to return here? maxBufferSize might not be the full available memory.
|
|
618
|
-
*free
|
|
619
|
-
*total
|
|
863
|
+
*free = ctx->webgpu_ctx->limits.maxBufferSize;
|
|
864
|
+
*total = ctx->webgpu_ctx->limits.maxBufferSize;
|
|
620
865
|
}
|
|
621
866
|
|
|
622
867
|
static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {
|
|
@@ -639,98 +884,140 @@ static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct
|
|
|
639
884
|
|
|
640
885
|
static ggml_guid_t ggml_backend_webgpu_guid(void) {
|
|
641
886
|
static const char * guid_str = "__ggml_webgpu :)";
|
|
642
|
-
return reinterpret_cast<ggml_guid_t>((void *)guid_str);
|
|
887
|
+
return reinterpret_cast<ggml_guid_t>((void *) guid_str);
|
|
643
888
|
}
|
|
644
889
|
|
|
645
|
-
static void ggml_webgpu_init_memset_pipeline(webgpu_context webgpu_ctx) {
|
|
890
|
+
static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
|
|
646
891
|
// we use the maximum workgroup size for the memset pipeline
|
|
647
892
|
size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
|
|
648
893
|
size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
|
|
649
894
|
// Size the bytes_per_thread so that the largest buffer size can be handled
|
|
650
|
-
webgpu_ctx->memset_bytes_per_thread =
|
|
895
|
+
webgpu_ctx->memset_bytes_per_thread =
|
|
896
|
+
(webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads;
|
|
651
897
|
std::vector<wgpu::ConstantEntry> constants(2);
|
|
652
|
-
constants[0].key
|
|
898
|
+
constants[0].key = "wg_size";
|
|
653
899
|
constants[0].value = max_wg_size;
|
|
654
|
-
constants[1].key
|
|
900
|
+
constants[1].key = "bytes_per_thread";
|
|
655
901
|
constants[1].value = webgpu_ctx->memset_bytes_per_thread;
|
|
656
902
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->memset_pipeline, wgsl_memset, "memset", constants);
|
|
657
|
-
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_dev_buf,
|
|
658
|
-
3 * sizeof(uint32_t), // 3 parameters: buffer size, offset, value
|
|
659
|
-
wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "memset_params_dev_buf");
|
|
660
|
-
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_host_buf,
|
|
661
|
-
3 * sizeof(uint32_t), wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "memset_params_host_buf");
|
|
662
903
|
}
|
|
663
904
|
|
|
664
|
-
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context webgpu_ctx) {
|
|
665
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
905
|
+
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
|
906
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
907
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
|
|
908
|
+
wgsl_mul_mat_f32_f32,
|
|
909
|
+
"mul_mat_f32_f32");
|
|
910
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
911
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
|
|
912
|
+
wgsl_mul_mat_f16_f16,
|
|
913
|
+
"mul_mat_f16_f16");
|
|
914
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
915
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
|
|
916
|
+
wgsl_mul_mat_f16_f32,
|
|
917
|
+
"mul_mat_f16_f32");
|
|
918
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
919
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
|
|
920
|
+
wgsl_mul_mat_q4_0_f32,
|
|
921
|
+
"mul_mat_q4_0_f32");
|
|
922
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
923
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
|
|
924
|
+
wgsl_mul_mat_q4_1_f32,
|
|
925
|
+
"mul_mat_q4_1_f32");
|
|
926
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
927
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32],
|
|
928
|
+
wgsl_mul_mat_q5_0_f32,
|
|
929
|
+
"mul_mat_q5_0_f32");
|
|
930
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
931
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32],
|
|
932
|
+
wgsl_mul_mat_q5_1_f32,
|
|
933
|
+
"mul_mat_q5_1_f32");
|
|
934
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
935
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32],
|
|
936
|
+
wgsl_mul_mat_q8_0_f32,
|
|
937
|
+
"mul_mat_q8_0_f32");
|
|
938
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
939
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32],
|
|
940
|
+
wgsl_mul_mat_q2_k_f32,
|
|
941
|
+
"mul_mat_q2_k_f32");
|
|
942
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
943
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32],
|
|
944
|
+
wgsl_mul_mat_q3_k_f32,
|
|
945
|
+
"mul_mat_q3_k_f32");
|
|
946
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
947
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32],
|
|
948
|
+
wgsl_mul_mat_q4_k_f32,
|
|
949
|
+
"mul_mat_q4_k_f32");
|
|
950
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
951
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32],
|
|
952
|
+
wgsl_mul_mat_q5_k_f32,
|
|
953
|
+
"mul_mat_q5_k_f32");
|
|
954
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
955
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32],
|
|
956
|
+
wgsl_mul_mat_q6_k_f32,
|
|
957
|
+
"mul_mat_q6_k_f32");
|
|
958
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
959
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
|
|
960
|
+
wgsl_mul_mat_iq2_xxs_f32,
|
|
961
|
+
"mul_mat_iq2_xxs_f32");
|
|
962
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
963
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
|
|
964
|
+
wgsl_mul_mat_iq2_xs_f32,
|
|
965
|
+
"mul_mat_iq2_xs_f32");
|
|
966
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
967
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32],
|
|
968
|
+
wgsl_mul_mat_iq2_s_f32,
|
|
969
|
+
"mul_mat_iq2_s_f32");
|
|
970
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
971
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
|
|
972
|
+
wgsl_mul_mat_iq3_xxs_f32,
|
|
973
|
+
"mul_mat_iq3_xxs_f32");
|
|
974
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
975
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32],
|
|
976
|
+
wgsl_mul_mat_iq3_s_f32,
|
|
977
|
+
"mul_mat_iq3_s_f32");
|
|
978
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
979
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32],
|
|
980
|
+
wgsl_mul_mat_iq1_s_f32,
|
|
981
|
+
"mul_mat_iq1_s_f32");
|
|
982
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
983
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32],
|
|
984
|
+
wgsl_mul_mat_iq1_m_f32,
|
|
985
|
+
"mul_mat_iq1_m_f32");
|
|
986
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
987
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
|
|
988
|
+
wgsl_mul_mat_iq4_nl_f32,
|
|
989
|
+
"mul_mat_iq4_nl_f32");
|
|
990
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
|
991
|
+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
|
|
992
|
+
wgsl_mul_mat_iq4_xs_f32,
|
|
993
|
+
"mul_mat_iq4_xs_f32");
|
|
670
994
|
}
|
|
671
995
|
|
|
672
|
-
static void
|
|
996
|
+
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
|
|
673
997
|
std::vector<wgpu::ConstantEntry> constants(1);
|
|
674
|
-
constants[0].key
|
|
998
|
+
constants[0].key = "wg_size";
|
|
675
999
|
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
|
|
1000
|
+
ggml_webgpu_create_pipeline(
|
|
1001
|
+
webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", constants);
|
|
1002
|
+
}
|
|
676
1003
|
|
|
1004
|
+
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
|
|
1005
|
+
std::vector<wgpu::ConstantEntry> constants(1);
|
|
1006
|
+
constants[0].key = "wg_size";
|
|
1007
|
+
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
|
|
677
1008
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants);
|
|
678
|
-
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->cpy_params_dev_buf, WEBGPU_CPY_PARAMS_SIZE,
|
|
679
|
-
wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "cpy_params_dev_buf");
|
|
680
|
-
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->cpy_params_host_buf, WEBGPU_CPY_PARAMS_SIZE,
|
|
681
|
-
wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "cpy_params_host_buf");
|
|
682
1009
|
}
|
|
683
1010
|
|
|
684
|
-
// TODO: Make thread safe if multiple devices are used
|
|
685
1011
|
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
|
|
686
1012
|
GGML_UNUSED(params);
|
|
687
1013
|
|
|
688
1014
|
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()");
|
|
689
1015
|
|
|
690
|
-
ggml_backend_webgpu_device_context * dev_ctx
|
|
691
|
-
webgpu_context
|
|
692
|
-
|
|
693
|
-
std::lock_guard<std::mutex> lock(webgpu_ctx->mutex);
|
|
694
|
-
|
|
695
|
-
if (!webgpu_ctx->device_initialized) {
|
|
696
|
-
// Initialize device
|
|
697
|
-
wgpu::DeviceDescriptor dev_desc;
|
|
698
|
-
dev_desc.requiredLimits = &webgpu_ctx->limits;
|
|
699
|
-
dev_desc.requiredFeatures = webgpu_ctx->features.features;
|
|
700
|
-
dev_desc.requiredFeatureCount = webgpu_ctx->features.featureCount;
|
|
701
|
-
dev_desc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous,
|
|
702
|
-
[](const wgpu::Device& device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
|
703
|
-
GGML_UNUSED(device);
|
|
704
|
-
GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
|
|
705
|
-
});
|
|
706
|
-
dev_desc.SetUncapturedErrorCallback(
|
|
707
|
-
[](const wgpu::Device& device, wgpu::ErrorType reason, wgpu::StringView message) {
|
|
708
|
-
GGML_UNUSED(device);
|
|
709
|
-
GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
|
|
710
|
-
});
|
|
711
|
-
webgpu_ctx->instance.WaitAny(webgpu_ctx->adapter.RequestDevice(&dev_desc, wgpu::CallbackMode::WaitAnyOnly,
|
|
712
|
-
[webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
|
713
|
-
if (status != wgpu::RequestDeviceStatus::Success) {
|
|
714
|
-
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data);
|
|
715
|
-
return;
|
|
716
|
-
}
|
|
717
|
-
webgpu_ctx->device = device;
|
|
718
|
-
}),
|
|
719
|
-
UINT64_MAX
|
|
720
|
-
);
|
|
721
|
-
GGML_ASSERT(webgpu_ctx->device != nullptr);
|
|
722
|
-
|
|
723
|
-
// Initialize (compute) queue
|
|
724
|
-
webgpu_ctx->queue = webgpu_ctx->device.GetQueue();
|
|
725
|
-
|
|
726
|
-
ggml_webgpu_init_memset_pipeline(webgpu_ctx);
|
|
727
|
-
ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
|
|
728
|
-
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
|
|
729
|
-
webgpu_ctx->device_initialized = true;
|
|
730
|
-
}
|
|
1016
|
+
ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
|
|
1017
|
+
webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
|
|
731
1018
|
|
|
732
1019
|
static ggml_backend_webgpu_context backend_ctx;
|
|
733
|
-
backend_ctx.name
|
|
1020
|
+
backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
|
|
734
1021
|
backend_ctx.webgpu_ctx = webgpu_ctx;
|
|
735
1022
|
|
|
736
1023
|
// See GGML Backend Interface section
|
|
@@ -748,14 +1035,15 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
|
|
|
748
1035
|
// See GGML Backend Buffer Type Interface section
|
|
749
1036
|
static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
|
|
750
1037
|
/* .iface = */ {
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
/* .is_host = */ NULL,
|
|
1038
|
+
/* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
|
|
1039
|
+
/* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
|
|
1040
|
+
/* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment,
|
|
1041
|
+
/* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size,
|
|
1042
|
+
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
|
1043
|
+
/* .is_host = */ NULL, // defaults to false
|
|
757
1044
|
},
|
|
758
|
-
/* .device = */
|
|
1045
|
+
/* .device = */
|
|
1046
|
+
dev,
|
|
759
1047
|
/* .context = */ NULL,
|
|
760
1048
|
};
|
|
761
1049
|
|
|
@@ -764,7 +1052,7 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
|
|
|
764
1052
|
|
|
765
1053
|
static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
|
766
1054
|
GGML_UNUSED(dev);
|
|
767
|
-
return
|
|
1055
|
+
return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
|
|
768
1056
|
}
|
|
769
1057
|
|
|
770
1058
|
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
|
@@ -776,9 +1064,44 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|
|
776
1064
|
case GGML_OP_PERMUTE:
|
|
777
1065
|
return true;
|
|
778
1066
|
case GGML_OP_CPY:
|
|
1067
|
+
case GGML_OP_SET_ROWS:
|
|
779
1068
|
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
|
|
780
1069
|
case GGML_OP_MUL_MAT:
|
|
781
|
-
|
|
1070
|
+
{
|
|
1071
|
+
switch (op->src[1]->type) {
|
|
1072
|
+
case GGML_TYPE_F16:
|
|
1073
|
+
return op->src[0]->type == GGML_TYPE_F16;
|
|
1074
|
+
case GGML_TYPE_F32:
|
|
1075
|
+
switch (op->src[0]->type) {
|
|
1076
|
+
case GGML_TYPE_F32:
|
|
1077
|
+
case GGML_TYPE_F16:
|
|
1078
|
+
case GGML_TYPE_Q4_0:
|
|
1079
|
+
case GGML_TYPE_Q4_1:
|
|
1080
|
+
case GGML_TYPE_Q5_0:
|
|
1081
|
+
case GGML_TYPE_Q5_1:
|
|
1082
|
+
case GGML_TYPE_Q8_0:
|
|
1083
|
+
case GGML_TYPE_Q2_K:
|
|
1084
|
+
case GGML_TYPE_Q3_K:
|
|
1085
|
+
case GGML_TYPE_Q4_K:
|
|
1086
|
+
case GGML_TYPE_Q5_K:
|
|
1087
|
+
case GGML_TYPE_Q6_K:
|
|
1088
|
+
case GGML_TYPE_IQ2_XXS:
|
|
1089
|
+
case GGML_TYPE_IQ2_XS:
|
|
1090
|
+
case GGML_TYPE_IQ2_S:
|
|
1091
|
+
case GGML_TYPE_IQ3_XXS:
|
|
1092
|
+
case GGML_TYPE_IQ3_S:
|
|
1093
|
+
case GGML_TYPE_IQ1_S:
|
|
1094
|
+
case GGML_TYPE_IQ1_M:
|
|
1095
|
+
case GGML_TYPE_IQ4_NL:
|
|
1096
|
+
case GGML_TYPE_IQ4_XS:
|
|
1097
|
+
return true;
|
|
1098
|
+
default:
|
|
1099
|
+
return false;
|
|
1100
|
+
}
|
|
1101
|
+
default:
|
|
1102
|
+
return false;
|
|
1103
|
+
}
|
|
1104
|
+
}
|
|
782
1105
|
default:
|
|
783
1106
|
return false;
|
|
784
1107
|
}
|
|
@@ -827,30 +1150,105 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|
|
827
1150
|
webgpu_context ctx = reg_ctx->webgpu_ctx;
|
|
828
1151
|
|
|
829
1152
|
wgpu::RequestAdapterOptions options = {};
|
|
830
|
-
auto
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
1153
|
+
auto callback =
|
|
1154
|
+
[](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message, void * userdata) {
|
|
1155
|
+
if (status != wgpu::RequestAdapterStatus::Success) {
|
|
1156
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
|
1157
|
+
return;
|
|
1158
|
+
}
|
|
1159
|
+
*static_cast<wgpu::Adapter *>(userdata) = std::move(adapter);
|
|
1160
|
+
};
|
|
1161
|
+
void * userdata = &ctx->adapter;
|
|
1162
|
+
ctx->instance.WaitAny(
|
|
1163
|
+
ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous, callback, userdata), UINT64_MAX);
|
|
839
1164
|
GGML_ASSERT(ctx->adapter != nullptr);
|
|
840
1165
|
|
|
841
1166
|
ctx->adapter.GetLimits(&ctx->limits);
|
|
842
|
-
ctx->adapter.GetFeatures(&ctx->features);
|
|
843
1167
|
|
|
844
1168
|
wgpu::AdapterInfo info{};
|
|
845
1169
|
ctx->adapter.GetInfo(&info);
|
|
846
1170
|
|
|
1171
|
+
// Initialize device
|
|
1172
|
+
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
|
|
1173
|
+
wgpu::FeatureName::ImplicitDeviceSynchronization };
|
|
1174
|
+
wgpu::DeviceDescriptor dev_desc;
|
|
1175
|
+
dev_desc.requiredLimits = &ctx->limits;
|
|
1176
|
+
dev_desc.requiredFeatures = required_features.data();
|
|
1177
|
+
dev_desc.requiredFeatureCount = required_features.size();
|
|
1178
|
+
dev_desc.SetDeviceLostCallback(
|
|
1179
|
+
wgpu::CallbackMode::AllowSpontaneous,
|
|
1180
|
+
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
|
1181
|
+
GGML_UNUSED(device);
|
|
1182
|
+
GGML_LOG_ERROR(
|
|
1183
|
+
"ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), std::string(message).c_str());
|
|
1184
|
+
});
|
|
1185
|
+
dev_desc.SetUncapturedErrorCallback(
|
|
1186
|
+
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
|
|
1187
|
+
GGML_UNUSED(device);
|
|
1188
|
+
GGML_LOG_ERROR(
|
|
1189
|
+
"ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), std::string(message).c_str());
|
|
1190
|
+
});
|
|
1191
|
+
ctx->instance.WaitAny(ctx->adapter.RequestDevice(
|
|
1192
|
+
&dev_desc,
|
|
1193
|
+
wgpu::CallbackMode::AllowSpontaneous,
|
|
1194
|
+
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
|
1195
|
+
if (status != wgpu::RequestDeviceStatus::Success) {
|
|
1196
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
|
|
1197
|
+
return;
|
|
1198
|
+
}
|
|
1199
|
+
ctx->device = std::move(device);
|
|
1200
|
+
}),
|
|
1201
|
+
UINT64_MAX);
|
|
1202
|
+
GGML_ASSERT(ctx->device != nullptr);
|
|
1203
|
+
|
|
1204
|
+
// Initialize (compute) queue
|
|
1205
|
+
ctx->queue = ctx->device.GetQueue();
|
|
1206
|
+
|
|
1207
|
+
// Create buffer pool for shader parameters
|
|
1208
|
+
ctx->param_buf_pool.init(ctx->device,
|
|
1209
|
+
WEBGPU_NUM_PARAM_BUFS,
|
|
1210
|
+
WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
|
1211
|
+
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
|
|
1212
|
+
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
|
|
1213
|
+
ctx->set_rows_error_buf_pool.init(ctx->device,
|
|
1214
|
+
WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
|
|
1215
|
+
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
|
1216
|
+
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
|
|
1217
|
+
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
|
|
1218
|
+
|
|
1219
|
+
ggml_webgpu_init_memset_pipeline(ctx);
|
|
1220
|
+
ggml_webgpu_init_mul_mat_pipeline(ctx);
|
|
1221
|
+
ggml_webgpu_init_set_rows_pipeline(ctx);
|
|
1222
|
+
ggml_webgpu_init_cpy_pipeline(ctx);
|
|
1223
|
+
|
|
1224
|
+
#ifdef GGML_WEBGPU_DEBUG
|
|
1225
|
+
// Initialize debug buffers
|
|
1226
|
+
ggml_webgpu_create_buffer(ctx->device,
|
|
1227
|
+
ctx->debug_host_buf,
|
|
1228
|
+
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
|
1229
|
+
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
|
|
1230
|
+
"debug_host_buf");
|
|
1231
|
+
ggml_webgpu_create_buffer(ctx->device,
|
|
1232
|
+
ctx->debug_dev_buf,
|
|
1233
|
+
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
|
1234
|
+
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
|
|
1235
|
+
"debug_dev_buf");
|
|
1236
|
+
#endif
|
|
1237
|
+
|
|
847
1238
|
static ggml_backend_webgpu_device_context device_ctx;
|
|
848
|
-
device_ctx.webgpu_ctx
|
|
1239
|
+
device_ctx.webgpu_ctx = ctx;
|
|
849
1240
|
device_ctx.device_name = GGML_WEBGPU_NAME;
|
|
850
|
-
device_ctx.device_desc =
|
|
851
|
-
|
|
852
|
-
GGML_LOG_INFO(
|
|
853
|
-
|
|
1241
|
+
device_ctx.device_desc = info.description;
|
|
1242
|
+
|
|
1243
|
+
GGML_LOG_INFO(
|
|
1244
|
+
"ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
|
|
1245
|
+
"device_desc: %s\n",
|
|
1246
|
+
info.vendorID,
|
|
1247
|
+
std::string(info.vendor).c_str(),
|
|
1248
|
+
std::string(info.architecture).c_str(),
|
|
1249
|
+
info.deviceID,
|
|
1250
|
+
std::string(info.device).c_str(),
|
|
1251
|
+
std::string(info.description).c_str());
|
|
854
1252
|
|
|
855
1253
|
// See GGML Backend Device Interface section
|
|
856
1254
|
static ggml_backend_device device = {
|
|
@@ -861,7 +1259,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|
|
861
1259
|
return &device;
|
|
862
1260
|
}
|
|
863
1261
|
|
|
864
|
-
|
|
865
1262
|
static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
|
|
866
1263
|
/* .get_name = */ ggml_backend_webgpu_reg_get_name,
|
|
867
1264
|
/* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count,
|
|
@@ -871,23 +1268,21 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
|
|
|
871
1268
|
|
|
872
1269
|
/* End GGML Backend Registration Interface */
|
|
873
1270
|
|
|
874
|
-
// TODO: Does this need to be thread safe? Is it only called once?
|
|
875
1271
|
ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
|
876
1272
|
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
|
|
877
1273
|
|
|
878
1274
|
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
|
|
879
|
-
webgpu_ctx->device_initialized = false;
|
|
880
1275
|
|
|
881
1276
|
static ggml_backend_webgpu_reg_context ctx;
|
|
882
|
-
ctx.webgpu_ctx
|
|
883
|
-
ctx.name
|
|
1277
|
+
ctx.webgpu_ctx = webgpu_ctx;
|
|
1278
|
+
ctx.name = GGML_WEBGPU_NAME;
|
|
884
1279
|
ctx.device_count = 1;
|
|
885
1280
|
|
|
886
|
-
wgpu::InstanceDescriptor
|
|
887
|
-
std::vector<wgpu::InstanceFeatureName> instance_features = {wgpu::InstanceFeatureName::TimedWaitAny};
|
|
888
|
-
instance_descriptor.requiredFeatures
|
|
889
|
-
instance_descriptor.requiredFeatureCount
|
|
890
|
-
webgpu_ctx->instance
|
|
1281
|
+
wgpu::InstanceDescriptor instance_descriptor{};
|
|
1282
|
+
std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
|
|
1283
|
+
instance_descriptor.requiredFeatures = instance_features.data();
|
|
1284
|
+
instance_descriptor.requiredFeatureCount = instance_features.size();
|
|
1285
|
+
webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
|
|
891
1286
|
GGML_ASSERT(webgpu_ctx->instance != nullptr);
|
|
892
1287
|
|
|
893
1288
|
static ggml_backend_reg reg = {
|