@novastera-oss/llamarn 0.2.9 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (314) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  22. package/cpp/llama.cpp/CMakePresets.json +11 -0
  23. package/cpp/llama.cpp/CODEOWNERS +1 -0
  24. package/cpp/llama.cpp/README.md +8 -8
  25. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  26. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  27. package/cpp/llama.cpp/common/arg.cpp +62 -1
  28. package/cpp/llama.cpp/common/chat.cpp +37 -20
  29. package/cpp/llama.cpp/common/chat.h +2 -0
  30. package/cpp/llama.cpp/common/common.cpp +22 -6
  31. package/cpp/llama.cpp/common/common.h +22 -4
  32. package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
  33. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
  34. package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
  35. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  36. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  37. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  38. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  39. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  40. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  41. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
  42. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  44. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
  45. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  46. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  97. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
  99. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
  100. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  131. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
  132. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  138. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  139. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  140. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  141. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  142. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  143. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  144. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  168. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  169. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  170. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  173. package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  177. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  178. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  179. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
  180. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  181. package/cpp/llama.cpp/include/llama.h +15 -47
  182. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  183. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  184. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  185. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  186. package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
  187. package/cpp/llama.cpp/src/llama-arch.h +23 -1
  188. package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
  189. package/cpp/llama.cpp/src/llama-batch.h +31 -18
  190. package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
  191. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  192. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  193. package/cpp/llama.cpp/src/llama-context.h +26 -16
  194. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  195. package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
  196. package/cpp/llama.cpp/src/llama-graph.h +184 -122
  197. package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
  198. package/cpp/llama.cpp/src/llama-hparams.h +13 -2
  199. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  200. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  201. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  202. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  203. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  204. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  205. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  206. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
  207. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  208. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  209. package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
  210. package/cpp/llama.cpp/src/llama-model.h +21 -4
  211. package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
  212. package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
  213. package/cpp/llama.cpp/src/llama-vocab.h +43 -0
  214. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  215. package/cpp/llama.cpp/src/unicode.h +2 -0
  216. package/ios/include/chat.h +2 -0
  217. package/ios/include/common.h +22 -4
  218. package/ios/include/llama.h +15 -47
  219. package/ios/libs/llama.xcframework/Info.plist +13 -13
  220. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  221. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  223. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  224. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
  225. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  231. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  232. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  248. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  250. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  254. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  255. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  261. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  262. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
  263. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  264. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  265. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  267. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  268. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
  269. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
  270. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  274. package/package.json +4 -4
  275. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  276. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  277. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  278. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  279. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  280. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -20,6 +20,9 @@
20
20
 
21
21
  static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
22
22
 
23
+ // Work buffer size for im2col operations in CONV2D
24
+ #define GGML_IM2COL_WORK_SIZE (16 * 1024 * 1024)
25
+
23
26
  #ifdef __cplusplus
24
27
  extern "C" {
25
28
  #endif
@@ -65,6 +68,7 @@ void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struc
65
68
  void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
66
69
  void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
67
70
  void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
71
+ void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
68
72
  void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
69
73
  void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
70
74
  void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
@@ -94,6 +98,7 @@ void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, st
94
98
  void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
95
99
  void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
96
100
  void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);
101
+ void ggml_compute_forward_glu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
97
102
  void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
98
103
  void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
99
104
  void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
@@ -106,6 +111,7 @@ void ggml_compute_forward_custom(const struct ggml_compute_params * params, stru
106
111
  void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst);
107
112
  void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
108
113
  void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
114
+ void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
109
115
 
110
116
  #ifdef __cplusplus
111
117
  }
@@ -14,7 +14,6 @@
14
14
  #include <cmath>
15
15
  #include <cstring>
16
16
  #include <cassert>
17
- #include <cstdlib> // for qsort
18
17
  #include <cstdio> // for GGML_ASSERT
19
18
 
20
19
  #include "repack.h"
@@ -189,7 +189,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
189
189
  #define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
190
190
  #define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
191
191
  #define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
192
- #define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c)
192
+ #define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
193
193
  #define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
194
194
  #define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
195
195
  #define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
@@ -37,35 +37,35 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
37
37
  for (int i = 0; i < np; i += ggml_f32_step) {
38
38
  ax1 = GGML_F32_VEC_LOAD(x + i);
39
39
  ay1 = GGML_F32_VEC_LOAD(y + i);
40
- sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
40
+ sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
41
41
 
42
42
  ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
43
43
  ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
44
- sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2);
44
+ sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);
45
45
 
