@novastera-oss/llamarn 0.2.9 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  5. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  9. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  12. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  17. package/cpp/build-info.cpp +2 -2
  18. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  19. package/cpp/llama.cpp/README.md +4 -5
  20. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  21. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  22. package/cpp/llama.cpp/common/arg.cpp +17 -0
  23. package/cpp/llama.cpp/common/chat.cpp +37 -20
  24. package/cpp/llama.cpp/common/chat.h +2 -0
  25. package/cpp/llama.cpp/common/common.h +4 -0
  26. package/cpp/llama.cpp/convert_hf_to_gguf.py +745 -6
  27. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  28. package/cpp/llama.cpp/ggml/CMakeLists.txt +7 -2
  29. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  30. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  31. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +0 -1
  32. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  33. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  34. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  35. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1203 -163
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +33 -9
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  43. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +17 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  47. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +8 -6
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +185 -79
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  66. package/cpp/llama.cpp/ggml/src/ggml-impl.h +64 -0
  67. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  68. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +35 -9
  69. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +167 -39
  70. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +254 -57
  71. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +505 -40
  73. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  83. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  84. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +60 -9
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +711 -292
  92. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  93. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  94. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  95. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  105. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  106. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  112. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  113. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  114. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  116. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  117. package/cpp/llama.cpp/ggml/src/ggml.c +382 -61
  118. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  119. package/cpp/llama.cpp/gguf-py/gguf/constants.py +209 -0
  120. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  121. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +73 -21
  122. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  123. package/cpp/llama.cpp/include/llama.h +0 -40
  124. package/cpp/llama.cpp/src/llama-arch.cpp +210 -3
  125. package/cpp/llama.cpp/src/llama-arch.h +18 -1
  126. package/cpp/llama.cpp/src/llama-batch.cpp +27 -1
  127. package/cpp/llama.cpp/src/llama-batch.h +8 -1
  128. package/cpp/llama.cpp/src/llama-chat.cpp +15 -0
  129. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  130. package/cpp/llama.cpp/src/llama-graph.cpp +119 -184
  131. package/cpp/llama.cpp/src/llama-graph.h +47 -60
  132. package/cpp/llama.cpp/src/llama-hparams.cpp +7 -1
  133. package/cpp/llama.cpp/src/llama-hparams.h +3 -0
  134. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +28 -18
  135. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +4 -2
  136. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +214 -65
  137. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +62 -24
  138. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  139. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +9 -4
  140. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  141. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +20 -10
  142. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  143. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  144. package/cpp/llama.cpp/src/llama-model.cpp +2530 -685
  145. package/cpp/llama.cpp/src/llama-model.h +18 -0
  146. package/cpp/llama.cpp/src/llama-quant.cpp +1 -0
  147. package/cpp/llama.cpp/src/llama-vocab.cpp +13 -2
  148. package/cpp/llama.cpp/src/llama-vocab.h +41 -0
  149. package/ios/include/chat.h +2 -0
  150. package/ios/include/common.h +4 -0
  151. package/ios/include/llama.h +0 -40
  152. package/ios/libs/llama.xcframework/Info.plist +19 -19
  153. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  154. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5055 -4886
  155. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  156. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  157. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +0 -40
  158. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  159. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  160. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  161. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
  162. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  163. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  164. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  165. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  166. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  167. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  168. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3766
  169. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  170. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  171. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +0 -40
  172. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  173. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  174. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +0 -40
  175. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  176. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  177. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  178. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +0 -40
  179. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  180. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  181. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  182. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4890
  183. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  184. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  185. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +0 -40
  186. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  187. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  188. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  189. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
  190. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  191. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  192. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  193. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  194. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  195. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5091 -4922
  196. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  197. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  198. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +0 -40
  199. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  200. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  201. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4897
  202. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3794
  203. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  204. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  205. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  206. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  207. package/package.json +1 -1
  208. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  209. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  210. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  211. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  212. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  213. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  214. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  215. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  216. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  217. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  218. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  219. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  220. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  221. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  222. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  223. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  224. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  225. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  226. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  227. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  228. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  229. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  230. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  231. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  232. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  233. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  234. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  235. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  236. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  237. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  238. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  239. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  240. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  241. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  242. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  243. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  244. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  245. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  246. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  247. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -109,6 +109,7 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
