@novastera-oss/llamarn 0.2.9 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (314) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  22. package/cpp/llama.cpp/CMakePresets.json +11 -0
  23. package/cpp/llama.cpp/CODEOWNERS +1 -0
  24. package/cpp/llama.cpp/README.md +8 -8
  25. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  26. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  27. package/cpp/llama.cpp/common/arg.cpp +62 -1
  28. package/cpp/llama.cpp/common/chat.cpp +37 -20
  29. package/cpp/llama.cpp/common/chat.h +2 -0
  30. package/cpp/llama.cpp/common/common.cpp +22 -6
  31. package/cpp/llama.cpp/common/common.h +22 -4
  32. package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
  33. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
  34. package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
  35. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  36. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  37. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  38. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  39. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  40. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  41. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
  42. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  44. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
  45. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  46. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  97. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
  99. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
  100. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  131. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
  132. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  138. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  139. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  140. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  141. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  142. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  143. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  144. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  168. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  169. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  170. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  173. package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  177. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  178. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  179. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
  180. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  181. package/cpp/llama.cpp/include/llama.h +15 -47
  182. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  183. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  184. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  185. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  186. package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
  187. package/cpp/llama.cpp/src/llama-arch.h +23 -1
  188. package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
  189. package/cpp/llama.cpp/src/llama-batch.h +31 -18
  190. package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
  191. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  192. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  193. package/cpp/llama.cpp/src/llama-context.h +26 -16
  194. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  195. package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
  196. package/cpp/llama.cpp/src/llama-graph.h +184 -122
  197. package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
  198. package/cpp/llama.cpp/src/llama-hparams.h +13 -2
  199. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  200. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  201. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  202. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  203. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  204. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  205. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  206. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
  207. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  208. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  209. package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
  210. package/cpp/llama.cpp/src/llama-model.h +21 -4
  211. package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
  212. package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
  213. package/cpp/llama.cpp/src/llama-vocab.h +43 -0
  214. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  215. package/cpp/llama.cpp/src/unicode.h +2 -0
  216. package/ios/include/chat.h +2 -0
  217. package/ios/include/common.h +22 -4
  218. package/ios/include/llama.h +15 -47
  219. package/ios/libs/llama.xcframework/Info.plist +13 -13
  220. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  221. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  223. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  224. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
  225. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  231. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  232. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  248. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  250. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  254. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  255. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  261. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  262. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
  263. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  264. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  265. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  267. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  268. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
  269. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
  270. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  274. package/package.json +4 -4
  275. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  276. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  277. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  278. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  279. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  280. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -0,0 +1,225 @@
