@novastera-oss/llamarn 0.2.9 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (314) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  22. package/cpp/llama.cpp/CMakePresets.json +11 -0
  23. package/cpp/llama.cpp/CODEOWNERS +1 -0
  24. package/cpp/llama.cpp/README.md +8 -8
  25. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  26. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  27. package/cpp/llama.cpp/common/arg.cpp +62 -1
  28. package/cpp/llama.cpp/common/chat.cpp +37 -20
  29. package/cpp/llama.cpp/common/chat.h +2 -0
  30. package/cpp/llama.cpp/common/common.cpp +22 -6
  31. package/cpp/llama.cpp/common/common.h +22 -4
  32. package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
  33. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
  34. package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
  35. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  36. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  37. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  38. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  39. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  40. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  41. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
  42. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  44. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
  45. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  46. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  97. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
  99. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
  100. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  131. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
  132. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  138. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  139. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  140. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  141. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  142. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  143. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  144. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  168. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  169. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  170. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  173. package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  177. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  178. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  179. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
  180. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  181. package/cpp/llama.cpp/include/llama.h +15 -47
  182. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  183. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  184. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  185. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  186. package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
  187. package/cpp/llama.cpp/src/llama-arch.h +23 -1
  188. package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
  189. package/cpp/llama.cpp/src/llama-batch.h +31 -18
  190. package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
  191. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  192. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  193. package/cpp/llama.cpp/src/llama-context.h +26 -16
  194. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  195. package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
  196. package/cpp/llama.cpp/src/llama-graph.h +184 -122
  197. package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
  198. package/cpp/llama.cpp/src/llama-hparams.h +13 -2
  199. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  200. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  201. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  202. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  203. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  204. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  205. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  206. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
  207. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  208. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  209. package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
  210. package/cpp/llama.cpp/src/llama-model.h +21 -4
  211. package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
  212. package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
  213. package/cpp/llama.cpp/src/llama-vocab.h +43 -0
  214. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  215. package/cpp/llama.cpp/src/unicode.h +2 -0
  216. package/ios/include/chat.h +2 -0
  217. package/ios/include/common.h +22 -4
  218. package/ios/include/llama.h +15 -47
  219. package/ios/libs/llama.xcframework/Info.plist +13 -13
  220. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  221. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  223. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  224. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
  225. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  231. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  232. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  248. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  250. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  254. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  255. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  261. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  262. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
  263. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  264. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  265. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  267. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  268. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
  269. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
  270. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  274. package/package.json +4 -4
  275. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  276. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  277. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  278. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  279. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  280. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -104,10 +104,12 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
104
104
  }
105
105
  }
106
106
 