46
46
  ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
47
47
  ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
48
- sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3);
48
+ sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);
49
49
 
50
50
  ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
51
51
  ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
52
- sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4);
52
+ sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);
53
53
 
54
54
  ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
55
55
  ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
56
- sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5);
56
+ sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);
57
57
 
58
58
  ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
59
59
  ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
60
- sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6);
60
+ sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);
61
61
 
62
62
  ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
63
63
  ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
64
- sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7);
64
+ sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);
65
65
 
66
66
  ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
67
67
  ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
68
- sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8);
68
+ sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);
69
69
  }
70
70
  // leftovers
71
71
  // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
@@ -73,7 +73,7 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
73
73
  for (int i = np; i < np2; i += ggml_f32_epr) {
74
74
  ax1 = GGML_F32_VEC_LOAD(x + i);
75
75
  ay1 = GGML_F32_VEC_LOAD(y + i);
76
- sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
76
+ sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
77
77
  }
78
78
  // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
79
79
  if (np2 < n) {
@@ -221,6 +221,9 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
221
221
  for (int i = np; i < n; ++i) {
222
222
  sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
223
223
  }
224
+
225
+ // if you hit this, you are likely running outside the FP range
226
+ assert(!isnan(sumf) && !isinf(sumf));
224
227
  #else
225
228
  for (int i = 0; i < n; ++i) {
226
229
  sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
@@ -254,6 +257,30 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
254
257
  }
255
258
  }
256
259
 
260
+ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) {
261
+ int i = 0;
262
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
263
+ for (; i + 15 < n; i += 16) {
264
+ _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i)));
265
+ }
266
+ #elif defined(__AVX2__) && defined(__FMA__)
267
+ for (; i + 7 < n; i += 8) {
268
+ _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i)));
269
+ }
270
+ #elif defined(__SSE2__)
271
+ for (; i + 3 < n; i += 4) {
272
+ _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
273
+ }
274
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
275
+ for (; i + 3 < n; i += 4) {
276
+ vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
277
+ }
278
+ #endif
279
+ for (; i < n; ++i) {
280
+ y[i] = ggml_silu_f32(x[i]) * g[i];
281
+ }
282
+ }
283
+
257
284
  ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
258
285
  int i = 0;
259
286
  ggml_float sum = 0;
@@ -163,49 +163,49 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
163
163
 
164
164
  ax1 = GGML_F32_VEC_LOAD(x + i);
165
165
  ay1 = GGML_F32_VEC_LOAD(y + i);
166
- ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
166
+ ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
167
167
 
168
168
  GGML_F32_VEC_STORE(y + i, ay1);
169
169
 
170
170
  ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
171
171
  ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
172
- ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2);
172
+ ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);
173
173
 
174
174
  GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
175
175
 
176
176
  ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
177
177
  ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
178
- ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3);
178
+ ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);
179
179
 
180
180
  GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
181
181
 
182
182
  ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
183
183
  ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
184
- ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4);
184
+ ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);
185
185
 
186
186
  GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
187
187
 
188
188
  ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
189
189
  ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
190
- ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5);
190
+ ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);
191
191
 
192
192
  GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
193
193
 
194
194
  ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
195
195
  ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
196
- ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6);
196
+ ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);
197
197
 
198
198
  GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
199
199
 
200
200
  ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
201
201
  ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
202
- ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7);
202
+ ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);
203
203
 
204
204
  GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
205
205
 
206
206
  ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
207
207
  ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
208
- ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8);
208
+ ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);
209
209
 
210
210
  GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
211
211
  }
@@ -215,7 +215,7 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
215
215
  for (int i = np; i < np2; i += ggml_f32_epr) {
216
216
  ax1 = GGML_F32_VEC_LOAD(x + i);
217
217
  ay1 = GGML_F32_VEC_LOAD(y + i);
218
- ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
218
+ ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
219
219
 
220
220
  GGML_F32_VEC_STORE(y + i, ay1);
221
221
  }
@@ -351,6 +351,45 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int
351
351
  #endif
352
352
  }
353
353
 
