@novastera-oss/llamarn 0.2.9 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  5. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  9. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  12. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  17. package/cpp/build-info.cpp +2 -2
  18. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  19. package/cpp/llama.cpp/README.md +4 -5
  20. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  21. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  22. package/cpp/llama.cpp/common/arg.cpp +17 -0
  23. package/cpp/llama.cpp/common/chat.cpp +37 -20
  24. package/cpp/llama.cpp/common/chat.h +2 -0
  25. package/cpp/llama.cpp/common/common.h +4 -0
  26. package/cpp/llama.cpp/convert_hf_to_gguf.py +745 -6
  27. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  28. package/cpp/llama.cpp/ggml/CMakeLists.txt +7 -2
  29. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  30. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  31. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +0 -1
  32. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  33. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  34. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  35. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1203 -163
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +33 -9
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  43. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +17 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  47. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +8 -6
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +185 -79
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  66. package/cpp/llama.cpp/ggml/src/ggml-impl.h +64 -0
  67. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  68. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +35 -9
  69. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +167 -39
  70. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +254 -57
  71. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +505 -40
  73. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  83. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  84. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +60 -9
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +711 -292
  92. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  93. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  94. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  95. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  105. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  106. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  112. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  113. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  114. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  116. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  117. package/cpp/llama.cpp/ggml/src/ggml.c +382 -61
  118. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  119. package/cpp/llama.cpp/gguf-py/gguf/constants.py +209 -0
  120. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  121. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +73 -21
  122. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  123. package/cpp/llama.cpp/include/llama.h +0 -40
  124. package/cpp/llama.cpp/src/llama-arch.cpp +210 -3
  125. package/cpp/llama.cpp/src/llama-arch.h +18 -1
  126. package/cpp/llama.cpp/src/llama-batch.cpp +27 -1
  127. package/cpp/llama.cpp/src/llama-batch.h +8 -1
  128. package/cpp/llama.cpp/src/llama-chat.cpp +15 -0
  129. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  130. package/cpp/llama.cpp/src/llama-graph.cpp +119 -184
  131. package/cpp/llama.cpp/src/llama-graph.h +47 -60
  132. package/cpp/llama.cpp/src/llama-hparams.cpp +7 -1
  133. package/cpp/llama.cpp/src/llama-hparams.h +3 -0
  134. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +28 -18
  135. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +4 -2
  136. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +214 -65
  137. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +62 -24
  138. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  139. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +9 -4
  140. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  141. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +20 -10
  142. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  143. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  144. package/cpp/llama.cpp/src/llama-model.cpp +2530 -685
  145. package/cpp/llama.cpp/src/llama-model.h +18 -0
  146. package/cpp/llama.cpp/src/llama-quant.cpp +1 -0
  147. package/cpp/llama.cpp/src/llama-vocab.cpp +13 -2
  148. package/cpp/llama.cpp/src/llama-vocab.h +41 -0
  149. package/ios/include/chat.h +2 -0
  150. package/ios/include/common.h +4 -0
  151. package/ios/include/llama.h +0 -40
  152. package/ios/libs/llama.xcframework/Info.plist +19 -19
  153. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  154. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5055 -4886
  155. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  156. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  157. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +0 -40
  158. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  159. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  160. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  161. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
  162. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  163. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  164. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  165. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  166. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  167. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  168. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3766
  169. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  170. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  171. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +0 -40
  172. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  173. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  174. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +0 -40
  175. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  176. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  177. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  178. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +0 -40
  179. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  180. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  181. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  182. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4890
  183. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  184. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  185. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +0 -40
  186. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  187. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  188. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  189. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
  190. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  191. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  192. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  193. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  194. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  195. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5091 -4922
  196. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  197. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  198. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +0 -40
  199. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  200. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  201. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4897
  202. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3794
  203. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  204. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  205. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  206. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  207. package/package.json +1 -1
  208. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  209. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  210. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  211. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  212. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  213. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  214. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  215. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  216. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  217. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  218. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  219. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  220. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  221. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  222. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  223. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  224. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  225. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  226. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  227. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  228. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  229. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  230. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  231. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  232. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  233. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  234. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  235. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  236. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  237. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  238. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  239. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  240. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  241. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  242. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  243. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  244. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  245. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  246. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  247. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -20,6 +20,9 @@
20
20
 
21
21
  static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
22
22
 
23
+ // Work buffer size for im2col operations in CONV2D
24
+ #define GGML_IM2COL_WORK_SIZE (16 * 1024 * 1024)
25
+
23
26
  #ifdef __cplusplus