1
+ #pragma once
2
+
3
+ #include "ggml-common.h"
4
+
5
+ template<typename src_t, typename dst_t>
6
+ static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) {
7
+ if constexpr (std::is_same_v<src_t, dst_t>) {
8
+ *dst = *src;
9
+ } else {
10
+ *dst = float(*src);
11
+ }
12
+ }
13
+
14
+ static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
15
+ if (x <= val[0]) return 0;
16
+ if (x >= val[n-1]) return n-1;
17
+ int ml = 0, mu = n-1;
18
+ while (mu-ml > 1) {
19
+ int mav = (ml+mu)/2;
20
+ if (x < val[mav]) mu = mav; else ml = mav;
21
+ }
22
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
23
+ }
24
+
25
+ static __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) {
26
+ float amax = 0.0f;
27
+ float vmax = 0.0f;
28
+
29
+ for (int j = 0; j < QK4_0; ++j) {
30
+ const float v = x[j];
31
+ if (amax < fabsf(v)) {
32
+ amax = fabsf(v);
33
+ vmax = v;
34
+ }
35
+ }
36
+
37
+ const float d = vmax / -8;
38
+ const float id = d ? 1.0f/d : 0.0f;
39
+
40
+ y->d = d;
41
+
42
+ for (int j = 0; j < QK4_0/2; ++j) {
43
+ const float x0 = x[0 + j]*id;
44
+ const float x1 = x[QK4_0/2 + j]*id;
45
+
46
+ const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
47
+ const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
48
+
49
+ y->qs[j] = xi0;
50
+ y->qs[j] |= xi1 << 4;
51
+ }
52
+ }
53
+
54
+ static __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) {
55
+ float vmin = FLT_MAX;
56
+ float vmax = -FLT_MAX;
57
+
58
+ for (int j = 0; j < QK4_1; ++j) {
59
+ const float v = x[j];
60
+ if (v < vmin) vmin = v;
61
+ if (v > vmax) vmax = v;
62
+ }
63
+
64
+ const float d = (vmax - vmin) / ((1 << 4) - 1);
65
+ const float id = d ? 1.0f/d : 0.0f;
66
+
67
+ y->dm.x = d;
68
+ y->dm.y = vmin;
69
+
70
+ for (int j = 0; j < QK4_1/2; ++j) {
71
+ const float x0 = (x[0 + j] - vmin)*id;
72
+ const float x1 = (x[QK4_1/2 + j] - vmin)*id;
73
+
74
+ const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
75
+ const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
76
+
77
+ y->qs[j] = xi0;
78
+ y->qs[j] |= xi1 << 4;
79
+ }
80
+ }
81
+
82
+ static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) {
83
+ float amax = 0.0f;
84
+ float vmax = 0.0f;
85
+
86
+ for (int j = 0; j < QK5_0; ++j) {
87
+ const float v = x[j];
88
+ if (amax < fabsf(v)) {
89
+ amax = fabsf(v);
90
+ vmax = v;
91
+ }
92
+ }
93
+
94
+ const float d = vmax / -16;
95
+ const float id = d ? 1.0f/d : 0.0f;
96
+
97
+ y->d = d;
98
+
99
+ uint32_t qh = 0;
100
+ for (int j = 0; j < QK5_0/2; ++j) {
101
+ const float x0 = x[0 + j]*id;
102
+ const float x1 = x[QK5_0/2 + j]*id;
103
+
104
+ const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
105
+ const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
106
+
107
+ y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
108
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
109
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
110
+ }
111
+ memcpy(y->qh, &qh, sizeof(qh));
112
+ }
113
+
114
+ static __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) {
115
+ float min = x[0];
116
+ float max = x[0];
117
+
118
+ for (int j = 1; j < QK5_1; ++j) {
119
+ const float v = x[j];
120
+ min = v < min ? v : min;
121
+ max = v > max ? v : max;
122
+ }
123
+
124
+ const float d = (max - min) / 31;
125
+ const float id = d ? 1.0f/d : 0.0f;
126
+
127
+ y->dm.x = d;
128
+ y->dm.y = min;
129
+
130
+ uint32_t qh = 0;
131
+ for (int j = 0; j < QK5_1/2; ++j) {
132
+ const float x0 = (x[0 + j] - min)*id;
133
+ const float x1 = (x[QK5_1/2 + j] - min)*id;
134
+
135
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
136
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
137
+
138
+ y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
139
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
140
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
141
+ }
142
+ memcpy(y->qh, &qh, sizeof(qh));
143
+ }
144
+
145
+ static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) {
146
+ float amax = 0.0f; // absolute max
147
+
148
+ for (int j = 0; j < QK8_0; j++) {
149
+ const float v = x[j];
150
+ amax = fmaxf(amax, fabsf(v));
151
+ }
152
+
153
+ const float d = amax / ((1 << 7) - 1);
154
+ const float id = d ? 1.0f/d : 0.0f;
155
+
156
+ y->d = d;
157
+
158
+ for (int j = 0; j < QK8_0; ++j) {
159
+ const float x0 = x[j]*id;
160
+ y->qs[j] = roundf(x0);
161
+ }
162
+ }
163
+
164
+ static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) {
165
+ float amax = 0.0f;
166
+ float vmax = 0.0f;
167
+
168
+ for (int j = 0; j < QK4_NL; ++j) {
169
+ const float v = x[j];
170
+ if (amax < fabsf(v)) {
171
+ amax = fabsf(v);
172
+ vmax = v;
173
+ }
174
+ }
175
+
176
+ float d = vmax / kvalues_iq4nl[0];
177
+ const float id = d ? 1.0f/d : 0.0f;
178
+
179
+ float sumqx = 0, sumq2 = 0;
180
+ for (int j = 0; j < QK4_NL/2; ++j) {
181
+ const float x0 = x[0 + j]*id;
182
+ const float x1 = x[QK4_NL/2 + j]*id;
183
+ const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
184
+ const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
185
+ y->qs[j] = xi0 | (xi1 << 4);
186
+ const float v0 = kvalues_iq4nl[xi0];
187
+ const float v1 = kvalues_iq4nl[xi1];
188
+ const float w0 = x[0 + j]*x[0 + j];
189
+ const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j];
190
+ sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j];
191
+ sumq2 += w0*v0*v0 + w1*v1*v1;
192
+ }
193
+
194
+ y->d = sumq2 > 0 ? sumqx/sumq2 : d;
195
+ }
196
+
197
+ // Wrapper functions for cpy.cu compatibility
198
+ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
199
+ quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti);
200
+ }
201
+
202
+ static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
203
+ quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti);
204
+ }
205
+
206
+ static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
207
+ quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti);
208
+ }
209
+
210
+ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
211
+ quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti);
212
+ }
213
+
214
+ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
215
+ quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti);
216
+ }
217
+
218
+ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
219
+ quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
220
+ }
221
+
222
+ template<typename src_t, typename dst_t>
223
+ static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
224
+ convert_flt((const src_t *)cxi, (dst_t *)cdsti);
225
+ }
@@ -1,51 +1,17 @@
1
1
  #include "cpy.cuh"