107
- template <int block_size>
107
+ template <int block_size, bool do_multiply = false>
108
108
  static __global__ void rms_norm_f32(
109
109
  const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
110
- const int64_t stride_sample, const float eps) {
110
+ const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0,
111
+ const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0,
112
+ const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) {
111
113
  const int nrows = gridDim.x;
112
114
  const int nchannels = gridDim.y;
113
115
 
@@ -119,6 +121,13 @@ static __global__ void rms_norm_f32(
119
121
  x += sample*stride_sample + channel*stride_channel + row*stride_row;
120
122
  dst += ((sample*nchannels + channel)*nrows + row)*ncols;
121
123
 
124
+ if constexpr (do_multiply) {
125
+ const int mul_row = row % mul_nrows;
126
+ const int mul_channel = channel % mul_nchannels;
127
+ const int mul_sample = sample % mul_nsamples;
128
+ mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
129
+ }
130
+
122
131
  float tmp = 0.0f; // partial sum for thread in warp
123
132
 
124
133
  for (int col = tid; col < ncols; col += block_size) {
@@ -145,7 +154,12 @@ static __global__ void rms_norm_f32(
145
154
  const float scale = rsqrtf(mean + eps);
146
155
 
147
156
  for (int col = tid; col < ncols; col += block_size) {
148
- dst[col] = scale * x[col];
157
+ if constexpr (do_multiply) {
158
+ const int mul_col = col % mul_ncols;
159
+ dst[col] = scale * x[col] * mul[mul_col];
160
+ } else {
161
+ dst[col] = scale * x[col];
162
+ }
149
163
  }
150
164
  }
151
165
 
@@ -310,10 +324,30 @@ static void rms_norm_f32_cuda(
310
324
  const dim3 blocks_num(nrows, nchannels, nsamples);
311
325
  if (ncols < 1024) {
312
326
  const dim3 block_dims(WARP_SIZE, 1, 1);
313
- rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
327
+ rms_norm_f32<WARP_SIZE, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
328
+ } else {
329
+ const dim3 block_dims(1024, 1, 1);
330
+ rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
331
+ }
332
+ }
333
+
334
+ static void rms_norm_mul_f32_cuda(
335
+ const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
336
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
337
+ const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
338
+ const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
339
+ const float eps, cudaStream_t stream) {
340
+ const dim3 blocks_num(nrows, nchannels, nsamples);
341
+ if (mul == nullptr) {
342
+ rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
343
+ return;
344
+ }
345
+ if (ncols < 1024) {
346
+ const dim3 block_dims(WARP_SIZE, 1, 1);
347
+ rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
314
348
  } else {
315
349
  const dim3 block_dims(1024, 1, 1);
316
- rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
350
+ rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
317
351
  }
318
352
  }
319
353
 
@@ -407,6 +441,59 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
407
441
  rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
408
442
  }
409
443
 
444
+ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
445
+ const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
446
+ float eps = 0.0f;
447
+
448
+ memcpy(&eps, dst->op_params, sizeof(float));
449
+
450
+ const float * src0_d = (const float *) rms_norm_src->data;
451
+ const float * mul_d = nullptr;
452
+ const ggml_tensor * mul_src = nullptr;
453
+
454
+ if (mul_tensor->src[0] == dst) {
455
+ mul_d = (float *) mul_tensor->src[1]->data;
456
+ mul_src = mul_tensor->src[1];
457
+ } else if(mul_tensor->src[1] == dst) {
458
+ mul_d = (float *) mul_tensor->src[0]->data;
459
+ mul_src = mul_tensor->src[0];
460
+ } else {
461
+ GGML_ASSERT(false);
462
+ }
463
+
464
+ float * dst_d = (float *) mul_tensor->data;
465
+ cudaStream_t stream = ctx.stream();
466
+
467
+ GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
468
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
469
+ GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
470
+ GGML_ASSERT(eps >= 0.0f);
471
+
472
+ const int64_t ne00 = rms_norm_src->ne[0];
473
+ const int64_t ne01 = rms_norm_src->ne[1];
474
+ const int64_t ne02 = rms_norm_src->ne[2];
475
+ const int64_t ne03 = rms_norm_src->ne[3];
476
+
477
+ const size_t ts0 = ggml_type_size(rms_norm_src->type);
478
+ GGML_ASSERT(rms_norm_src->nb[0] == ts0);
479
+ const int64_t s01 = rms_norm_src->nb[1] / ts0;
480
+ const int64_t s02 = rms_norm_src->nb[2] / ts0;
481
+ const int64_t s03 = rms_norm_src->nb[3] / ts0;
482
+
483
+ const size_t ts_mul = ggml_type_size(mul_src->type);
484
+ GGML_ASSERT(mul_src->nb[0] == ts_mul);
485
+ const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
486
+ const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
487
+ const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
488
+
489
+ const int mul_ncols = mul_src->ne[0];
490
+ const int mul_nrows = mul_src->ne[1];
491
+ const int mul_nchannels = mul_src->ne[2];
492
+ const int mul_nsamples = mul_src->ne[3];
493
+
494
+ rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
495
+ }
496
+
410
497
  void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
411
498
  const ggml_tensor * grad = dst->src[0]; // gradients
412
499
  const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
@@ -6,6 +6,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
6
6
 
7
7
  void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
8
8
 
9
+ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);
10
+
9
11
  void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
10
12
 
