@novastera-oss/llamarn 0.2.9 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  5. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  9. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  12. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  17. package/cpp/build-info.cpp +2 -2
  18. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  19. package/cpp/llama.cpp/README.md +4 -5
  20. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  21. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  22. package/cpp/llama.cpp/common/arg.cpp +17 -0
  23. package/cpp/llama.cpp/common/chat.cpp +37 -20
  24. package/cpp/llama.cpp/common/chat.h +2 -0
  25. package/cpp/llama.cpp/common/common.h +4 -0
  26. package/cpp/llama.cpp/convert_hf_to_gguf.py +745 -6
  27. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  28. package/cpp/llama.cpp/ggml/CMakeLists.txt +7 -2
  29. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  30. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  31. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +0 -1
  32. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  33. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  34. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  35. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1203 -163
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +33 -9
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  43. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +17 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  47. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +8 -6
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +185 -79
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  66. package/cpp/llama.cpp/ggml/src/ggml-impl.h +64 -0
  67. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  68. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +35 -9
  69. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +167 -39
  70. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +254 -57
  71. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +505 -40
  73. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  83. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  84. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +60 -9
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +711 -292
  92. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  93. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  94. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  95. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  105. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  106. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  112. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  113. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  114. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  116. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  117. package/cpp/llama.cpp/ggml/src/ggml.c +382 -61
  118. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  119. package/cpp/llama.cpp/gguf-py/gguf/constants.py +209 -0
  120. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  121. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +73 -21
  122. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  123. package/cpp/llama.cpp/include/llama.h +0 -40
  124. package/cpp/llama.cpp/src/llama-arch.cpp +210 -3
  125. package/cpp/llama.cpp/src/llama-arch.h +18 -1
  126. package/cpp/llama.cpp/src/llama-batch.cpp +27 -1
  127. package/cpp/llama.cpp/src/llama-batch.h +8 -1
  128. package/cpp/llama.cpp/src/llama-chat.cpp +15 -0
  129. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  130. package/cpp/llama.cpp/src/llama-graph.cpp +119 -184
  131. package/cpp/llama.cpp/src/llama-graph.h +47 -60
  132. package/cpp/llama.cpp/src/llama-hparams.cpp +7 -1
  133. package/cpp/llama.cpp/src/llama-hparams.h +3 -0
  134. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +28 -18
  135. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +4 -2
  136. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +214 -65
  137. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +62 -24
  138. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  139. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +9 -4
  140. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  141. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +20 -10
  142. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  143. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  144. package/cpp/llama.cpp/src/llama-model.cpp +2530 -685
  145. package/cpp/llama.cpp/src/llama-model.h +18 -0
  146. package/cpp/llama.cpp/src/llama-quant.cpp +1 -0
  147. package/cpp/llama.cpp/src/llama-vocab.cpp +13 -2
  148. package/cpp/llama.cpp/src/llama-vocab.h +41 -0
  149. package/ios/include/chat.h +2 -0
  150. package/ios/include/common.h +4 -0
  151. package/ios/include/llama.h +0 -40
  152. package/ios/libs/llama.xcframework/Info.plist +19 -19
  153. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  154. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5055 -4886
  155. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  156. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  157. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +0 -40
  158. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  159. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  160. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  161. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
  162. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  163. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  164. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  165. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  166. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  167. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  168. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3766
  169. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  170. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  171. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +0 -40
  172. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  173. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  174. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +0 -40
  175. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  176. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  177. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  178. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +0 -40
  179. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  180. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  181. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  182. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4890
  183. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  184. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  185. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +0 -40
  186. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  187. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  188. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  189. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
  190. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  191. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  192. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  193. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  194. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  195. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5091 -4922
  196. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  197. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  198. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +0 -40
  199. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  200. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  201. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4897
  202. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3794
  203. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  204. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  205. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  206. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  207. package/package.json +1 -1
  208. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  209. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  210. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  211. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  212. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  213. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  214. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  215. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  216. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  217. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  218. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  219. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  220. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  221. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  222. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  223. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  224. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  225. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  226. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  227. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  228. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  229. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  230. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  231. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  232. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  233. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  234. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  235. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  236. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  237. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  238. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  239. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  240. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  241. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  242. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  243. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  244. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  245. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  246. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  247. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -1,18 +1,18 @@