2
2
  #include "dequantize.cuh"
3
- #ifdef GGML_USE_MUSA
3
+ #include "cpy-utils.cuh"
4
+ #if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
4
5
  #include "ggml-musa/mudnn.cuh"
5
- #endif // GGML_USE_MUSA
6
+ #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
6
7
 
7
8
  typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
8
9
 
9
- static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
10
- const float * xi = (const float *) cxi;
11
- float * dsti = (float *) cdsti;
12
-
13
- *dsti = *xi;
14
- }
15
-
16
- static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
17
- const float * xi = (const float *) cxi;
18
- nv_bfloat16 * dsti = (nv_bfloat16 *) cdsti;
19
-
20
- *dsti = *xi;
21
- }
22
-
23
- static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
24
- const float * xi = (const float *) cxi;
25
- half * dsti = (half *) cdsti;
26
-
27
- *dsti = __float2half(*xi);
28
- }
29
-
30
- static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
31
- const half * xi = (const half *) cxi;
32
- half * dsti = (half *) cdsti;
33
-
34
- *dsti = *xi;
35
- }
36
-
37
- static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
38
- const half * xi = (const half *) cxi;
39
- float * dsti = (float *) cdsti;
40
-
41
- *dsti = *xi;
42
- }
43
-
44
10
  template <cpy_kernel_t cpy_1>
45
- static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
46
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
47
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
48
- const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
11
+ static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne,
12
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
13
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
14
+ const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
49
15
  const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
50
16
 
51
17
  if (i >= ne) {
@@ -71,29 +37,6 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in
71
37
  cpy_1(cx + x_offset, cdst + dst_offset);
72
38
  }
73
39
 
74
- static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
75
- const float * xi = (const float *) cxi;
76
- block_q8_0 * dsti = (block_q8_0 *) cdsti;
77
-
78
- float amax = 0.0f; // absolute max
79
-
80
- for (int j = 0; j < QK8_0; j++) {
81
- const float v = xi[j];
82
- amax = fmaxf(amax, fabsf(v));
83
- }
84
-
85
- const float d = amax / ((1 << 7) - 1);
86
- const float id = d ? 1.0f/d : 0.0f;
87
-
88
- dsti->d = d;
89
-
90
- for (int j = 0; j < QK8_0; ++j) {
91
- const float x0 = xi[j]*id;
92
-
93
- dsti->qs[j] = roundf(x0);
94
- }
95
- }
96
-
97
40
  static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
98
41
  float * cdstf = (float *)(cdsti);
99
42
 
@@ -106,139 +49,6 @@ static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
106
49
  }
107
50
  }
108
51
 