24
27
  extern "C" {
25
28
  #endif
@@ -65,6 +68,7 @@ void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struc
65
68
  void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
66
69
  void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
67
70
  void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
71
+ void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
68
72
  void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
69
73
  void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
70
74
  void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
@@ -94,6 +98,7 @@ void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, st
94
98
  void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
95
99
  void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
96
100
  void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);
101
+ void ggml_compute_forward_glu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
97
102
  void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
98
103
  void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
99
104
  void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
@@ -106,6 +111,7 @@ void ggml_compute_forward_custom(const struct ggml_compute_params * params, stru
106
111
  void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst);
107
112
  void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
108
113
  void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
114
+ void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
109
115
 
110
116
  #ifdef __cplusplus
111
117
  }
@@ -189,7 +189,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
189
189
  #define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
190
190
  #define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
191
191
  #define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
192
- #define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c)
192
+ #define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
193
193
  #define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
194
194
  #define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
195
195
  #define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
@@ -37,35 +37,35 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
37
37
  for (int i = 0; i < np; i += ggml_f32_step) {
38
38
  ax1 = GGML_F32_VEC_LOAD(x + i);
39
39
  ay1 = GGML_F32_VEC_LOAD(y + i);
40
- sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
40
+ sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
41
41
 
42
42
  ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
43
43
  ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
44
- sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2);
44
+ sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);
45
45
 
46
46
  ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
47
47
  ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
48
- sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3);
48
+ sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);
49
49
 
50
50
  ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
51
51
  ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
52
- sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4);
52
+ sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);
53
53
 
54
54
  ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
55
55
  ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
56
- sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5);
56
+ sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);
57
57
 
58
58
  ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
59
59
  ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
60
- sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6);
60
+ sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);
61
61
 
62
62
  ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
63
63
  ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
64
- sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7);
64
+ sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);
65
65
 
66
66
  ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
67
67
  ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
68
- sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8);
68
+ sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);
69
69
  }
70
70
  // leftovers
71
71
  // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
@@ -73,7 +73,7 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
73
73
  for (int i = np; i < np2; i += ggml_f32_epr) {
74
74
  ax1 = GGML_F32_VEC_LOAD(x + i);
75
75
  ay1 = GGML_F32_VEC_LOAD(y + i);
76
- sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
76
+ sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
77
77
  }
78
78
  // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
79
79
  if (np2 < n) {
@@ -254,6 +254,30 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
254
254
  }
255
255
  }
256
256
 
257
+ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) {
258
+ int i = 0;
259
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
260
+ for (; i + 15 < n; i += 16) {
261
+ _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i)));
262
+ }
263
+ #elif defined(__AVX2__) && defined(__FMA__)
264
+ for (; i + 7 < n; i += 8) {
265
+ _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i)));
266
+ }
267
+ #elif defined(__SSE2__)
268
+ for (; i + 3 < n; i += 4) {
269
+ _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
270
+ }
271
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
272
+ for (; i + 3 < n; i += 4) {
273
+ vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
274
+ }
275
+ #endif
276
+ for (; i < n; ++i) {
277
+ y[i] = ggml_silu_f32(x[i]) * g[i];
278
+ }
279
+ }
280
+
257
281
  ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
258
282
  int i = 0;
259
283
  ggml_float sum = 0;
@@ -163,49 +163,49 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
163
163
 
164
164
  ax1 = GGML_F32_VEC_LOAD(x + i);
165
165
  ay1 = GGML_F32_VEC_LOAD(y + i);
166
- ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
166
+ ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
167
167
 
168
168
  GGML_F32_VEC_STORE(y + i, ay1);
169
169
 
170
170
  ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
171
171
  ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
172
- ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2);
172
+ ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);
173
173
 
174
174
  GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
175
175
 
176
176
  ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
177
177
  ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
178
- ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3);
178
+ ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);
179
179
 
180
180
  GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
181
181
 
182
182
  ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
183
183
  ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
184
- ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4);
184
+ ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);
185
185
 
186
186
  GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
187
187
 
188
188
  ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
189
189
  ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
190
- ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5);
190
+ ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);
191
191
 
192
192
  GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
193
193
 
194
194
  ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
195
195
  ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
196
- ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6);
196
+ ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);
197
197
 
198
198
  GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
199
199
 
200
200
  ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
201
201
  ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
202
- ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7);
202
+ ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);
203
203
 
204
204
  GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
205
205
 
206
206
  ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
207
207
  ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
208
- ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8);
208
+ ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);
209
209
 
210
210
  GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
211
211
  }