1
1
  #include "scale.cuh"
2
2
 
3
- static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
3
+ static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) {
4
4
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
5
5
 
6
6
  if (i >= k) {
7
7
  return;
8
8
  }
9
9
 
10
- dst[i] = scale * x[i];
10
+ dst[i] = scale * x[i] + bias;
11
11
  }
12
12
 
13
- static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
13
+ static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) {
14
14
  const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
15
- scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
15
+ scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, k);
16
16
  }
17
17
 
18
18
  void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -25,7 +25,9 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
25
25
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
26
26
 
27
27
  float scale;
28
- memcpy(&scale, dst->op_params, sizeof(float));
28
+ float bias;
29
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
30
+ memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
29
31
 
30
- scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
32
+ scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream);
31
33
  }
@@ -2,6 +2,7 @@
2
2
  #include "ggml.h"
3
3
  #include "softmax.cuh"
4
4
  #include <cstdint>
5
+ #include <utility>
5
6
 
6
7
  template <typename T>
7
8
  static __device__ __forceinline__ float t2f32(T val) {
@@ -13,6 +14,29 @@ __device__ float __forceinline__ t2f32<half>(half val) {
13
14
  return __half2float(val);
14
15
  }
15
16
 
17
+ struct soft_max_params {
18
+
19
+ int64_t nheads;
20
+ uint32_t n_head_log2;
21
+ int64_t ncols;
22
+ int64_t nrows_x;
23
+ int64_t nrows_y;
24
+ int64_t ne00;
25
+ int64_t ne01;
26
+ int64_t ne02;
27
+ int64_t ne03;
28
+ int64_t nb11;
29
+ int64_t nb12;
30
+ int64_t nb13;
31
+
32
+ int64_t ne12;
33
+ int64_t ne13;
34
+ float scale;
35
+ float max_bias;
36
+ float m0;
37
+ float m1;
38
+ };
39
+
16
40
  // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
17
41
  // As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
18
42
  #ifdef __clang__
@@ -21,16 +45,24 @@ __device__ float __forceinline__ t2f32<half>(half val) {
21
45
  #endif // __clang__
22
46
  template <bool use_shared, int ncols_template, int block_size_template, typename T>
23
47
  static __global__ void soft_max_f32(
24
- const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
25
- const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
26
- const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
48
+ const float * x, const T * mask, float * dst, const soft_max_params p) {
49
+ const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
27
50
 
28
51
  const int tid = threadIdx.x;
29
- const int rowx = blockIdx.x;
30
- const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
52
+
53
+ const int64_t i03 = blockIdx.z;
54
+ const int64_t i02 = blockIdx.y;
55
+ const int64_t i01 = blockIdx.x;
56
+
57
+ //TODO: noncontigous inputs/outputs
58
+ const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
59
+
60
+ const int64_t i11 = i01;
61
+ const int64_t i12 = i02 % p.ne12;
62
+ const int64_t i13 = i03 % p.ne13;
31
63
 
32
64
  x += int64_t(rowx)*ncols;
33
- mask += int64_t(rowy)*ncols * (mask != nullptr);
65
+ mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
34
66
  dst += int64_t(rowx)*ncols;
35
67
 
36
68
  const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
@@ -38,7 +70,7 @@ static __global__ void soft_max_f32(
38
70
  const int warp_id = threadIdx.x / WARP_SIZE;
39
71
  const int lane_id = threadIdx.x % WARP_SIZE;
40
72
 
41
- const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
73
+ const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
42
74
 
43
75
  extern __shared__ float data_soft_max_f32[];
44
76
  float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
@@ -55,7 +87,7 @@ static __global__ void soft_max_f32(
55
87
  break;
56
88
  }
57
89
 
58
- const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
90
+ const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
59
91
 
60
92
  vals[col] = val;
61
93
  max_val = max(max_val, val);
@@ -150,64 +182,58 @@ static __global__ void soft_max_back_f32(
150
182
  }
151
183
  }
152
184
 
185
+ template<int... Ns, typename T>
186
+ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
187
+ const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
188
+ {
189
+ const int id = ggml_cuda_get_device();
190
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
191
+
192
+ auto launch_kernel = [=](auto I) -> bool {
193
+ constexpr int ncols = decltype(I)::value;
194
+ constexpr int block = (ncols > 1024 ? 1024 : ncols);
195
+
196
+ if (p.ncols == ncols) {
197
+ CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
198
+ soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
199
+ (x, mask, dst, p);
200
+ return true;
201
+ }
202
+ return false;
203
+ };
204
+
205
+ // unary fold over launch_kernel
206
+ if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
207
+ return;
208
+ }
209
+
210
+ //default case
211
+ CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
212
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
213
+ }
214
+
215
+
153
216
  template<typename T>
154
- static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
217
+ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
155
218
  int nth = WARP_SIZE;
219
+ const int64_t ncols_x = params.ncols;
220
+
156
221
  while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
157
222
  const dim3 block_dims(nth, 1, 1);
158
- const dim3 block_nums(nrows_x, 1, 1);
223
+ const dim3 block_nums(params.ne01, params.ne02, params.ne03);
159
224
  const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
160
225
  static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
161
226
 
162
- const uint32_t n_head = nrows_x/nrows_y;
163
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
164
227
 
165
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
166
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
228
+ const int id = ggml_cuda_get_device();
229
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
167
230
 
168
- // FIXME: this limit could be raised by ~2-4x on Ampere or newer
169
- if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
170
- switch (ncols_x) {
171
- case 32:
172
- soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
173
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
174
- break;
175
- case 64:
176
- soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
177
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
178
- break;
179
- case 128:
180
- soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
181
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
182
- break;
183
- case 256:
184
- soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
185
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
186
- break;
187
- case 512:
188
- soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
189
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
190
- break;
191
- case 1024:
192
- soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
193
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
194
- break;
195
- case 2048:
196
- soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
197
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
198
- break;
199
- case 4096:
200
- soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
201
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
202
- break;
203
- default:
204
- soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
205
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
206
- break;
207
- }
231
+
232
+ if (nbytes_shared <= smpbo) {
233
+ launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
208
234
  } else {
209
235
  const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
210
- soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
236
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
211
237
  }
212
238
  }
213
239
 
@@ -235,10 +261,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
235
261
 
236
262
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
237
263
 
238
- const int64_t ne00 = src0->ne[0];
239
264
  const int64_t nrows_x = ggml_nrows(src0);
240
265
  const int64_t nrows_y = src0->ne[1];
241
266
 
267
+ const int64_t ne00 = src0->ne[0];
268
+
242
269
  float scale = 1.0f;
243
270
  float max_bias = 0.0f;
244
271
 
@@ -247,10 +274,44 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
247
274
 
248
275
  const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
249
276
 
277
+ const int64_t nb11 = src1 ? src1->nb[1] : 1;
278
+ const int64_t nb12 = src1 ? src1->nb[2] : 1;
279
+ const int64_t nb13 = src1 ? src1->nb[3] : 1;
280
+
281
+ const int64_t ne12 = src1 ? src1->ne[2] : 1;
282
+ const int64_t ne13 = src1 ? src1->ne[3] : 1;
283
+
284
+ const uint32_t n_head = src0->ne[2];
285
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
286
+
287
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
288
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
289
+
290
+
291
+ soft_max_params params = {};
292
+ params.nheads = src0->ne[2];
293
+ params.n_head_log2 = n_head_log2;
294
+ params.ncols = ne00;
295
+ params.nrows_x = nrows_x;
296
+ params.nrows_y = nrows_y;
297
+ params.ne00 = src0->ne[0];
298
+ params.ne01 = src0->ne[1];
299
+ params.ne02 = src0->ne[2];
300
+ params.ne03 = src0->ne[3];
301
+ params.nb11 = nb11;
302
+ params.nb12 = nb12;
303
+ params.nb13 = nb13;
304
+ params.ne12 = ne12;
305
+ params.ne13 = ne13;
306
+ params.scale = scale;
307
+ params.max_bias = max_bias;
308
+ params.m0 = m0;
309
+ params.m1 = m1;
310
+
250
311
  if (use_f16) {
251
- soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
312
+ soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream);
252
313
  } else {
253
- soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
314
+ soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream);
254
315
  }