109
- static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
110
- const float * xi = (const float *) cxi;
111
- block_q4_0 * dsti = (block_q4_0 *) cdsti;
112
-
113
- float amax = 0.0f;
114
- float vmax = 0.0f;
115
-
116
- for (int j = 0; j < QK4_0; ++j) {
117
- const float v = xi[j];
118
- if (amax < fabsf(v)) {
119
- amax = fabsf(v);
120
- vmax = v;
121
- }
122
- }
123
-
124
- const float d = vmax / -8;
125
- const float id = d ? 1.0f/d : 0.0f;
126
-
127
- dsti->d = d;
128
-
129
- for (int j = 0; j < QK4_0/2; ++j) {
130
- const float x0 = xi[0 + j]*id;
131
- const float x1 = xi[QK4_0/2 + j]*id;
132
-
133
- const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
134
- const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
135
-
136
- dsti->qs[j] = xi0;
137
- dsti->qs[j] |= xi1 << 4;
138
- }
139
- }
140
-
141
- static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
142
- const float * xi = (const float *) cxi;
143
- block_q4_1 * dsti = (block_q4_1 *) cdsti;
144
-
145
- float vmin = FLT_MAX;
146
- float vmax = -FLT_MAX;
147
-
148
- for (int j = 0; j < QK4_1; ++j) {
149
- const float v = xi[j];
150
-
151
- if (v < vmin) vmin = v;
152
- if (v > vmax) vmax = v;
153
- }
154
-
155
- const float d = (vmax - vmin) / ((1 << 4) - 1);
156
- const float id = d ? 1.0f/d : 0.0f;
157
-
158
- dsti->dm.x = d;
159
- dsti->dm.y = vmin;
160
-
161
- for (int j = 0; j < QK4_1/2; ++j) {
162
- const float x0 = (xi[0 + j] - vmin)*id;
163
- const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
164
-
165
- const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
166
- const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
167
-
168
- dsti->qs[j] = xi0;
169
- dsti->qs[j] |= xi1 << 4;
170
- }
171
- }
172
-
173
- static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
174
- const float * xi = (const float *) cxi;
175
- block_q5_0 * dsti = (block_q5_0 *) cdsti;
176
-
177
- float amax = 0.0f;
178
- float vmax = 0.0f;
179
-
180
- for (int j = 0; j < QK5_0; ++j) {
181
- const float v = xi[j];
182
- if (amax < fabsf(v)) {
183
- amax = fabsf(v);
184
- vmax = v;
185
- }
186
- }
187
-
188
- const float d = vmax / -16;
189
- const float id = d ? 1.0f/d : 0.0f;
190
-
191
- dsti->d = d;
192
-
193
- uint32_t qh = 0;
194
- for (int j = 0; j < QK5_0/2; ++j) {
195
- const float x0 = xi[0 + j]*id;
196
- const float x1 = xi[QK5_0/2 + j]*id;
197
-
198
- const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
199
- const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
200
-
201
- dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
202
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
203
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
204
- }
205
- memcpy(dsti->qh, &qh, sizeof(qh));
206
- }
207
-
208
- static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
209
- const float * xi = (const float *) cxi;
210
- block_q5_1 * dsti = (block_q5_1 *) cdsti;
211
-
212
- float min = xi[0];
213
- float max = xi[0];
214
-
215
- for (int j = 1; j < QK5_1; ++j) {
216
- const float v = xi[j];
217
- min = v < min ? v : min;
218
- max = v > max ? v : max;
219
- }
220
-
221
- const float d = (max - min) / 31;
222
- const float id = d ? 1.0f/d : 0.0f;
223
-
224
- dsti->dm.x = d;
225
- dsti->dm.y = min;
226
-
227
- uint32_t qh = 0;
228
- for (int j = 0; j < QK5_1/2; ++j) {
229
- const float x0 = (xi[0 + j] - min)*id;
230
- const float x1 = (xi[QK5_1/2 + j] - min)*id;
231
-
232
- const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
233
- const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
234
-
235
- dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
236
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
237
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
238
- }
239
- memcpy(dsti->qh, &qh, sizeof(qh));
240
- }
241
-
242
52
  template<dequantize_kernel_t dequant, int qk>
243
53
  static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
244
54
  float * cdstf = (float *)(cdsti);
@@ -252,53 +62,6 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
252
62
  }
253
63
  }
254
64
 
255
- static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
256
- if (x <= val[0]) return 0;
257
- if (x >= val[n-1]) return n-1;
258
- int ml = 0, mu = n-1;
259
- while (mu-ml > 1) {
260
- int mav = (ml+mu)/2;
261
- if (x < val[mav]) mu = mav; else ml = mav;
262
- }
263
- return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
264
- }
265
-
266
- static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
267
- const float * xi = (const float *) cxi;
268
- block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
269
-
270
- float amax = 0.0f;
271
- float vmax = 0.0f;
272
-
273
- for (int j = 0; j < QK4_NL; ++j) {
274
- const float v = xi[j];
275
- if (amax < fabsf(v)) {
276
- amax = fabsf(v);
277
- vmax = v;
278
- }
279
- }
280
-
281
- float d = vmax / kvalues_iq4nl[0];
282
- const float id = d ? 1.0f/d : 0.0f;
283
-
284
- float sumqx = 0, sumq2 = 0;
285
- for (int j = 0; j < QK4_NL/2; ++j) {
286
- const float x0 = xi[0 + j]*id;
287
- const float x1 = xi[QK4_NL/2 + j]*id;
288
- const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
289
- const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
290
- dsti->qs[j] = xi0 | (xi1 << 4);
291
- const float v0 = kvalues_iq4nl[xi0];
292
- const float v1 = kvalues_iq4nl[xi1];
293
- const float w0 = xi[0 + j]*xi[0 + j];
294
- const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j];
295
- sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j];
296
- sumq2 += w0*v0*v0 + w1*v1*v1;
297
- }
298
-
299
- dsti->d = sumq2 > 0 ? sumqx/sumq2 : d;
300
- }
301
-
302
65
  template <cpy_kernel_t cpy_blck, int qk>