109
109
  }
110
110
 
111
111
  void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
112
+ #pragma METAL fp math_mode(safe)
112
113
  float amax = 0.0f; // absolute max
113
114
  float max = 0.0f;
114
115
 
@@ -138,6 +139,7 @@ void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
138
139
  }
139
140
 
140
141
  void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
142
+ #pragma METAL fp math_mode(safe)
141
143
  float min = FLT_MAX;
142
144
  float max = -FLT_MAX;
143
145
 
@@ -166,6 +168,7 @@ void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
166
168
  }
167
169
 
168
170
  void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
171
+ #pragma METAL fp math_mode(safe)
169
172
  float amax = 0.0f; // absolute max
170
173
  float max = 0.0f;
171
174
 
@@ -203,6 +206,7 @@ void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
203
206
  }
204
207
 
205
208
  void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
209
+ #pragma METAL fp math_mode(safe)
206
210
  float max = src[0];
207
211
  float min = src[0];
208
212
 
@@ -239,6 +243,7 @@ void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
239
243
  }
240
244
 
241
245
  void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
246
+ #pragma METAL fp math_mode(safe)
242
247
  float amax = 0.0f; // absolute max
243
248
  float max = 0.0f;
244
249
 
@@ -458,6 +463,7 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
458
463
  }
459
464
 
460
465
  void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
466
+ #pragma METAL fp math_mode(safe)
461
467
  float amax = 0.0f; // absolute max
462
468
 
463
469
  for (int j = 0; j < QK8_0; j++) {
@@ -1008,16 +1014,18 @@ kernel void kernel_scale(
1008
1014
  device const float * src0,
1009
1015
  device float * dst,
1010
1016
  constant float & scale,
1017
+ constant float & bias,
1011
1018
  uint tpig[[thread_position_in_grid]]) {
1012
- dst[tpig] = src0[tpig] * scale;
1019
+ dst[tpig] = src0[tpig] * scale + bias;
1013
1020
  }
1014
1021
 
1015
1022
  kernel void kernel_scale_4(
1016
1023
  device const float4 * src0,
1017
1024
  device float4 * dst,
1018
1025
  constant float & scale,
1026
+ constant float & bias,
1019
1027
  uint tpig[[thread_position_in_grid]]) {
1020
- dst[tpig] = src0[tpig] * scale;
1028
+ dst[tpig] = src0[tpig] * scale + bias;
1021
1029
  }
1022
1030
 
1023
1031
  kernel void kernel_clamp(
@@ -1191,6 +1199,114 @@ kernel void kernel_neg(
1191
1199
  dst[tpig] = -src0[tpig];
1192
1200
  }
1193
1201
 
1202
+ kernel void kernel_reglu(
1203
+ device const char * src0,
1204
+ device const char * src1,
1205
+ device char * dst,
1206
+ constant ggml_metal_kargs_glu & args,
1207
+ uint tgpig[[threadgroup_position_in_grid]],
1208
+ uint tpitg[[thread_position_in_threadgroup]],
1209
+ uint ntg[[threads_per_threadgroup]]) {
1210
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1211
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1212
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1213
+
1214
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1215
+ const float x0 = src0_row[i0];
1216
+ const float x1 = src1_row[i0];
1217
+
1218
+ dst_row[i0] = x0*x1*(x0 > 0.0f);
1219
+ }
1220
+ }
1221
+
1222
+ kernel void kernel_geglu(
1223
+ device const char * src0,
1224
+ device const char * src1,
1225
+ device char * dst,
1226
+ constant ggml_metal_kargs_glu & args,
1227
+ uint tgpig[[threadgroup_position_in_grid]],
1228
+ uint tpitg[[thread_position_in_threadgroup]],
1229
+ uint ntg[[threads_per_threadgroup]]) {
1230
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1231
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1232
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1233
+
1234
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1235
+ const float x0 = src0_row[i0];
1236
+ const float x1 = src1_row[i0];
1237
+
1238
+ const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
1239
+
1240
+ dst_row[i0] = gelu*x1;
1241
+ }
1242
+ }
1243
+
1244
+ kernel void kernel_swiglu(
1245
+ device const char * src0,
1246
+ device const char * src1,
1247
+ device char * dst,
1248
+ constant ggml_metal_kargs_glu & args,
1249
+ uint tgpig[[threadgroup_position_in_grid]],
1250
+ uint tpitg[[thread_position_in_threadgroup]],
1251
+ uint ntg[[threads_per_threadgroup]]) {
1252
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1253
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1254
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1255
+
1256
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1257
+ const float x0 = src0_row[i0];
1258
+ const float x1 = src1_row[i0];
1259
+
1260
+ const float silu = x0 / (1.0f + exp(-x0));
1261
+
1262
+ dst_row[i0] = silu*x1;
1263
+ }
1264
+ }
1265
+
1266
+ kernel void kernel_geglu_erf(
1267
+ device const char * src0,
1268
+ device const char * src1,
1269
+ device char * dst,
1270
+ constant ggml_metal_kargs_glu & args,
1271
+ uint tgpig[[threadgroup_position_in_grid]],
1272
+ uint tpitg[[thread_position_in_threadgroup]],
1273
+ uint ntg[[threads_per_threadgroup]]) {
1274
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1275
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1276
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1277
+
1278
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1279
+ const float x0 = src0_row[i0];
1280
+ const float x1 = src1_row[i0];
1281
+
1282
+ const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
1283
+
1284
+ dst_row[i0] = gelu_erf*x1;
1285
+ }
1286
+ }
1287
+
1288
+ kernel void kernel_geglu_quick(
1289
+ device const char * src0,
1290
+ device const char * src1,
1291
+ device char * dst,
1292
+ constant ggml_metal_kargs_glu & args,
1293
+ uint tgpig[[threadgroup_position_in_grid]],
1294
+ uint tpitg[[thread_position_in_threadgroup]],
1295
+ uint ntg[[threads_per_threadgroup]]) {
1296
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1297
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1298
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1299
+
1300
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1301
+ const float x0 = src0_row[i0];
1302
+ const float x1 = src1_row[i0];
1303
+
1304
+ const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
1305
+
1306
+ dst_row[i0] = gelu_quick*x1;
1307
+ }
1308
+ }
1309
+
1194
1310
  template <bool norm>