@@ -215,7 +215,7 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
215
215
  for (int i = np; i < np2; i += ggml_f32_epr) {
216
216
  ax1 = GGML_F32_VEC_LOAD(x + i);
217
217
  ay1 = GGML_F32_VEC_LOAD(y + i);
218
- ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
218
+ ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
219
219
 
220
220
  GGML_F32_VEC_STORE(y + i, ay1);
221
221
  }
@@ -351,6 +351,45 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int
351
351
  #endif
352
352
  }
353
353
 
354
+ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
355
+ #if defined(GGML_USE_ACCELERATE)
356
+ vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
357
+ #elif defined(GGML_SIMD)
358
+ #if defined(__ARM_FEATURE_SVE)
359
+ // scalar ; TODO: Write SVE code
360
+ for (int i = 0; i < n; ++i) {
361
+ y[i] = x[i]*s + b;
362
+ }
363
+ #else
364
+ const int np = (n & ~(GGML_F32_STEP - 1));
365
+
366
+ GGML_F32_VEC vs = GGML_F32_VEC_SET1(s);
367
+ GGML_F32_VEC vb = GGML_F32_VEC_SET1(b);
368
+
369
+ GGML_F32_VEC ay[GGML_F32_ARR];
370
+
371
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
372
+ for (int j = 0; j < GGML_F32_ARR; j++) {
373
+ ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
374
+ ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb);
375
+
376
+ GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
377
+ }
378
+ }
379
+
380
+ // leftovers
381
+ for (int i = np; i < n; ++i) {
382
+ y[i] = x[i]*s + b;
383
+ }
384
+ #endif
385
+ #else
386
+ // scalar
387
+ for (int i = 0; i < n; ++i) {
388
+ y[i] = x[i]*s + b;
389
+ }
390
+ #endif
391
+ }
392
+
354
393
  //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
355
394
  inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
356
395
  #if defined(GGML_USE_ACCELERATE)
@@ -905,6 +944,100 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con
905
944
  }
906
945
  }
907
946
 
947
+ inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) {
948
+ for (int i = 0; i < n; ++i) {
949
+ y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f;
950
+ }
951
+ }
952
+
953
+ inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
954
+ for (int i = 0; i < n; ++i) {
955
+ float v = GGML_CPU_FP16_TO_FP32(x[i]);
956
+ y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v * GGML_CPU_FP16_TO_FP32(g[i]) : 0.f);
957
+ }
958
+ }
959
+
960
+ #ifdef GGML_GELU_FP16
961
+ inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
962
+ uint16_t t;
963
+ for (int i = 0; i < n; ++i) {
964
+ if (x[i] <= -10.0f) {
965
+ y[i] = 0.0f;
966
+ } else if (x[i] >= 10.0f) {
967
+ y[i] = x[i] * g[i];
968
+ } else {
969
+ ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
970
+ memcpy(&t, &fp16, sizeof(uint16_t));
971
+ y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i];
972
+ }
973
+ }
974
+ }
975
+ #else
976
+ inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
977
+ for (int i = 0; i < n; ++i) {
978
+ y[i] = ggml_gelu_f32(x[i]) * g[i];
979
+ }
980
+ }
981
+ #endif
982
+
983
+ inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
984
+ const uint16_t * i16 = (const uint16_t *) x;
985
+ for (int i = 0; i < n; ++i) {
986
+ float v = GGML_CPU_FP16_TO_FP32(g[i]);
987
+ y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v);
988
+ }
989
+ }
990
+
991
+ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g);
992
+
993
+ inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
994
+ for (int i = 0; i < n; ++i) {
995
+ float v = GGML_CPU_FP16_TO_FP32(x[i]);
996
+ float w = GGML_CPU_FP16_TO_FP32(g[i]);
997
+ y[i] = GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
998
+ }
999
+ }
1000
+
1001
+ inline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {
1002
+ for (int i = 0; i < n; ++i) {
1003
+ float xi = x[i];
1004
+ y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];
1005
+ }
1006
+ }
1007
+
1008
+ inline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
1009
+ for (int i = 0; i < n; ++i) {
1010
+ float xi = GGML_CPU_FP16_TO_FP32(x[i]);
1011
+ float gi = GGML_CPU_FP16_TO_FP32(g[i]);
1012
+ y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);
1013
+ }
1014
+ }
1015
+
1016
+ #ifdef GGML_GELU_QUICK_FP16
1017
+ inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
1018
+ uint16_t t;
1019
+ for (int i = 0; i < n; ++i) {
1020
+ ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
1021
+ memcpy(&t, &fp16, sizeof(uint16_t));
1022
+ y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i];
1023
+ }
1024
+ }
1025
+ #else
1026
+ inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
1027
+ for (int i = 0; i < n; ++i) {
1028
+ y[i] = ggml_gelu_quick_f32(x[i]) * g[i];
1029
+ }
1030
+ }
1031
+ #endif
1032
+
1033
+ inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
1034
+ const uint16_t * i16 = (const uint16_t *) x;
1035
+ for (int i = 0; i < n; ++i) {
1036
+ float v = GGML_CPU_FP16_TO_FP32(g[i]);
1037
+ y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v);
1038
+ }
1039
+ }
1040
+
908
1041
  inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