354
+ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
355
+ #if defined(GGML_USE_ACCELERATE)
356
+ vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
357
+ #elif defined(GGML_SIMD)
358
+ #if defined(__ARM_FEATURE_SVE)
359
+ // scalar ; TODO: Write SVE code
360
+ for (int i = 0; i < n; ++i) {
361
+ y[i] = x[i]*s + b;
362
+ }
363
+ #else
364
+ const int np = (n & ~(GGML_F32_STEP - 1));
365
+
366
+ GGML_F32_VEC vs = GGML_F32_VEC_SET1(s);
367
+ GGML_F32_VEC vb = GGML_F32_VEC_SET1(b);
368
+
369
+ GGML_F32_VEC ay[GGML_F32_ARR];
370
+
371
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
372
+ for (int j = 0; j < GGML_F32_ARR; j++) {
373
+ ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
374
+ ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb);
375
+
376
+ GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
377
+ }
378
+ }
379
+
380
+ // leftovers
381
+ for (int i = np; i < n; ++i) {
382
+ y[i] = x[i]*s + b;
383
+ }
384
+ #endif
385
+ #else
386
+ // scalar
387
+ for (int i = 0; i < n; ++i) {
388
+ y[i] = x[i]*s + b;
389
+ }
390
+ #endif
391
+ }
392
+
354
393
  //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
355
394
  inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
356
395
  #if defined(GGML_USE_ACCELERATE)
@@ -905,6 +944,100 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con
905
944
  }
906
945
  }
907
946
 
947
+ inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) {
948
+ for (int i = 0; i < n; ++i) {
949
+ y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f;
950
+ }
951
+ }
952
+
953
+ inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
954
+ for (int i = 0; i < n; ++i) {
955
+ float v = GGML_CPU_FP16_TO_FP32(x[i]);
956
+ y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v * GGML_CPU_FP16_TO_FP32(g[i]) : 0.f);
957
+ }
958
+ }
959
+
960
+ #ifdef GGML_GELU_FP16
961
+ inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
962
+ uint16_t t;
963
+ for (int i = 0; i < n; ++i) {
964
+ if (x[i] <= -10.0f) {
965
+ y[i] = 0.0f;
966
+ } else if (x[i] >= 10.0f) {
967
+ y[i] = x[i] * g[i];
968
+ } else {
969
+ ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
970
+ memcpy(&t, &fp16, sizeof(uint16_t));
971
+ y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i];
972
+ }
973
+ }
974
+ }
975
+ #else
976
+ inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
977
+ for (int i = 0; i < n; ++i) {
978
+ y[i] = ggml_gelu_f32(x[i]) * g[i];
979
+ }
980
+ }
981
+ #endif
982
+
983
+ inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
984
+ const uint16_t * i16 = (const uint16_t *) x;
985
+ for (int i = 0; i < n; ++i) {
986
+ float v = GGML_CPU_FP16_TO_FP32(g[i]);
987
+ y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v);
988
+ }
989
+ }
990
+
991
+ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g);
992
+
993
+ inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
994
+ for (int i = 0; i < n; ++i) {
995
+ float v = GGML_CPU_FP16_TO_FP32(x[i]);
996
+ float w = GGML_CPU_FP16_TO_FP32(g[i]);
997
+ y[i] = GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
998
+ }
999
+ }
1000
+
1001
+ inline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {
1002
+ for (int i = 0; i < n; ++i) {
1003
+ float xi = x[i];
1004
+ y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];
1005
+ }
1006
+ }
1007
+
1008
+ inline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
1009
+ for (int i = 0; i < n; ++i) {
1010
+ float xi = GGML_CPU_FP16_TO_FP32(x[i]);
1011
+ float gi = GGML_CPU_FP16_TO_FP32(g[i]);
1012
+ y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);
1013
+ }
1014
+ }
1015
+
1016
+ #ifdef GGML_GELU_QUICK_FP16
1017
+ inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
1018
+ uint16_t t;
1019
+ for (int i = 0; i < n; ++i) {
1020
+ ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
1021
+ memcpy(&t, &fp16, sizeof(uint16_t));
1022
+ y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i];
1023
+ }
1024
+ }
1025
+ #else
1026
+ inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
1027
+ for (int i = 0; i < n; ++i) {
1028
+ y[i] = ggml_gelu_quick_f32(x[i]) * g[i];
1029
+ }
1030
+ }
1031
+ #endif
1032
+
1033
+ inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
1034
+ const uint16_t * i16 = (const uint16_t *) x;
1035
+ for (int i = 0; i < n; ++i) {
1036
+ float v = GGML_CPU_FP16_TO_FP32(g[i]);
1037
+ y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v);
1038
+ }
1039
+ }
1040
+
908
1041
  inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
909
1042
  #ifndef GGML_USE_ACCELERATE
910
1043
  ggml_float sum = 0.0;
