@novastera-oss/llamarn 0.2.7 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (319) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  13. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  18. package/cpp/LlamaCppModel.cpp +56 -22
  19. package/cpp/build-info.cpp +2 -2
  20. package/cpp/llama.cpp/CMakeLists.txt +1 -2
  21. package/cpp/llama.cpp/README.md +4 -5
  22. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  23. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  24. package/cpp/llama.cpp/common/arg.cpp +24 -0
  25. package/cpp/llama.cpp/common/chat.cpp +37 -20
  26. package/cpp/llama.cpp/common/chat.h +2 -0
  27. package/cpp/llama.cpp/common/common.cpp +3 -0
  28. package/cpp/llama.cpp/common/common.h +5 -0
  29. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  30. package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
  31. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  32. package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
  33. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  34. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  35. package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
  36. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  100. package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
  101. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
  103. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
  104. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
  105. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  133. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  134. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  135. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  136. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  137. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
  138. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  141. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  142. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  144. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  168. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  169. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  170. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  172. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  173. package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
  177. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
  178. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
  179. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  180. package/cpp/llama.cpp/include/llama.h +8 -43
  181. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  182. package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
  183. package/cpp/llama.cpp/src/llama-arch.h +36 -1
  184. package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
  185. package/cpp/llama.cpp/src/llama-batch.h +105 -70
  186. package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
  187. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  188. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  189. package/cpp/llama.cpp/src/llama-context.h +13 -13
  190. package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
  191. package/cpp/llama.cpp/src/llama-graph.h +78 -79
  192. package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
  193. package/cpp/llama.cpp/src/llama-hparams.h +11 -0
  194. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
  195. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  196. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
  197. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
  198. package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
  199. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
  200. package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
  201. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
  202. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  203. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  204. package/cpp/llama.cpp/src/llama-memory.h +21 -22
  205. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  206. package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
  207. package/cpp/llama.cpp/src/llama-model.h +40 -0
  208. package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
  209. package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
  210. package/cpp/llama.cpp/src/llama-vocab.h +42 -0
  211. package/cpp/rn-utils.h +3 -0
  212. package/ios/include/chat.h +2 -0
  213. package/ios/include/common.h +5 -0
  214. package/ios/include/llama.h +8 -43
  215. package/ios/libs/llama.xcframework/Info.plist +19 -19
  216. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  217. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  218. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  219. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  220. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
  221. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  223. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  224. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  225. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  231. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  232. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  248. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  250. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
  254. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
  255. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  261. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  262. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  263. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  264. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  265. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  267. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  268. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
  269. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
  270. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
  274. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  275. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  276. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  277. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  278. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  279. package/package.json +1 -1
  280. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  315. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  316. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  317. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  318. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  319. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -0,0 +1,19 @@
1
+ #include "mean.cuh"
2
+
3
+ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
4
+ const ggml_tensor * src0 = dst->src[0];
5
+ const float * src0_d = (const float *) src0->data;
6
+ float * dst_d = (float *) dst->data;
7
+ cudaStream_t stream = ctx.stream();
8
+
9
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
10
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
11
+ GGML_ASSERT(ggml_is_contiguous(src0));
12
+
13
+ const int64_t ncols = src0->ne[0];
14
+ const int64_t nrows = ggml_nrows(src0);
15
+
16
+ const dim3 block_dims(WARP_SIZE, 1, 1);
17
+ const dim3 block_nums(nrows, 1, 1);
18
+ reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
19
+ }
@@ -0,0 +1,3 @@
1
+ #include "common.cuh"
2
+
3
+ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -3016,14 +3016,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3016
3016
 
3017
3017
  const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
3018
3018
 
3019
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
3020
- static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
3021
- if (!shared_memory_limit_raised[id]) {
3022
- CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3023
- CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3024
- shared_memory_limit_raised[id] = true;
3025
- }
3026
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
3019
+ CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, false>), nbytes_shared);
3020
+ CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, true>), nbytes_shared);
3027
3021
 