1195
1311
  kernel void kernel_sum_rows(
1196
1312
  constant ggml_metal_kargs_sum_rows & args,
@@ -1253,24 +1369,28 @@ kernel void kernel_soft_max(
1253
1369
  device char * dst,
1254
1370
  constant ggml_metal_kargs_soft_max & args,
1255
1371
  threadgroup float * buf [[threadgroup(0)]],
1256
- uint tgpig[[threadgroup_position_in_grid]],
1257
- uint tpitg[[thread_position_in_threadgroup]],
1372
+ uint3 tgpig[[threadgroup_position_in_grid]],
1373
+ uint3 tpitg[[thread_position_in_threadgroup]],
1258
1374
  uint sgitg[[simdgroup_index_in_threadgroup]],
1259
1375
  uint tiisg[[thread_index_in_simdgroup]],
1260
- uint ntg[[threads_per_threadgroup]]) {
1261
- const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
1262
- const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
1263
- const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
1376
+ uint3 tptg[[threads_per_threadgroup]]) {
1377
+ const int32_t i03 = tgpig.z;
1378
+ const int32_t i02 = tgpig.y;
1379
+ const int32_t i01 = tgpig.x;
1380
+
1381
+ const int32_t i13 = i03%args.ne13;
1382
+ const int32_t i12 = i02%args.ne12;
1383
+ const int32_t i11 = i01;
1264
1384
 
1265
- device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
1266
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
1267
- device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
1385
+ device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
1386
+ device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
1387
+ device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
1268
1388
 
1269
1389
  float slope = 1.0f;
1270
1390
 
1271
1391
  // ALiBi
1272
1392
  if (args.max_bias > 0.0f) {
1273
- const int64_t h = i02;
1393
+ const int32_t h = i02;
1274
1394
 
1275
1395
  const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1276
1396
  const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
@@ -1281,13 +1401,13 @@ kernel void kernel_soft_max(
1281
1401
  // parallel max
1282
1402
  float lmax = -INFINITY;
1283
1403
 
1284
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1404
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1285
1405
  lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
1286
1406
  }
1287
1407
 
1288
1408
  // find the max value in the block
1289
1409
  float max_val = simd_max(lmax);
1290
- if (ntg > N_SIMDWIDTH) {
1410
+ if (tptg.x > N_SIMDWIDTH) {
1291
1411
  if (sgitg == 0) {
1292
1412
  buf[tiisg] = -INFINITY;
1293
1413
  }
@@ -1306,7 +1426,7 @@ kernel void kernel_soft_max(
1306
1426
 
1307
1427
  // parallel sum
1308
1428
  float lsum = 0.0f;
1309
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1429
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1310
1430
  const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
1311
1431
  lsum += exp_psrc0;
1312
1432
  pdst[i00] = exp_psrc0;
@@ -1318,7 +1438,7 @@ kernel void kernel_soft_max(
1318
1438
 
1319
1439
  float sum = simd_sum(lsum);
1320
1440
 
1321
- if (ntg > N_SIMDWIDTH) {
1441
+ if (tptg.x > N_SIMDWIDTH) {
1322
1442
  if (sgitg == 0) {
1323
1443
  buf[tiisg] = 0.0f;
1324
1444
  }
@@ -1337,7 +1457,7 @@ kernel void kernel_soft_max(
1337
1457
 
1338
1458
  const float inv_sum = 1.0f/sum;
1339
1459
 
1340
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1460
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1341
1461
  pdst[i00] *= inv_sum;
1342
1462
  }
1343
1463
  }
@@ -1349,23 +1469,27 @@ kernel void kernel_soft_max_4(
1349
1469
  device char * dst,
1350
1470
  constant ggml_metal_kargs_soft_max & args,
1351
1471
  threadgroup float * buf [[threadgroup(0)]],
1352
- uint tgpig[[threadgroup_position_in_grid]],
1353
- uint tpitg[[thread_position_in_threadgroup]],
1472
+ uint3 tgpig[[threadgroup_position_in_grid]],
1473
+ uint3 tpitg[[thread_position_in_threadgroup]],
1354
1474
  uint sgitg[[simdgroup_index_in_threadgroup]],
1355
1475
  uint tiisg[[thread_index_in_simdgroup]],
1356
- uint ntg[[threads_per_threadgroup]]) {
1357
- const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
1358
- const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
1359
- const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
1476
+ uint3 tptg[[threads_per_threadgroup]]) {
1477
+ const int32_t i03 = tgpig.z;
1478
+ const int32_t i02 = tgpig.y;
1479
+ const int32_t i01 = tgpig.x;
1360
1480
 
1361
- device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
1362
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
1363
- device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
1481
+ const int32_t i13 = i03%args.ne13;
1482
+ const int32_t i12 = i02%args.ne12;
1483
+ const int32_t i11 = i01;
1484
+
1485
+ device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
1486
+ device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
1487
+ device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
1364
1488
 
1365
1489
  float slope = 1.0f;
1366
1490
 
1367
1491
  if (args.max_bias > 0.0f) {
1368
- const int64_t h = i02;
1492
+ const int32_t h = i02;
1369
1493
 
1370
1494
  const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1371
1495
  const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
@@ -1376,14 +1500,14 @@ kernel void kernel_soft_max_4(
1376
1500
  // parallel max
1377
1501
  float4 lmax4 = -INFINITY;
1378
1502
 
1379
- for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1503
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1380
1504
  lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
1381
1505
  }
1382
1506
 
1383
1507
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
1384
1508
 
1385
1509
  float max_val = simd_max(lmax);
1386
- if (ntg > N_SIMDWIDTH) {
1510
+ if (tptg.x > N_SIMDWIDTH) {
1387
1511
  if (sgitg == 0) {
1388
1512
  buf[tiisg] = -INFINITY;
1389
1513
  }
@@ -1402,7 +1526,7 @@ kernel void kernel_soft_max_4(
1402
1526
 
1403
1527
  // parallel sum
1404
1528
  float4 lsum4 = 0.0f;
1405
- for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1529
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1406
1530
  const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
1407
1531
  lsum4 += exp_psrc4;
1408
1532
  pdst4[i00] = exp_psrc4;
@@ -1416,7 +1540,7 @@ kernel void kernel_soft_max_4(
1416
1540
 
1417
1541
  float sum = simd_sum(lsum);
1418
1542
 
1419
- if (ntg > N_SIMDWIDTH) {
1543
+ if (tptg.x > N_SIMDWIDTH) {
1420
1544
  if (sgitg == 0) {
1421
1545
  buf[tiisg] = 0.0f;
1422
1546
  }
@@ -1435,7 +1559,7 @@ kernel void kernel_soft_max_4(
1435
1559
 
1436
1560
  const float inv_sum = 1.0f/sum;
1437
1561
 
1438
- for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1562
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1439
1563
  pdst4[i00] *= inv_sum;
1440
1564
  }
1441
1565
  }
@@ -1521,7 +1645,7 @@ kernel void kernel_ssm_conv_f32(
1521
1645
  x[0] = sumf;
1522
1646
  }
1523
1647
 
1524
- // ref: ggml.c:ggml_compute_forward_ssm_scan_f32
1648
+ // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
1525
1649
  kernel void kernel_ssm_scan_f32(
1526
1650
  device const void * src0,
1527
1651
  device const void * src1,
@@ -1529,46 +1653,119 @@ kernel void kernel_ssm_scan_f32(
1529
1653
  device const void * src3,
1530
1654
  device const void * src4,
1531
1655
  device const void * src5,
1656
+ device const void * src6,
1532
1657
  device float * dst,
1533
1658
  constant ggml_metal_kargs_ssm_scan & args,
1534
1659
  uint3 tgpig[[threadgroup_position_in_grid]],
1535
1660
  uint3 tpitg[[thread_position_in_threadgroup]],
1536
1661
  uint3 ntg[[threads_per_threadgroup]]) {
1537
- const int64_t ir = tgpig.x;
1538
- const int64_t i3 = tgpig.y;
1662
+ const int64_t i1 = 0;
1663
+ const int64_t ir = tgpig.x; // current head
1664
+ const int64_t i3 = tgpig.y; // current seq
1665
+
1666
+ const uint64_t nb00 = sizeof(float);
1667
+ const uint64_t nb10 = sizeof(float);
1668
+ const uint64_t nb20 = sizeof(float);
1539
1669
 
1540
1670
  const int64_t nc = args.d_state;
1541
- // const int64_t nr = args.d_inner;
1671
+ const int64_t nr = args.d_inner;
1672
+ const int64_t nh = args.n_head;
1673
+ const int64_t ng = args.n_group;
1542
1674
  const int64_t n_t = args.n_seq_tokens;
1543
- // const int64_t n_s = args.n_seqs;
1675
+
1676
+ const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
1677
+
1678
+ device const int32_t * ids = (device const int32_t *) src6;
1679
+
1680
+ device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1681
+ device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1544
1682
 
1545
1683
  for (int64_t i2 = 0; i2 < n_t; ++i2) {
1546
- device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
1547
- device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
1548
- device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
1549
- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
1550
- device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
1551
- device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
1552
- device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
1553
- device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13);
1554
-
1555
- if (i2 > 0) {
1556
- s0 = s;
1557
- }
1558
-
1559
- // i1 == 0
1560
- float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
1561
- float x_dt = x[0] * dt_soft_plus;
1684
+ device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
1685
+ device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
1686
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
1687
+ device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
1688
+ device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
1689
+ device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
1690
+
1691
+ const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
1692
+ const float x_dt = x[0] * dt_soft_plus;
1562
1693
  float sumf = 0.0f;
1563
1694
 
1564
1695
  for (int64_t i0 = 0; i0 < nc; ++i0) {
1565
- int64_t i = i0;
1566
- float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
1696
+ const int64_t i = i0 + i1*nc;
1697
+ const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
1567
1698
  sumf += state * C[i0];
1568
1699
  s[i] = state;
1569
1700
  }
1570
1701
 
1571
1702
  y[0] = sumf;
1703
+
1704
+ // recurse
1705
+ s0 = s;
1706
+ }
1707
+ }
1708
+
1709
+ // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
1710
+ // TODO: optimize (e.g. by parallelizing over d_state)
1711
+ kernel void kernel_ssm_scan_f32_group(
1712
+ device const void * src0,
1713
+ device const void * src1,
1714
+ device const void * src2,
1715
+ device const void * src3,
1716
+ device const void * src4,
1717
+ device const void * src5,
1718
+ device const void * src6,
1719
+ device float * dst,
1720
+ constant ggml_metal_kargs_ssm_scan & args,
1721
+ uint3 tgpig[[threadgroup_position_in_grid]],
1722
+ uint3 tpitg[[thread_position_in_threadgroup]],
1723
+ uint3 ntg[[threads_per_threadgroup]]) {
1724
+ const int64_t i1 = tgpig.x;
1725
+ const int64_t ir = tgpig.y; // current head
1726
+ const int64_t i3 = tgpig.z; // current seq
1727
+
1728
+ const uint64_t nb00 = sizeof(float);
1729
+ const uint64_t nb10 = sizeof(float);
1730
+ const uint64_t nb20 = sizeof(float);
1731
+
1732
+ const int64_t nc = args.d_state;
1733
+ const int64_t nr = args.d_inner;
1734
+ const int64_t nh = args.n_head;
1735
+ const int64_t ng = args.n_group;
1736
+ const int64_t n_t = args.n_seq_tokens;
1737
+
1738
+ const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
1739
+
1740
+ device const int32_t * ids = (device const int32_t *) src6;
1741
+
1742
+ device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1743
+ device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1744
+
1745
+ for (int64_t i2 = 0; i2 < n_t; ++i2) {
1746
+ device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
1747
+ device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
1748
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
1749
+ device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
1750
+ device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
1751
+ device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
1752
+
1753
+ const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
1754
+ const float x_dt = x[0] * dt_soft_plus;
1755
+ const float dA = exp(dt_soft_plus * A[0]);
1756
+ float sumf = 0.0f;
1757
+
1758
+ for (int64_t i0 = 0; i0 < nc; ++i0) {
1759
+ const int64_t i = i0 + i1*nc;
1760
+ const float state = (s0[i] * dA) + (B[i0] * x_dt);
1761
+ sumf += state * C[i0];
1762
+ s[i] = state;
1763
+ }
1764
+
1765
+ y[0] = sumf;
1766
+
1767
+ // recurse
1768
+ s0 = s;
1572
1769
  }
1573
1770
  }
1574
1771
 
@@ -3709,7 +3906,7 @@ kernel void kernel_flash_attn_ext(
3709
3906
  // load the mask in shared memory
3710
3907
  #pragma unroll(Q)
3711
3908
  for (short j = 0; j < Q; ++j) {
3712
- device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
3909
+ device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
3713
3910
 
3714
3911
  const float m = pm[ic + tiisg];
3715
3912
 
@@ -4195,7 +4392,7 @@ kernel void kernel_flash_attn_ext_vec(
4195
4392
  const bool has_mask = mask != q;
4196
4393
 
4197
4394
  // pointer to the mask
4198
- device const half * pm = (device const half *) (mask + iq1*args.nb31);
4395
+ device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
4199
4396
 
4200
4397
  float slope = 1.0f;
4201
4398
 
@@ -65,6 +65,7 @@ set(GGML_OPENCL_KERNELS
65
65
  gemv_noshuffle_general
66
66
  gemv_noshuffle
67
67
  get_rows
68
+ glu
68
69
  group_norm
69
70
  im2col_f32
70
71
  im2col_f16
@@ -87,6 +88,7 @@ set(GGML_OPENCL_KERNELS
87
88
  rms_norm
88
89
  rope
89
90
  scale
91
+ set_rows
90
92
  sigmoid
91
93
  silu
92
94
  softmax_4_f32
@@ -102,6 +104,7 @@ set(GGML_OPENCL_KERNELS
102
104
  tanh
103
105
  pad
104
106
  repeat
107
+ mul_mat_f16_f32
105
108
  )
106
109
 
107
110
  foreach (K ${GGML_OPENCL_KERNELS})