@@ -102,12 +102,12 @@ if (CUDAToolkit_FOUND)
102
102
  if (GGML_STATIC)
103
103
  if (WIN32)
104
104
  # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
105
- target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
105
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
106
106
  else ()
107
- target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
107
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static)
108
108
  endif()
109
109
  else()
110
- target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt)
110
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
111
111
  endif()
112
112
 
113
113
  if (GGML_CUDA_NO_VMM)
@@ -56,7 +56,7 @@
56
56
  #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
57
57
  #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
58
58
  #define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
59
- #define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
59
+ #define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
60
60
  #define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
61
61
  #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
62
62
 
@@ -72,8 +72,9 @@
72
72
  #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
73
73
  #define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
74
74
  #define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
75
- #define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
76
- #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
75
+ #define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
76
+ #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
77
+ #define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
77
78
 
78
79
  // Moore Threads
79
80
  #define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
@@ -175,6 +176,23 @@ static const char * cu_get_error_str(CUresult err) {
175
176
  #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
176
177
  #endif
177
178
 
179
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
180
+ # define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
181
+ do { \
182
+ static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \
183
+ const int id = ggml_cuda_get_device(); \
184
+ if (!shared_memory_limit_raised[id]) { \
185
+ CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
186
+ shared_memory_limit_raised[id] = true; \
187
+ } \
188
+ } while (0)
189
+ #else
190
+ # define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
191
+ do { \
192
+ GGML_UNUSED(nbytes); \
193
+ } while (0)
194
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
195
+
178
196
  #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
179
197
  #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
180
198
  #else
@@ -209,6 +227,10 @@ typedef float2 dfloat2;
209
227
  #define FP16_MMA_AVAILABLE
210
228
  #endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
211
229
 
230
+ #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3)
231
+ #define AMD_MFMA_AVAILABLE
232
+ #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3)
233
+
212
234
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
213
235
  #define NEW_MMA_AVAILABLE
214
236
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -271,6 +293,11 @@ static bool fp32_mma_hardware_available(const int cc) {
271
293
  return GGML_CUDA_CC_IS_CDNA(cc);
272
294
  }
273
295
 
296
+ // AMD CDNA3 matrix cores.. Will add support for other CDNA generations later.
297
+ static bool amd_mfma_available(const int cc) {
298
+ return cc >= GGML_CUDA_CC_OFFSET_AMD && GGML_CUDA_CC_IS_CDNA3(cc);
299
+ }
300
+
274
301
  // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
275
302
  static bool new_mma_available(const int cc) {
276
303
  return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
@@ -748,7 +775,7 @@ struct ggml_tensor_extra_gpu {
748
775
  };
749
776
 
750
777
 
751
- #if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
778
+ #if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)) || defined(GGML_MUSA_GRAPHS)
752
779
  #define USE_CUDA_GRAPH
753
780
  #endif
754
781
 
@@ -6,24 +6,33 @@
6
6
  #define CUDA_Q8_0_NE_ALIGN 2048
7
7
 
8
8
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
9
- static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
10
- const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
9
+ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
10
+ const int64_t ne00, const int64_t ne01, const int64_t ne02,
11
+ const int64_t s01, const int64_t s02, const int64_t s03) {
12
+ const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
11
13
 
12
- if (i >= k) {
14
+ if (i00 >= ne00) {
13
15
  return;
14
16
  }
15
17
 
16
- const int64_t ib = i/qk; // block index
17
- const int64_t iqs = (i%qk)/qr; // quant index
18
- const int64_t iybs = i - i%qk; // y block start index
18
+ const int64_t i01 = blockIdx.y;
19
+ const int64_t i02 = blockIdx.z % ne02;
20
+ const int64_t i03 = blockIdx.z / ne02;
21
+
22
+ const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
23
+
24
+ const int64_t ib = ibx0 + i00/qk; // block index
25
+ const int64_t iqs = (i00%qk)/qr; // quant index
26
+ const int64_t iybs = i00 - i00%qk; // y block start index
19
27
  const int64_t y_offset = qr == 1 ? 1 : qk/2;
20
28
 
21
29
  // dequantize
22
30
  dfloat2 v;
23
31
  dequantize_kernel(vx, ib, iqs, v);
24
32
 
25
- y[iybs + iqs + 0] = v.x;
26
- y[iybs + iqs + y_offset] = v.y;
33
+ const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
34
+ y[iy0 + 0] = float(v.x);
35
+ y[iy0 + y_offset] = float(v.y);
27
36
  }
28
37
 
29
38
  template <bool need_check>
@@ -457,9 +466,17 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
457
466
  }
458
467
 
459
468
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
460
- static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
461
- const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
462
- dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
469
+ static void dequantize_block_cuda(const void * vx, dst_t * y,
470
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
471
+ const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
472
+ const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
473
+ dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
474
+ (vx, y, ne00, ne01, ne02, s01, s02, s03);
475
+ }
476
+
477
+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
478
+ static void dequantize_block_cont_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
479
+ dequantize_block_cuda<qk, qr, dequantize_kernel, dst_t>(vx, y, k, 1, 1, 1, k/qk, k/qk, k/qk, stream);
463
480
  }