3028
3022
  const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
3029
3023
  const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
@@ -2,25 +2,26 @@
2
2
  #include "common.cuh"
3
3
  #include "mmv.cuh"
4
4
 
5
- template <typename T, typename type_acc, int block_size>
5
+ template <typename T, typename type_acc, int ncols_dst, int block_size>
6
6
  static __global__ void mul_mat_vec(
7
7
  const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
8
- const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row,
9
- const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
10
- const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
11
- const int64_t row = blockIdx.x;
12
- const int64_t channel_dst = blockIdx.y;
13
- const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
14
- const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst;
15
- const int64_t sample_dst = blockIdx.z;
16
- const int64_t sample_x = sample_dst / sample_ratio;
17
- const int64_t sample_y = sample_dst;
18
- const int tid = threadIdx.x;
8
+ const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
9
+ const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
10
+ const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
11
+ const int row = blockIdx.x;
12
+ const int channel_dst = blockIdx.y;
13
+ const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
14
+ const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
15
+ const int sample_dst = blockIdx.z;
16
+ const int sample_x = sample_dst / sample_ratio;
17
+ const int sample_y = sample_dst;
18
+ const int tid = threadIdx.x;
19
+
19
20
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
20
21
 
21
- x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
22
- y += sample_y *stride_sample_y + channel_y *stride_channel_y;
23
- dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
22
+ x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
23
+ y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
24
+ dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
24
25
 
25
26
  const float2 * y2 = (const float2 *) y;
26
27
 
@@ -34,81 +35,108 @@ static __global__ void mul_mat_vec(
34
35
  __syncthreads();
35
36
  }
36
37
 
37
- float sumf = 0.0f;
38
+ float sumf[ncols_dst] = {0.0f};
38
39
 
39
40
  if constexpr (std::is_same<T, float>::value) {
40
41
  const float2 * x2 = (const float2 *) x;
41
42
 
42
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
43
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
43
44
  const float2 tmpx = x2[col2];
44
- const float2 tmpy = y2[col2];
45
- sumf += tmpx.x*tmpy.x;
46
- sumf += tmpx.y*tmpy.y;
45
+
46
+ #pragma unroll
47
+ for (int j = 0; j < ncols_dst; ++j) {
48
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
49
+ sumf[j] += tmpx.x*tmpy.x;
50
+ sumf[j] += tmpx.y*tmpy.y;
51
+ }
47
52
  }
48
53
  } else if constexpr (std::is_same<T, half>::value) {
49
54
  const half2 * x2 = (const half2 *) x;
50
55
 
51
56
  if (std::is_same<type_acc, float>::value) {
52
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
57
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
53
58
  const float2 tmpx = __half22float2(x2[col2]);
54
- const float2 tmpy = y2[col2];
55
- sumf += tmpx.x * tmpy.x;
56
- sumf += tmpx.y * tmpy.y;
59
+
60
+ #pragma unroll
61
+ for (int j = 0; j < ncols_dst; ++j) {
62
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
63
+ sumf[j] += tmpx.x * tmpy.x;
64
+ sumf[j] += tmpx.y * tmpy.y;
65
+ }
57
66
  }
58
67
  } else {
59
68
  #ifdef FP16_AVAILABLE
60
- half2 sumh2 = make_half2(0.0f, 0.0f);
69
+ half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
70
+
71
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
72
+ const half2 tmpx = x2[col2];
61
73
 
62
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
63
- const float2 tmp = y2[col2];
64
- sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
74
+ #pragma unroll
75
+ for (int j = 0; j < ncols_dst; ++j) {
76
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
77
+ sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
78
+ }
65
79
  }
66
80
 
67
- sumf = __low2float(sumh2) + __high2float(sumh2);
81
+ #pragma unroll
82
+ for (int j = 0; j < ncols_dst; ++j) {
83
+ sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
84
+ }
68
85
  #else