909
1042
  #ifndef GGML_USE_ACCELERATE
910
1043
  ggml_float sum = 0.0;
@@ -175,6 +175,23 @@ static const char * cu_get_error_str(CUresult err) {
175
175
  #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
176
176
  #endif
177
177
 
178
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
179
+ # define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
180
+ do { \
181
+ static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \
182
+ const int id = ggml_cuda_get_device(); \
183
+ if (!shared_memory_limit_raised[id]) { \
184
+ CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
185
+ shared_memory_limit_raised[id] = true; \
186
+ } \
187
+ } while (0)
188
+ #else
189
+ # define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
190
+ do { \
191
+ GGML_UNUSED(nbytes); \
192
+ } while (0)
193
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
194
+
178
195
  #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
179
196
  #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
180
197
  #else
@@ -728,3 +728,25 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
728
728
  return nullptr;
729
729
  }
730
730
  }
731
+
732
+ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
733
+ switch (type) {
734
+ case GGML_TYPE_F32:
735
+ return convert_unary_cuda<float, nv_bfloat16>;
736
+ case GGML_TYPE_F16:
737
+ return convert_unary_cuda<half, nv_bfloat16>;
738
+ default:
739
+ return nullptr;
740
+ }
741
+ }
742
+
743
+ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
744
+ switch (type) {
745
+ case GGML_TYPE_F16:
746
+ return convert_unary_cuda<half, float>;
747
+ case GGML_TYPE_BF16:
748
+ return convert_unary_cuda<nv_bfloat16, float>;
749
+ default:
750
+ return nullptr;
751
+ }
752
+ }
@@ -22,5 +22,10 @@ using to_t_nc_cuda_t = void (*)(const void * x, T * y,
22
22
  int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
23
23
  int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
24
24
 
25
+ typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t;
25
26
  typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
27
+ typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
28
+
29
+ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
26
30
  to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
31
+ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
@@ -123,13 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
123
123
  ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
124
124
 
125
125
  if (nbytes_shared <= smpbo) {
126
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
127
- static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
128
- if (!shared_memory_limit_raised[id]) {
129
- CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
130
- shared_memory_limit_raised[id] = true;
131
- }
132
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
126
+ CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo);
133
127
  cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
134
128
  } else {
135
129
  cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
@@ -175,13 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
175
169
  const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
176
170
 
177
171
  if (nbytes_shared <= smpbo) {
178
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
179
- static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
180
- if (!shared_memory_limit_raised[id]) {
181
- CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
182
- shared_memory_limit_raised[id] = true;
183
- }
184
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
172
+ CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);
185
173
  cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
186
174
  } else {
187
175
  cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
@@ -32,7 +32,9 @@ typedef void (* fattn_kernel_t)(
32
32
  const int ne12,
33
33
  const int ne13,
34
34
  const int ne31,
35
+ const int ne32,
35
36
  const int nb31,
37
+ const int nb32,
36
38
  const int nb01,
37
39
  const int nb02,
38
40
  const int nb03,
@@ -851,7 +853,8 @@ void launch_fattn(
851
853
  scale, max_bias, m0, m1, n_head_log2, logit_softcap,
852
854
  Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
853
855
  K->ne[0], K->ne[1], K->ne[2], K->ne[3],
854
- mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
856
+ mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
857
+ mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
855
858
  Q->nb[1], Q->nb[2], Q->nb[3],
856
859
  nb11, nb12, nb13,
857
860
  nb21, nb22, nb23,
@@ -1223,7 +1223,9 @@ static __global__ void flash_attn_ext_f16(
1223
1223
  const int ne12,
1224
1224
  const int ne13,
1225
1225
  const int ne31,
1226
+ const int ne32,
1226
1227
  const int nb31,
1228
+ const int nb32,
1227
1229
  const int nb01,
1228
1230
  const int nb02,
1229
1231
  const int nb03,
@@ -1288,7 +1290,8 @@ static __global__ void flash_attn_ext_f16(
1288
1290
 
1289
1291
  const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1290
1292
  const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1291
- const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
1293
+ const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1294
+ (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
1292
1295
  float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1293
1296
 
1294
1297
  const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@@ -1327,7 +1330,8 @@ static __global__ void flash_attn_ext_f16(
1327
1330
 
1328
1331
  const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1329
1332
  const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1330
- const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
1333
+ const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1334
+ (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
1331
1335
  float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1332
1336
 
1333
1337
  const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@@ -1348,8 +1352,8 @@ static __global__ void flash_attn_ext_f16(
1348
1352
  GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
1349
1353
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
1350
1354
  GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
1351
- GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
1352
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
1355
+ GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
1356
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
1353
1357
  GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
1354
1358
  GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
1355
1359
  GGML_UNUSED(ne2); GGML_UNUSED(ne3);
@@ -6,7 +6,7 @@
6
6
 
7
7
  template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
8
8
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
9
- __launch_bounds__(nwarps*WARP_SIZE, 1)
9
+ __launch_bounds__(nwarps*WARP_SIZE, 2)
10
10
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
11
11
  static __global__ void flash_attn_tile_ext_f16(
12
12
  const char * __restrict__ Q,
@@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f16(
30
30
  const int ne12,
31
31
  const int ne13,
32
32
  const int ne31,
33
+ const int ne32,
33
34
  const int nb31,
35
+ const int nb32,
34
36
  const int nb01,
35
37
  const int nb02,
36
38
  const int nb03,
@@ -64,7 +66,7 @@ static __global__ void flash_attn_tile_ext_f16(
64
66
  const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
65
67
  const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
66
68
  const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
67
- const half * maskh = (const half *) mask + ne11*ic0;
69
+ const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
68
70
 
69
71
  const int stride_KV2 = nb11 / sizeof(half2);
70
72
 
@@ -288,8 +290,8 @@ static __global__ void flash_attn_tile_ext_f16(
288
290
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
289
291
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
290
292
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
291
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
292
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
293
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
294
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
293
295
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
294
296
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
295
297
  GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@@ -6,7 +6,7 @@
6
6
 
7
7
  template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
8
8
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
9
- __launch_bounds__(nwarps*WARP_SIZE, 1)
9
+ __launch_bounds__(nwarps*WARP_SIZE, 2)
10
10
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
11
11
  static __global__ void flash_attn_tile_ext_f32(
12
12
  const char * __restrict__ Q,
@@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f32(
30
30
  const int ne12,
31
31
  const int ne13,
32
32
  const int ne31,
33
+ const int ne32,
33
34
  const int nb31,
35
+ const int nb32,
34
36
  const int nb01,
35
37
  const int nb02,
36
38
  const int nb03,
@@ -58,8 +60,8 @@ static __global__ void flash_attn_tile_ext_f32(
58
60
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
59
61
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
60
62
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
61
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
62
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
63
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
64
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
63
65
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
64
66
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
65
67
  GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@@ -76,7 +78,7 @@ static __global__ void flash_attn_tile_ext_f32(
76
78
  const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
77
79
  const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
78
80
  const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
79
- const half * maskh = (const half *) mask + ne11*ic0;
81
+ const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
80
82
 
81
83
  const int stride_KV2 = nb11 / sizeof(half2);
82
84
 
@@ -297,14 +299,14 @@ static __global__ void flash_attn_tile_ext_f32(
297
299
  GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
298
300
  GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
299
301
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
300
- GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
301
- GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
302
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
303
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
304
- GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
305
- GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
306
- GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
307
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
302
+ GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
303
+ GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
304
+ GGML_UNUSED(ne31); GGML_UNUSED(ne32);
305
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32);
306
+ GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
307
+ GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
308
+ GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
309
+ GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
308
310
  NO_DEVICE_CODE;
309
311
  #endif // FLASH_ATTN_AVAILABLE
310
312
  }
@@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f16(
27
27
  const int ne12,
28
28
  const int ne13,
29
29
  const int ne31,
30
+ const int ne32,
30
31
  const int nb31,
32
+ const int nb32,
31
33
  const int nb01,
32
34
  const int nb02,
33
35
  const int nb03,
@@ -68,7 +70,7 @@ static __global__ void flash_attn_vec_ext_f16(
68
70
  K += nb12*(blockIdx.z / gqa_ratio);
69
71
  V += nb22*(blockIdx.z / gqa_ratio);
70
72
 
71
- const half * maskh = (const half *) mask + ne11*ic0;
73
+ const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
72
74
 
73
75
  const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
74
76
  const half slopeh = __float2half(slopef);
@@ -342,8 +344,8 @@ static __global__ void flash_attn_vec_ext_f16(
342
344
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
343
345
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
344
346
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
345
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
346
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
347
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
348
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
347
349
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
348
350
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
349
351
  GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);