255
316
  }
256
317
 
@@ -107,8 +107,11 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
107
107
  if (nc == 4) {
108
108
  ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
109
109
  dst, dst_nb0, dst_nb1, dst_nb2, n_t);
110
+ } else if (nc == 3) {
111
+ ssm_conv_f32<threads, 3><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
112
+ dst, dst_nb0, dst_nb1, dst_nb2, n_t);
110
113
  } else {
111
- GGML_ABORT("Only support kernel size = 4 now.");
114
+ GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
112
115
  }
113
116
  } else {
114
117
  if (nc == 4) {
@@ -116,8 +119,13 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
116
119
  dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
117
120
  ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>(
118
121
  src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
122
+ } else if (nc == 3) {
123
+ const int64_t split_n_t = 32;
124
+ dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
125
+ ssm_conv_long_token_f32<threads, 3, split_n_t><<<blocks, threads, 0, stream>>>(
126
+ src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
119
127
  } else {
120
- GGML_ABORT("Only support kernel size = 4 right now.");
128
+ GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
121
129
  }
122
130
  }
123
131
  }
@@ -4,16 +4,15 @@ template <size_t splitD, size_t N>
4
4
  __global__ void __launch_bounds__(splitD, 2)
5
5
  ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
6
6
  const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
7
- const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
8
- const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
9
- const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
10
- float * __restrict__ dst, const int64_t L) {
11
- GGML_UNUSED(src1_nb0);
12
- GGML_UNUSED(src2_nb0);
7
+ const int32_t * __restrict__ src6, float * __restrict__ dst,
8
+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
9
+ const int src2_nb1, const int src2_nb2, const int src3_nb1,
10
+ const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
11
+ const int64_t s_off, const int64_t d_inner, const int64_t L) {
13
12
 
14
13
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
15
- const int bidx = blockIdx.x; // split along B
16
- const int bidy = blockIdx.y; // split along D
14
+ const int bidx = blockIdx.x; // split along B (sequences)
15
+ const int bidy = blockIdx.y; // split along D (d_inner)
17
16
  const int tid = threadIdx.x;
18
17
  const int wid = tid / 32;
19
18
  const int wtid = tid % 32;
@@ -24,23 +23,23 @@ __global__ void __launch_bounds__(splitD, 2)
24
23
  float * smem_A = smem;
25
24
  float * smem_s0 = smem_A + splitD * stride_sA;
26
25
 
27
- const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
28
- const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
26
+ const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
27
+ const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
29
28
  const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
30
29
  const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
31
- const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2));
32
- const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2));
33
- float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
34
- float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
30
+ const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
31
+ const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
32
+ float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
33
+ float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
35
34
 