69
86
  NO_DEVICE_CODE;
70
87
  #endif // FP16_AVAILABLE
71
88
  }
72
89
  } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
73
90
  const int * x2 = (const int *) x;
74
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
75
- const int tmpx = x2[col2];
76
- const float2 tmpy = y2[col2];
77
- sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
78
- sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
91
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
92
+ const int tmpx = x2[col2];
93
+ #pragma unroll
94
+ for (int j = 0; j < ncols_dst; ++j) {
95
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
96
+ sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
97
+ sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
98
+ }
79
99
  }
80
100
  } else {
81
101
  static_assert(std::is_same<T, void>::value, "unsupported type");
82
102
  }
83
103
 
84
- sumf = warp_reduce_sum<warp_size>(sumf);
104
+ #pragma unroll
105
+ for (int j = 0; j < ncols_dst; ++j) {
106
+ sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
85
107
 
86
- if (block_size > warp_size) {
87
- buf_iw[tid/warp_size] = sumf;
88
- __syncthreads();
89
- if (tid >= warp_size) {
90
- return;
108
+ if (block_size > warp_size) {
109
+ buf_iw[tid/warp_size] = sumf[j];
110
+ __syncthreads();
111
+ if (tid < warp_size) {
112
+ sumf[j] = buf_iw[tid];
113
+ sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
114
+ }
115
+ if (j < ncols_dst) {
116
+ __syncthreads();
117
+ }
91
118
  }
92
- sumf = buf_iw[tid];
93
- sumf = warp_reduce_sum<warp_size>(sumf);
94
119
  }
95
120
 
96
- if (tid != 0) {
121
+ if (tid >= ncols_dst) {
97
122
  return;
98
123
  }
99
124
 
100
- dst[row] = sumf;
125
+ dst[tid*stride_col_dst + row] = sumf[tid];
101
126
  }
102
127
 
103
- template <typename T, typename type_acc>
128
+ template <typename T, typename type_acc, int ncols_dst>
104
129
  static void launch_mul_mat_vec_cuda(
105
130
  const T * x, const float * y, const int32_t * ids, float * dst,
106
- const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
131
+ const int64_t ncols, const int64_t nrows,
132
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
133
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
107
134
  const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
108
135
  const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
109
136
  cudaStream_t stream) {
110
- GGML_ASSERT(ncols % 2 == 0);
111
- GGML_ASSERT(stride_row % 2 == 0);
137
+ GGML_ASSERT(ncols % 2 == 0);
138
+ GGML_ASSERT(stride_row % 2 == 0);
139
+ GGML_ASSERT(stride_col_y % 2 == 0);
112
140
  GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
113
141
  GGML_ASSERT( nsamples_dst % nsamples_x == 0);
114
142
  const int64_t channel_ratio = nchannels_dst / nchannels_x;
@@ -138,44 +166,52 @@ static void launch_mul_mat_vec_cuda(
138
166
  const dim3 block_dims(block_size_best, 1, 1);
139
167
  switch (block_size_best) {
140
168
  case 32: {
141
- mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
142
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
143
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
169
+ mul_mat_vec<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, smem, stream>>>
170
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
171
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
172
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
144
173
  } break;
145
174
  case 64: {
146
- mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
147
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
148
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
175
+ mul_mat_vec<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, smem, stream>>>
176
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
177
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
178
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
149
179
  } break;
150
180
  case 96: {
151
- mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
152
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
153
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
181
+ mul_mat_vec<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, smem, stream>>>
182
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
183
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
184
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
154
185
  } break;
155
186
  case 128: {
156
- mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
157
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
158
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
187
+ mul_mat_vec<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, smem, stream>>>
188
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
189
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
190
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
159
191
  } break;
160
192
  case 160: {
161
- mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
162
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
163
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
193
+ mul_mat_vec<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, smem, stream>>>
194
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
195
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
196
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
164
197
  } break;
165
198
  case 192: {
166
- mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
167
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
168
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
199
+ mul_mat_vec<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, smem, stream>>>
200
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
201
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
202
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
169
203
  } break;
170
204
  case 224: {
171
- mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
172
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
173
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
205
+ mul_mat_vec<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, smem, stream>>>
206
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
207
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
208
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
174
209
  } break;
175
210
  case 256: {
176
- mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
177
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
178
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
211
+ mul_mat_vec<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, smem, stream>>>
212
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
213
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
214
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
179
215
  } break;
180
216
  default: {
181
217
  GGML_ABORT("fatal error");
@@ -183,23 +219,91 @@ static void launch_mul_mat_vec_cuda(
183
219
  }
184
220
  }
185
221
 
222
+ template <typename T, typename type_acc>
223
+ static void mul_mat_vec_cuda_switch_ncols_dst(
224
+ const T * x, const float * y, const int32_t * ids, float * dst,
225
+ const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
226
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
227
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
228
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
229
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
230
+ cudaStream_t stream) {
231
+ switch (ncols_dst) {
232
+ case 1:
233
+ launch_mul_mat_vec_cuda<T, type_acc, 1>
234
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
235
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
236
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
237
+ break;
238
+ case 2:
239
+ launch_mul_mat_vec_cuda<T, type_acc, 2>
240
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
241
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
242
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
243
+ break;
244
+ case 3:
245
+ launch_mul_mat_vec_cuda<T, type_acc, 3>
246
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
247
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
248
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
249
+ break;
250
+ case 4:
251
+ launch_mul_mat_vec_cuda<T, type_acc, 4>
252
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
253
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
254
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
255
+ break;
256
+ case 5:
257
+ launch_mul_mat_vec_cuda<T, type_acc, 5>
258
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
259
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
260
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
261
+ break;
262
+ case 6:
263
+ launch_mul_mat_vec_cuda<T, type_acc, 6>
264
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
265
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
266
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
267
+ break;
268
+ case 7:
269
+ launch_mul_mat_vec_cuda<T, type_acc, 7>
270
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
271
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
272
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
273
+ break;
274
+ case 8:
275
+ launch_mul_mat_vec_cuda<T, type_acc, 8>
276
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
277
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
278
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
279
+ break;
280
+ default:
281
+ GGML_ABORT("fatal error");
282
+ break;
283
+ }
284
+ }
285
+
186
286
  template<typename T>
187
287
  static void mul_mat_vec_cuda(
188
288
  const T * x, const float * y, const int32_t * ids, float * dst,
189
- const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
289
+ const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
290
+ const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
291
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
190
292
  const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
191
293
  const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
192
294
  enum ggml_prec prec, cudaStream_t stream) {
193
295
  if constexpr(std::is_same<T, half>::value) {
194
296
  if (prec == GGML_PREC_DEFAULT) {
195
- launch_mul_mat_vec_cuda<T, half>
196
- (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
297
+ mul_mat_vec_cuda_switch_ncols_dst<T, half>
298
+ (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
299
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
197
300
  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
198
301
  return;
199
302
  }
200
303
  }
201
- launch_mul_mat_vec_cuda<T, float>
202
- (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
304
+ mul_mat_vec_cuda_switch_ncols_dst<T, float>
305
+ (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
306
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
203
307
  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
204
308
  }
205
309
 
@@ -246,24 +350,24 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
246
350
  const int64_t stride_channel_dst = ids ? s1 : s2;
247
351
  const int64_t stride_channel_y = ids ? s11 : s12;
248
352
 
249
- GGML_ASSERT(ncols_dst == 1);
353
+ GGML_ASSERT(!ids || ncols_dst == 1);
250
354
 
251
355
  switch (src0->type) {
252
356
  case GGML_TYPE_F32: {
253
357
  const float * src0_d = (const float *) src0->data;
254
- mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
358
+ mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
255
359
  ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
256
360
  ne03, ne3, s03, s13, s3, prec, ctx.stream());
257
361
  } break;
258
362
  case GGML_TYPE_F16: {
259
363
  const half * src0_d = (const half *) src0->data;
260
- mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
364
+ mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
261
365
  ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
262
366
  ne03, ne3, s03, s13, s3, prec, ctx.stream());
263
367
  } break;
264
368
  case GGML_TYPE_BF16: {
265
369
  const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
266
- mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
370
+ mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
267
371
  ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
268
372
  ne03, ne3, s03, s13, s3, prec, ctx.stream());
269
373
  } break;
@@ -282,16 +386,19 @@ void ggml_cuda_op_mul_mat_vec(
282
386
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
283
387
 
284
388
  const int64_t ne00 = src0->ne[0];
389
+ const int64_t ne10 = src1->ne[0];
390
+ const int64_t ne0 = dst->ne[0];
285
391
  const int64_t row_diff = row_high - row_low;
286
392
 
287
- GGML_ASSERT(src1_ncols == 1);
288
-
289
- const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
393
+ const int id = ggml_cuda_get_device();
394
+ const int cc = ggml_cuda_info().devices[id].cc;
290
395
  const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
291
396
 
292
397
 
293
398
  // ggml_cuda_op provides single, contiguous matrices
294
399
  const int64_t stride_row = ne00;
400
+ const int64_t stride_col_y = ne10;
401
+ const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
295
402
  const int64_t nchannels_x = 1;
296
403
  const int64_t nchannels_y = 1;
297
404
  const int64_t nchannels_dst = 1;
@@ -307,19 +414,19 @@ void ggml_cuda_op_mul_mat_vec(
307
414
  switch (src0->type) {
308
415
  case GGML_TYPE_F32: {
309
416
  const float * src0_d = (const float *) src0_dd_i;
310
- mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
417
+ mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
311
418
  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
312
419
  nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
313
420
  } break;
314
421
  case GGML_TYPE_F16: {
315
422
  const half * src0_d = (const half *) src0_dd_i;
316
- mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
423
+ mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
317
424
  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
318
425
  nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
319
426
  } break;
320
427
  case GGML_TYPE_BF16: {
321
428
  const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
322
- mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
429
+ mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
323
430
  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
324
431
  nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
325
432
  } break;
@@ -334,3 +441,66 @@ void ggml_cuda_op_mul_mat_vec(
334
441
  GGML_UNUSED(src1_ncols);
335
442
  GGML_UNUSED(src1_padded_row_size);
336
443
  }
444
+
445
+ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
446
+ if (src0_ne[0] % 2 != 0) {
447
+ return false;
448
+ }
449
+ switch (type) {
450
+ case GGML_TYPE_F32:
451
+ if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
452
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
453
+ return ne11 <= 8;
454
+ }
455
+ if (cc >= GGML_CUDA_CC_TURING) {
456
+ return ne11 <= 4;
457
+ }
458
+ return ne11 <= 3;
459
+ } else if (GGML_CUDA_CC_IS_AMD(cc)) {
460
+ if (fp32_mma_hardware_available(cc)) {
461
+ return ne11 <= 3;
462
+ }
463
+ return ne11 <= 8;
464
+ }
465
+ return ne11 <= 8;
466
+ case GGML_TYPE_F16:
467
+ if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
468
+ const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
469
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
470
+ return src0_small && ne11 <= 4;
471
+ }
472
+ if (fp16_mma_hardware_available(cc)) {
473
+ return src0_small && ne11 <= 3;
474
+ }
475
+ return ne11 <= 8;
476
+ } else if (GGML_CUDA_CC_IS_AMD(cc)) {
477
+ if (fp16_mma_hardware_available(cc)) {
478
+ if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
479
+ return ne11 <= 5;
480
+ }
481
+ return ne11 <= 2;
482
+ }
483
+ return ne11 <= 8;
484
+ }
485
+ return ne11 <= 8;
486
+ case GGML_TYPE_BF16:
487
+ if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
488
+ const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
489
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
490
+ return src0_small && ne11 <= 4;
491
+ }
492
+ if (bf16_mma_hardware_available(cc)) {
493
+ return src0_small && ne11 <= 3;
494
+ }
495
+ return ne11 <= 8;
496
+ } else if (GGML_CUDA_CC_IS_AMD(cc)) {
497
+ if (bf16_mma_hardware_available(cc)) {
498
+ return ne11 <= 3;
499
+ }
500
+ return ne11 <= 8;
501
+ }
502
+ return ne11 <= 8;
503
+ default:
504
+ return false;
505
+ }
506
+ }
@@ -1,8 +1,5 @@
1
1
  #include "common.cuh"
2
2
 
3
- // maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
4
- #define MMV_MAX_ROWS 512
5
-
6
3
  void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
7
4
 
8
5
  void ggml_cuda_op_mul_mat_vec(
@@ -10,3 +7,5 @@ void ggml_cuda_op_mul_mat_vec(
10
7
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
11
8
  const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
12
9
  const int64_t src1_padded_row_size, cudaStream_t stream);
10
+
11
+ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
@@ -50,21 +50,19 @@ static __global__ void rope_norm(
50
50
 
51
51
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
52
52
 
53
- if (i0 >= n_dims) {
54
- const int i = row_dst*ne0 + i0;
55
-
56
- dst[i + 0] = x[i + 0];
57
- dst[i + 1] = x[i + 1];
58
-
59
- return;
60
- }
61
-
62
53
  const int row_x = row_dst % ne1;
63
54
  const int channel_x = row_dst / ne1;
64
55
 
65
56
  const int idst = row_dst*ne0 + i0;
66
57
  const int ix = channel_x*s2 + row_x*s1 + i0;
67
58
 
59
+ if (i0 >= n_dims) {
60
+ dst[idst + 0] = x[ix + 0];
61
+ dst[idst + 1] = x[ix + 1];
62
+
63
+ return;
64
+ }
65
+
68
66
  const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
69
67
 
70
68
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -94,21 +92,19 @@ static __global__ void rope_neox(
94
92
 
95
93
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
96
94
 
97
- if (i0 >= n_dims) {
98
- const int i = row_dst*ne0 + i0;
99
-
100
- dst[i + 0] = x[i + 0];
101
- dst[i + 1] = x[i + 1];
102
-
103
- return;
104
- }
105
-
106
95
  const int row_x = row_dst % ne1;
107
96
  const int channel_x = row_dst / ne1;
108
97
 
109
98
  const int idst = row_dst*ne0 + i0/2;
110
99
  const int ix = channel_x*s2 + row_x*s1 + i0/2;
111
100
 
101
+ if (i0 >= n_dims) {
102
+ dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
103
+ dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
104
+
105
+ return;
106
+ }
107
+
112
108
  const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
113
109
 
114
110
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -138,21 +134,19 @@ static __global__ void rope_multi(
138
134
 
139
135
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
140
136
 
141
- if (i0 >= n_dims) {
142
- const int i = row_dst*ne0 + i0;
143
-
144
- dst[i + 0] = x[i + 0];
145
- dst[i + 1] = x[i + 1];
146
-
147
- return;
148
- }
149
-
150
137
  const int row_x = row_dst % ne1;
151
138
  const int channel_x = row_dst / ne1;
152
139
 
153
140
  const int idst = row_dst*ne0 + i0/2;
154
141
  const int ix = channel_x*s2 + row_x*s1 + i0/2;
155
142
 
143
+ if (i0 >= n_dims) {
144
+ dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
145
+ dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
146
+
147
+ return;
148
+ }
149
+
156
150
  const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
157
151
  const int sec_w = sections.v[1] + sections.v[0];
158
152
  const int sector = (i0 / 2) % sect_dims;