11
13
  void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -50,21 +50,19 @@ static __global__ void rope_norm(
50
50
 
51
51
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
52
52
 
53
- if (i0 >= n_dims) {
54
- const int i = row_dst*ne0 + i0;
55
-
56
- dst[i + 0] = x[i + 0];
57
- dst[i + 1] = x[i + 1];
58
-
59
- return;
60
- }
61
-
62
53
  const int row_x = row_dst % ne1;
63
54
  const int channel_x = row_dst / ne1;
64
55
 
65
56
  const int idst = row_dst*ne0 + i0;
66
57
  const int ix = channel_x*s2 + row_x*s1 + i0;
67
58
 
59
+ if (i0 >= n_dims) {
60
+ dst[idst + 0] = x[ix + 0];
61
+ dst[idst + 1] = x[ix + 1];
62
+
63
+ return;
64
+ }
65
+
68
66
  const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
69
67
 
70
68
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -94,21 +92,19 @@ static __global__ void rope_neox(
94
92
 
95
93
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
96
94
 
97
- if (i0 >= n_dims) {
98
- const int i = row_dst*ne0 + i0;
99
-
100
- dst[i + 0] = x[i + 0];
101
- dst[i + 1] = x[i + 1];
102
-
103
- return;
104
- }
105
-
106
95
  const int row_x = row_dst % ne1;
107
96
  const int channel_x = row_dst / ne1;
108
97
 
109
98
  const int idst = row_dst*ne0 + i0/2;
110
99
  const int ix = channel_x*s2 + row_x*s1 + i0/2;
111
100
 
101
+ if (i0 >= n_dims) {
102
+ dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
103
+ dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
104
+
105
+ return;
106
+ }
107
+
112
108
  const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
113
109
 
114
110
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -138,21 +134,19 @@ static __global__ void rope_multi(
138
134
 
139
135
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
140
136
 
141
- if (i0 >= n_dims) {
142
- const int i = row_dst*ne0 + i0;
143
-
144
- dst[i + 0] = x[i + 0];
145
- dst[i + 1] = x[i + 1];
146
-
147
- return;
148
- }
149
-
150
137
  const int row_x = row_dst % ne1;
151
138
  const int channel_x = row_dst / ne1;
152
139
 
153
140
  const int idst = row_dst*ne0 + i0/2;
154
141
  const int ix = channel_x*s2 + row_x*s1 + i0/2;
155
142
 
143
+ if (i0 >= n_dims) {
144
+ dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
145
+ dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
146
+
147
+ return;
148
+ }
149
+
156
150
  const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
157
151
  const int sec_w = sections.v[1] + sections.v[0];
158
152
  const int sector = (i0 / 2) % sect_dims;
@@ -1,18 +1,18 @@
1
1
  #include "scale.cuh"
2
2
 
3
- static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
3
+ static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) {
4
4
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
5
5
 
6
6
  if (i >= k) {
7
7
  return;
8
8
  }
9
9
 
10
- dst[i] = scale * x[i];
10
+ dst[i] = scale * x[i] + bias;
11
11
  }
12
12
 
13
- static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
13
+ static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) {
14
14
  const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
15
- scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
15
+ scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, k);
16
16
  }
17
17
 
18
18
  void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -25,7 +25,9 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
25
25
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
26
26
 
27
27
  float scale;
28
- memcpy(&scale, dst->op_params, sizeof(float));
28
+ float bias;
29
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
30
+ memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
29
31
 
30
- scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
32
+ scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream);
31
33
  }