36
- const int stride_s0 = src0_nb1 / sizeof(float);
37
- const int stride_x = src1_nb1 / sizeof(float);
35
+ const int stride_s0 = src0_nb2 / sizeof(float);
36
+ const int stride_x = src1_nb2 / sizeof(float);
38
37
  const int stride_dt = src2_nb1 / sizeof(float);
39
38
  const int stride_A = src3_nb1 / sizeof(float);
40
- const int stride_B = src4_nb1 / sizeof(float);
41
- const int stride_C = src5_nb1 / sizeof(float);
39
+ const int stride_B = src4_nb2 / sizeof(float);
40
+ const int stride_C = src5_nb2 / sizeof(float);
42
41
  const int stride_s = stride_s0;
43
- const int stride_y = stride_x;
42
+ const int stride_y = d_inner;
44
43
 
45
44
  // can N not be 16? for example 32?
46
45
  if (N == 16) {
@@ -84,24 +83,167 @@ __global__ void __launch_bounds__(splitD, 2)
84
83
  }
85
84
  }
86
85
 
86
+ // assumes as many threads as d_state
87
+ template <int splitH, int d_state>
88
+ __global__ void __launch_bounds__(d_state, 1)
89
+ ssm_scan_f32_group(
90
+ const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
91
+ const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
92
+ const int32_t * __restrict__ src6, float * __restrict__ dst,
93
+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
94
+ const int src2_nb1, const int src2_nb2, const int src3_nb1,
95
+ const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
96
+ const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
97
+
98
+ const int head_idx = (blockIdx.x * splitH) / d_head;
99
+ const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
100
+ const int seq_idx = blockIdx.y;
101
+
102
+ const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
103
+
104
+ const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
105
+ const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
106
+ const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
107
+ const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
108
+ const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
109
+ const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
110
+ float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
111
+ float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
112
+
113
+ // strides across n_seq_tokens
114
+ const int stride_x = src1_nb2 / sizeof(float);
115
+ const int stride_dt = src2_nb1 / sizeof(float);
116
+ const int stride_B = src4_nb2 / sizeof(float);
117
+ const int stride_C = src5_nb2 / sizeof(float);
118
+ const int stride_y = n_head * d_head;
119
+
120
+ float state[splitH];
121
+ // for the parallel accumulation
122
+ __shared__ float stateC[splitH * d_state];
123
+
124
+ #pragma unroll
125
+ for (int j = 0; j < splitH; j++) {
126
+ state[j] = s0_block[j * d_state + threadIdx.x];
127
+ }
128
+
129
+ for (int64_t i = 0; i < n_tok; i++) {
130
+ // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
131
+ // TODO: only calculate B and C once per head group
132
+ // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
133
+ float dt_soft_plus = dt_block[i * stride_dt];
134
+ if (dt_soft_plus <= 20.0f) {
135
+ dt_soft_plus = log1pf(expf(dt_soft_plus));
136
+ }
137
+ const float dA = expf(dt_soft_plus * A_block[0]);
138
+ const float B = B_block[i * stride_B + threadIdx.x];
139
+ const float C = C_block[i * stride_C + threadIdx.x];
140
+
141
+ // across d_head
142
+ #pragma unroll
143
+ for (int j = 0; j < splitH; j++) {
144
+ const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
145
+
146
+ state[j] = (state[j] * dA) + (B * x_dt);
147
+
148
+ stateC[j * d_state + threadIdx.x] = state[j] * C;
149
+ }
150
+
151
+ __syncthreads();
152
+
153
+ // parallel accumulation for stateC
154
+ // TODO: simplify
155
+ {
156
+ static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
157
+ static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
158
+
159
+ // reduce until w matches the warp size
160
+ // TODO: does this work even when the physical warp size is 64?
161
+ #pragma unroll
162
+ for (int w = d_state; w > WARP_SIZE; w >>= 1) {
163
+ // (assuming there are d_state threads)
164
+ #pragma unroll
165
+ for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
166
+ // TODO: check for bank conflicts
167
+ const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
168
+ stateC[k] += stateC[k + (w >> 1)];
169
+
170
+ }
171
+ __syncthreads();
172
+ }
173
+
174
+ static_assert(splitH >= d_state / WARP_SIZE);
175
+
176
+ #pragma unroll
177
+ for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
178
+ float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
179
+ y = warp_reduce_sum(y);
180
+
181
+ // store the above accumulations
182
+ if (threadIdx.x % WARP_SIZE == 0) {
183
+ const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
184
+ y_block[i * stride_y + k] = y;
185
+ }
186
+ }
187
+ }
188
+ }
189
+
190
+ // write back the state
191
+ #pragma unroll
192
+ for (int j = 0; j < splitH; j++) {
193
+ s_block[j * d_state + threadIdx.x] = state[j];
194
+ }
195
+ }
196
+
87
197
  static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