464
481
 
465
482
  static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
@@ -624,14 +641,14 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
624
641
  case GGML_TYPE_Q4_1:
625
642
  return dequantize_row_q4_1_cuda;
626
643
  case GGML_TYPE_Q5_0:
627
- return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
644
+ return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
628
645
  case GGML_TYPE_Q5_1:
629
- return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
646
+ return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
630
647
  case GGML_TYPE_Q8_0:
631
648
  if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
632
649
  return dequantize_block_q8_0_f16_cuda;
633
650
  }
634
- return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
651
+ return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
635
652
  case GGML_TYPE_Q2_K:
636
653
  return dequantize_row_q2_K_cuda;
637
654
  case GGML_TYPE_Q3_K:
@@ -676,11 +693,11 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
676
693
  case GGML_TYPE_Q4_1:
677
694
  return dequantize_row_q4_1_cuda;
678
695
  case GGML_TYPE_Q5_0:
679
- return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
696
+ return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
680
697
  case GGML_TYPE_Q5_1:
681
- return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
698
+ return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
682
699
  case GGML_TYPE_Q8_0:
683
- return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
700
+ return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
684
701
  case GGML_TYPE_Q2_K:
685
702
  return dequantize_row_q2_K_cuda;
686
703
  case GGML_TYPE_Q3_K:
@@ -722,9 +739,61 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
722
739
  switch (type) {
723
740
  case GGML_TYPE_F32:
724
741
  return convert_unary_cuda<float>;
742
+ case GGML_TYPE_Q4_0:
743
+ return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
744
+ case GGML_TYPE_Q4_1:
745
+ return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
746
+ case GGML_TYPE_Q5_0:
747
+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
748
+ case GGML_TYPE_Q5_1:
749
+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
750
+ case GGML_TYPE_Q8_0:
751
+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
725
752
  case GGML_TYPE_BF16:
726
753
  return convert_unary_cuda<nv_bfloat16>;
727
754
  default:
728
755
  return nullptr;
729
756
  }
730
757
  }
758
+
759
+ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
760
+ switch (type) {
761
+ case GGML_TYPE_F32:
762
+ return convert_unary_cuda<float, nv_bfloat16>;
763
+ case GGML_TYPE_Q4_0:
764
+ return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
765
+ case GGML_TYPE_Q4_1:
766
+ return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
767
+ case GGML_TYPE_Q5_0:
768
+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
769
+ case GGML_TYPE_Q5_1:
770
+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
771
+ case GGML_TYPE_Q8_0:
772
+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
773
+ case GGML_TYPE_F16:
774
+ return convert_unary_cuda<half, nv_bfloat16>;
775
+ default:
776
+ return nullptr;
777
+ }
778
+ }
779
+
780
+ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
781
+ switch (type) {
782
+ case GGML_TYPE_F16:
783
+ return convert_unary_cuda<half, float>;
784
+ case GGML_TYPE_Q4_0:
785
+ return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
786
+ case GGML_TYPE_Q4_1:
787
+ return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
788
+ case GGML_TYPE_Q5_0:
789
+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
790
+ case GGML_TYPE_Q5_1:
791
+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
792
+ case GGML_TYPE_Q8_0:
793
+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
794
+ case GGML_TYPE_BF16:
795
+ return convert_unary_cuda<nv_bfloat16, float>;
796
+ default:
797
+ return nullptr;
798
+ }
799
+ }
@@ -22,5 +22,10 @@ using to_t_nc_cuda_t = void (*)(const void * x, T * y,
22
22
  int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
23
23
  int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
24
24
 
25
+ typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t;
25
26
  typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
27
+ typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
28
+
29
+ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
26
30
  to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
31
+ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);