303
66
  static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
304
67
  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -358,7 +121,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int
358
121
  // Copy destination pointers to GPU to be available when pointer indirection is in use
359
122
 
360
123
  void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
361
- #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
124
+ #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
362
125
  if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
363
126
  CUDA_CHECK(cudaStreamSynchronize(stream));
364
127
  if (cuda_graph->dest_ptrs_d != nullptr) {
@@ -376,43 +139,14 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
376
139
  #endif
377
140
  }
378
141
 
379
- static void ggml_cpy_f16_f32_cuda(
380
- const char * cx, char * cdst, const int ne,
381
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
382
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
383
-
384
- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
385
- cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
386
- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
387
- }
388
-
389
- static void ggml_cpy_f32_f32_cuda(
142
+ template<typename src_t, typename dst_t>
143
+ static void ggml_cpy_flt_cuda(
390
144
  const char * cx, char * cdst, const int ne,
391
145
  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
392
146
  const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
393
147
 
394
148
  const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
395
- cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
396
- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
397
- }
398
-
399
- static void ggml_cpy_f32_bf16_cuda(
400
- const char * cx, char * cdst, const int ne,
401
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
402
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
403
-
404
- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
405
- cpy_f32_f16<cpy_1_f32_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
406
- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
407
- }
408
-
409
- static void ggml_cpy_f32_f16_cuda(
410
- const char * cx, char * cdst, const int ne,
411
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
412
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
413
-
414
- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
415
- cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
149
+ cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
416
150
  (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
417
151
  }
418
152
 
@@ -544,16 +278,6 @@ static void ggml_cpy_f32_iq4_nl_cuda(
544
278
  (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
545
279
  }
546
280
 
547
- static void ggml_cpy_f16_f16_cuda(
548
- const char * cx, char * cdst, const int ne,
549
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
550
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
551
-
552
- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
553
- cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
554
- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
555
- }
556
-
557
281
  void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
558
282
  const int64_t ne = ggml_nelements(src0);
559
283
  GGML_ASSERT(ne == ggml_nelements(src1));
@@ -590,7 +314,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
590
314
 
591
315
  char ** dest_ptrs_d = nullptr;
592
316
  int graph_cpynode_index = -1;
593
- #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
317
+ #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
594
318
  if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
595
319
  dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
596
320
  graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
@@ -600,20 +324,20 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
600
324
  #endif
601
325
  if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
602
326
  GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
603
- #ifdef GGML_USE_MUSA
327
+ #if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
604
328
  if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
605
329
  CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));
606
330
  } else
607
- #endif // GGML_USE_MUSA
331
+ #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
608
332
  {
609
333
  CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
610
334
  }
611
335
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
612
- ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
336
+ ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
613
337
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
614
- ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
338
+ ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
615
339
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
616
- ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
340
+ ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
617
341
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
618
342
  ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
619
343
  } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -640,14 +364,22 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
640
364
  } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
641
365
  ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
642
366
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
643
- ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
367
+ ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
368
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
369
+ ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
644
370
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
645
- ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
371
+ ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
372
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
373
+ ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
374
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
375
+ ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
376
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
377
+ ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
646
378
  } else {
647
379
  GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
648
380
  ggml_type_name(src0->type), ggml_type_name(src1->type));
649
381
  }
650
- #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
382
+ #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
651
383
  if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
652
384
  ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
653
385
  }
@@ -667,11 +399,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
667
399
  if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
668
400
  return nullptr;
669
401
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
670
- return (void*) cpy_f32_f16<cpy_1_f32_f32>;
402
+ return (void*) cpy_flt<cpy_1_flt<float, float>>;
671
403
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
672
- return (void*) cpy_f32_f16<cpy_1_f32_bf16>;
404
+ return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
673
405
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
674
- return (void*) cpy_f32_f16<cpy_1_f32_f16>;
406
+ return (void*) cpy_flt<cpy_1_flt<float, half>>;
675
407
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
676
408
  return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
677
409
  } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -695,9 +427,17 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
695
427
  } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
696
428
  return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
697
429
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
698
- return (void*) cpy_f32_f16<cpy_1_f32_f16>;
430
+ return (void*) cpy_flt<cpy_1_flt<half, half>>;
431
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
432
+ return (void*) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
699
433
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
700
- return (void*) cpy_f32_f16<cpy_1_f16_f32>;
434
+ return (void*) cpy_flt<cpy_1_flt<half, float>>;
435
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
436
+ return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
437
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
438
+ return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
439
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
440
+ return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
701
441
  } else {
702
442
  GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
703
443
  ggml_type_name(src0->type), ggml_type_name(src1->type));