88
- const float * src4, const float * src5, const int src0_nb1, const int src0_nb2,
89
- const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
90
- const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
91
- const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
92
- float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B,
198
+ const float * src4, const float * src5, const int32_t * src6, float * dst,
199
+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
200
+ const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
201
+ const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
202
+ const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
93
203
  cudaStream_t stream) {
94
- const int threads = 128;
95
- // todo: consider D cannot be divided,does this situation exist?
96
- GGML_ASSERT(D % threads == 0);
97
- const dim3 blocks(B, (D + threads - 1) / threads, 1);
98
- const int smem_size = (threads * (N + 1) * 2) * sizeof(float);
99
- if (N == 16) {
100
- ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
101
- src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
102
- src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
204
+ // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
205
+ if (src3_nb1 == sizeof(float)) {
206
+ // Mamba-2
207
+ if (d_state == 128) {
208
+ const int threads = 128;
209
+ GGML_ASSERT(d_state % threads == 0);
210
+ // NOTE: can be any power of two between 4 and 64
211
+ const int splitH = 16;
212
+ GGML_ASSERT(head_dim % splitH == 0);
213
+ const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
214
+ ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
215
+ src0, src1, src2, src3, src4, src5, src6, dst,
216
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
217
+ src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
218
+ } else if (d_state == 256) { // Falcon-H1
219
+ const int threads = 256;
220
+ // NOTE: can be any power of two between 8 and 64
221
+ const int splitH = 16;
222
+ GGML_ASSERT(head_dim % splitH == 0);
223
+ const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
224
+ ssm_scan_f32_group<16, 256><<<blocks, threads, 0, stream>>>(
225
+ src0, src1, src2, src3, src4, src5, src6, dst,
226
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
227
+ src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
228
+ } else {
229
+ GGML_ABORT("doesn't support d_state!=(128 or 256).");
230
+ }
103
231
  } else {
104
- GGML_ABORT("doesn't support N!=16.");
232
+ const int threads = 128;
233
+ // Mamba-1
234
+ GGML_ASSERT(n_head % threads == 0);
235
+ GGML_ASSERT(head_dim == 1);
236
+ GGML_ASSERT(n_group == 1);
237
+ const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
238
+ const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
239
+ if (d_state == 16) {
240
+ ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
241
+ src0, src1, src2, src3, src4, src5, src6, dst,
242
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
243
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
244
+ } else {
245
+ GGML_ABORT("doesn't support d_state!=16.");
246
+ }
105
247
  }
106
248
  }
107
249
 
@@ -112,30 +254,25 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
112
254
  const struct ggml_tensor * src3 = dst->src[3]; // A
113
255
  const struct ggml_tensor * src4 = dst->src[4]; // B
114
256
  const struct ggml_tensor * src5 = dst->src[5]; // C
115
-
116
- // const int64_t d_state = src0->ne[0];
117
- // const int64_t d_inner = src0->ne[1];
118
- // const int64_t l = src1->ne[1];
119
- // const int64_t b = src0->ne[2];
257
+ const struct ggml_tensor * src6 = dst->src[6]; // ids
120
258
 
121
259
  const int64_t nc = src0->ne[0]; // d_state
122
- const int64_t nr = src0->ne[1]; // d_inner
123
- const int64_t n_t = src1->ne[1]; // number of tokens per sequence
124
- const int64_t n_s = src0->ne[2]; // number of sequences in the batch
260
+ const int64_t nr = src0->ne[1]; // head_dim or 1
261
+ const int64_t nh = src1->ne[1]; // n_head
262
+ const int64_t ng = src4->ne[1]; // n_group
263
+ const int64_t n_t = src1->ne[2]; // number of tokens per sequence
264
+ const int64_t n_s = src1->ne[3]; // number of sequences in the batch
265
+
266
+ const int64_t s_off = ggml_nelements(src1) * sizeof(float);
125
267
 
126
- GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
268
+ GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst));
127
269
  GGML_ASSERT(src0->nb[0] == sizeof(float));