@@ -0,0 +1,275 @@
1
+ #include "set-rows.cuh"
2
+ #include "cpy-utils.cuh"
3
+
4
+ typedef void (*set_rows_kernel_t)(const char * src, char * dst);
5
+
6
+ template<typename src_t, typename dst_t>
7
+ __device__ __forceinline__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
8
+ convert_flt(src_f, dst_f);
9
+ }
10
+
11
+ // Generic quantized set_rows kernel template
12
+ template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
13
+ static __global__ void k_set_rows_quant(
14
+ const float * __restrict__ src0, const int64_t * __restrict__ src1, block_type * __restrict__ dst,
15
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
16
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
17
+ const int64_t s01, const int64_t s02, const int64_t s03,
18
+ const int64_t s10, const int64_t s11, const int64_t s12,
19
+ const int64_t s1, const int64_t s2, const int64_t s3) {
20
+
21
+ const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
22
+ const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
23
+
24
+ if (i >= ne_total) {
25
+ return;
26
+ }
27
+
28
+ const int64_t i_base = i * qk;
29
+ const int64_t i03 = i_base / (ne00 * ne01 * ne02);
30
+ const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
31
+ const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
32
+ const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
33
+
34
+ const int64_t i12 = i03 % ne12;
35
+ const int64_t i11 = i02 % ne11;
36
+ const int64_t i10 = i01;
37
+
38
+ const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
39
+
40
+ const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
41
+ block_type * dst_row_ptr = dst + (dst_row*s1 + i02*s2 + i03*s3) / sizeof(block_type);
42
+
43
+ const float * src_block = src0_row + i00;
44
+ block_type * dst_block = dst_row_ptr + i00 / qk;
45
+
46
+ quantize_func(src_block, dst_block);
47
+
48
+ GGML_UNUSED(ne10);
49
+ GGML_UNUSED(ne13);
50
+ }
51
+
52
+ // Template dispatch function for quantized set_rows
53
+ template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
54
+ static void set_rows_cuda_quant(
55
+ const float * src0_d, const int64_t * src1_d, block_type * dst_d,
56
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
57
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
58
+ const size_t nb01, const size_t nb02, const size_t nb03,
59
+ const size_t nb10, const size_t nb11, const size_t nb12,
60
+ const size_t nb1, const size_t nb2, const size_t nb3,
61
+ cudaStream_t stream) {
62
+
63
+ GGML_ASSERT(ne00 % qk == 0);
64
+ const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
65
+ const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
66
+ const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
67
+ const dim3 grid_size(num_blocks);
68
+
69
+ const int64_t s01 = nb01/sizeof(float);
70
+ const int64_t s02 = nb02/sizeof(float);
71
+ const int64_t s03 = nb03/sizeof(float);
72
+ const int64_t s10 = nb10/sizeof(int64_t);
73
+ const int64_t s11 = nb11/sizeof(int64_t);
74
+ const int64_t s12 = nb12/sizeof(int64_t);
75
+ const int64_t s1 = nb1;
76
+ const int64_t s2 = nb2;
77
+ const int64_t s3 = nb3;
78
+
79
+ if (ne_total > 0) {
80
+ k_set_rows_quant<block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
81
+ src0_d, src1_d, dst_d,
82
+ ne00, ne01, ne02, ne03,
83
+ ne10, ne11, ne12, ne13,
84
+ s01, s02, s03,
85
+ s10, s11, s12,
86
+ s1, s2, s3);
87
+ }
88
+ }
89
+
90
+ template<typename src_t, typename dst_t>
91
+ static __global__ void k_set_rows(
92
+ const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
93
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
94
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
95
+ const int64_t s01, const int64_t s02, const int64_t s03,
96
+ const int64_t s10, const int64_t s11, const int64_t s12,
97
+ const int64_t s1, const int64_t s2, const int64_t s3) {
98
+
99
+ const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
100
+ const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
101
+
102
+ if (i >= ne_total) {
103
+ return;
104
+ }
105
+
106
+ const int64_t i03 = i / (ne00 * ne01 * ne02);
107
+ const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
108
+ const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
109
+ const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
110
+
111
+ const int64_t i12 = i03 % ne12;
112
+ const int64_t i11 = i02 % ne11;
113
+ const int64_t i10 = i01;
114
+
115
+ const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
116
+
117
+ const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
118
+ dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
119
+
120
+ const src_t* src_elem = src0_row + i00;
121
+ dst_t* dst_elem = dst_row_ptr + i00;
122
+ set_rows_1(src_elem, dst_elem);
123
+
124
+ GGML_UNUSED(ne10);
125
+ GGML_UNUSED(ne13);
126
+ }
127
+
128
+ template<typename src_t, typename dst_t>
129
+ static void set_rows_cuda(
130
+ const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d,
131
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
132
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
133
+ const size_t nb01, const size_t nb02, const size_t nb03,
134
+ const size_t nb10, const size_t nb11, const size_t nb12,
135
+ const size_t nb1, const size_t nb2, const size_t nb3,
136
+ cudaStream_t stream) {
137
+
138
+ const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
139
+ const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
140
+ const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
141
+ const dim3 grid_size(num_blocks);
142
+
143
+
144
+ const int64_t s01 = nb01/sizeof(src_t);
145
+ const int64_t s02 = nb02/sizeof(src_t);
146
+ const int64_t s03 = nb03/sizeof(src_t);
147
+ const int64_t s10 = nb10/sizeof(int64_t);
148
+ const int64_t s11 = nb11/sizeof(int64_t);
149
+ const int64_t s12 = nb12/sizeof(int64_t);
150
+ const int64_t s1 = nb1/sizeof(dst_t);
151
+ const int64_t s2 = nb2/sizeof(dst_t);
152
+ const int64_t s3 = nb3/sizeof(dst_t);
153
+
154
+ if (ne_total > 0) {
155
+ k_set_rows<<<grid_size, block_size, 0, stream>>>(
156
+ src0_d, src1_d, dst_d,
157
+ ne00, ne01, ne02, ne03,
158
+ ne10, ne11, ne12, ne13,
159
+ s01, s02, s03,
160
+ s10, s11, s12,
161
+ s1, s2, s3);
162
+ }
163
+ }
164
+
165
+
166
+ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
167
+ const ggml_tensor * src0 = dst->src[0];
168
+ const ggml_tensor * src1 = dst->src[1];
169
+
170
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
171
+ GGML_ASSERT(src1->type == GGML_TYPE_I64);
172
+
173
+ GGML_TENSOR_BINARY_OP_LOCALS
174
+
175
+ const float * src0_d = (const float *)src0->data;
176
+ const int64_t * src1_d = (const int64_t *)src1->data;
177
+
178
+ cudaStream_t stream = ctx.stream();
179
+
180
+
181
+
182
+ if (dst->type == GGML_TYPE_F32) {
183
+ set_rows_cuda(
184
+ src0_d, src1_d, (float*)dst->data,
185
+ ne00, ne01, ne02, ne03,
186
+ ne10, ne11, ne12, ne13,
187
+ nb01, nb02, nb03,
188
+ nb10, nb11, nb12,
189
+ nb1, nb2, nb3,
190
+ stream
191
+ );
192
+ } else if (dst->type == GGML_TYPE_F16) {
193
+ set_rows_cuda(
194
+ src0_d, src1_d, (half*)dst->data,
195
+ ne00, ne01, ne02, ne03,
196
+ ne10, ne11, ne12, ne13,
197
+ nb01, nb02, nb03,
198
+ nb10, nb11, nb12,
199
+ nb1, nb2, nb3,
200
+ stream
201
+ );
202
+ } else if (dst->type == GGML_TYPE_BF16) {
203
+ set_rows_cuda(
204
+ src0_d, src1_d, (nv_bfloat16*)dst->data,
205
+ ne00, ne01, ne02, ne03,
206
+ ne10, ne11, ne12, ne13,
207
+ nb01, nb02, nb03,
208
+ nb10, nb11, nb12,
209
+ nb1, nb2, nb3,
210
+ stream
211
+ );
212
+ } else if (dst->type == GGML_TYPE_Q4_0) {
213
+ set_rows_cuda_quant<block_q4_0, QK4_0, quantize_f32_q4_0_block>(
214
+ src0_d, src1_d, (block_q4_0*)dst->data,
215
+ ne00, ne01, ne02, ne03,
216
+ ne10, ne11, ne12, ne13,
217
+ nb01, nb02, nb03,
218
+ nb10, nb11, nb12,
219
+ nb1, nb2, nb3,
220
+ stream
221
+ );
222
+ } else if (dst->type == GGML_TYPE_Q4_1) {
223
+ set_rows_cuda_quant<block_q4_1, QK4_1, quantize_f32_q4_1_block>(
224
+ src0_d, src1_d, (block_q4_1*)dst->data,
225
+ ne00, ne01, ne02, ne03,
226
+ ne10, ne11, ne12, ne13,
227
+ nb01, nb02, nb03,
228
+ nb10, nb11, nb12,
229
+ nb1, nb2, nb3,
230
+ stream
231
+ );
232
+ } else if (dst->type == GGML_TYPE_Q5_0) {
233
+ set_rows_cuda_quant<block_q5_0, QK5_0, quantize_f32_q5_0_block>(
234
+ src0_d, src1_d, (block_q5_0*)dst->data,
235
+ ne00, ne01, ne02, ne03,
236
+ ne10, ne11, ne12, ne13,
237
+ nb01, nb02, nb03,
238
+ nb10, nb11, nb12,
239
+ nb1, nb2, nb3,
240
+ stream
241
+ );
242
+ } else if (dst->type == GGML_TYPE_Q5_1) {
243
+ set_rows_cuda_quant<block_q5_1, QK5_1, quantize_f32_q5_1_block>(
244
+ src0_d, src1_d, (block_q5_1*)dst->data,
245
+ ne00, ne01, ne02, ne03,
246
+ ne10, ne11, ne12, ne13,
247
+ nb01, nb02, nb03,
248
+ nb10, nb11, nb12,
249
+ nb1, nb2, nb3,
250
+ stream
251
+ );
252
+ } else if (dst->type == GGML_TYPE_Q8_0) {
253
+ set_rows_cuda_quant<block_q8_0, QK8_0, quantize_f32_q8_0_block>(
254
+ src0_d, src1_d, (block_q8_0*)dst->data,
255
+ ne00, ne01, ne02, ne03,
256
+ ne10, ne11, ne12, ne13,
257
+ nb01, nb02, nb03,
258
+ nb10, nb11, nb12,
259
+ nb1, nb2, nb3,
260
+ stream
261
+ );
262
+ } else if (dst->type == GGML_TYPE_IQ4_NL) {
263
+ set_rows_cuda_quant<block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
264
+ src0_d, src1_d, (block_iq4_nl*)dst->data,
265
+ ne00, ne01, ne02, ne03,
266
+ ne10, ne11, ne12, ne13,
267
+ nb01, nb02, nb03,
268
+ nb10, nb11, nb12,
269
+ nb1, nb2, nb3,
270
+ stream
271
+ );
272
+ } else {
273
+ GGML_ABORT("unsupported type %s", ggml_type_name(dst->type));
274
+ }
275
+ }
@@ -0,0 +1,7 @@
1
+ #pragma once
2
+
3
+ #include "common.cuh"
4
+
5
+ #define CUDA_SET_ROWS_BLOCK_SIZE 256
6
+
7
+ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);