128
270
  GGML_ASSERT(src1->nb[0] == sizeof(float));
129
271
  GGML_ASSERT(src2->nb[0] == sizeof(float));
130
272
  GGML_ASSERT(src3->nb[0] == sizeof(float));
131
273
  GGML_ASSERT(src4->nb[0] == sizeof(float));
132
274
  GGML_ASSERT(src5->nb[0] == sizeof(float));
133
- // required for the dot product between s and C
134
- GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
135
- // required for per-sequence offsets for states
136
- GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float));
137
- // required to get correct offset for state destination (i.e. src1->nb[3])
138
- GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float));
275
+ GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
139
276
 
140
277
  const float * src0_d = (const float *) src0->data;
141
278
  const float * src1_d = (const float *) src1->data;
@@ -143,13 +280,16 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
143
280
  const float * src3_d = (const float *) src3->data;
144
281
  const float * src4_d = (const float *) src4->data;
145
282
  const float * src5_d = (const float *) src5->data;
283
+ const int32_t * src6_d = (const int32_t *) src6->data;
146
284
  float * dst_d = (float *) dst->data;
147
285
  cudaStream_t stream = ctx.stream();
148
286
 
149
287
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
288
+ GGML_ASSERT(src6->type == GGML_TYPE_I32);
150
289
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
151
290
 
152
- ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0],
153
- src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1],
154
- src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream);
291
+ ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,
292
+ src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
293
+ src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
294
+ s_off, nc, nr, nh, ng, n_t, n_s, stream);
155
295
  }