@novastera-oss/llamarn 0.2.7 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (319) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  13. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  18. package/cpp/LlamaCppModel.cpp +56 -22
  19. package/cpp/build-info.cpp +2 -2
  20. package/cpp/llama.cpp/CMakeLists.txt +1 -2
  21. package/cpp/llama.cpp/README.md +4 -5
  22. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  23. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  24. package/cpp/llama.cpp/common/arg.cpp +24 -0
  25. package/cpp/llama.cpp/common/chat.cpp +37 -20
  26. package/cpp/llama.cpp/common/chat.h +2 -0
  27. package/cpp/llama.cpp/common/common.cpp +3 -0
  28. package/cpp/llama.cpp/common/common.h +5 -0
  29. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  30. package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
  31. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  32. package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
  33. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  34. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  35. package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
  36. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  100. package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
  101. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
  103. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
  104. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
  105. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  133. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  134. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  135. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  136. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  137. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
  138. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  141. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  142. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  144. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  168. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  169. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  170. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  172. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  173. package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
  177. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
  178. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
  179. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  180. package/cpp/llama.cpp/include/llama.h +8 -43
  181. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  182. package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
  183. package/cpp/llama.cpp/src/llama-arch.h +36 -1
  184. package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
  185. package/cpp/llama.cpp/src/llama-batch.h +105 -70
  186. package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
  187. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  188. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  189. package/cpp/llama.cpp/src/llama-context.h +13 -13
  190. package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
  191. package/cpp/llama.cpp/src/llama-graph.h +78 -79
  192. package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
  193. package/cpp/llama.cpp/src/llama-hparams.h +11 -0
  194. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
  195. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  196. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
  197. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
  198. package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
  199. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
  200. package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
  201. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
  202. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  203. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  204. package/cpp/llama.cpp/src/llama-memory.h +21 -22
  205. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  206. package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
  207. package/cpp/llama.cpp/src/llama-model.h +40 -0
  208. package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
  209. package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
  210. package/cpp/llama.cpp/src/llama-vocab.h +42 -0
  211. package/cpp/rn-utils.h +3 -0
  212. package/ios/include/chat.h +2 -0
  213. package/ios/include/common.h +5 -0
  214. package/ios/include/llama.h +8 -43
  215. package/ios/libs/llama.xcframework/Info.plist +19 -19
  216. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  217. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  218. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  219. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  220. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
  221. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  223. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  224. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  225. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  231. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  232. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  248. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  250. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
  254. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
  255. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  261. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  262. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  263. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  264. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  265. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  267. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  268. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
  269. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
  270. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
  274. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  275. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  276. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  277. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  278. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  279. package/package.json +1 -1
  280. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  315. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  316. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  317. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  318. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  319. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -35,6 +35,17 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
35
35
  -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
36
36
  };
37
37
 
38
+ static inline int best_index_int8(int n, constant float * val, float x) {
39
+ if (x <= val[0]) return 0;
40
+ if (x >= val[n-1]) return n-1;
41
+ int ml = 0, mu = n-1;
42
+ while (mu-ml > 1) {
43
+ int mav = (ml+mu)/2;
44
+ if (x < val[mav]) mu = mav; else ml = mav;
45
+ }
46
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
47
+ }
48
+
38
49
  // NOTE: this is not dequantizing - we are simply fitting the template
39
50
  template <typename type4x4>
40
51
  void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -97,6 +108,178 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
97
108
  }
98
109
  }
99
110
 
111
+ void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
112
+ #pragma METAL fp math_mode(safe)
113
+ float amax = 0.0f; // absolute max
114
+ float max = 0.0f;
115
+
116
+ for (int j = 0; j < QK4_0; j++) {
117
+ const float v = src[j];
118
+ if (amax < fabs(v)) {
119
+ amax = fabs(v);
120
+ max = v;
121
+ }
122
+ }
123
+
124
+ const float d = max / -8;
125
+ const float id = d ? 1.0f/d : 0.0f;
126
+
127
+ dst.d = d;
128
+
129
+ for (int j = 0; j < QK4_0/2; ++j) {
130
+ const float x0 = src[0 + j]*id;
131
+ const float x1 = src[QK4_0/2 + j]*id;
132
+
133
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
134
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
135
+
136
+ dst.qs[j] = xi0;
137
+ dst.qs[j] |= xi1 << 4;
138
+ }
139
+ }
140
+
141
+ void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
142
+ #pragma METAL fp math_mode(safe)
143
+ float min = FLT_MAX;
144
+ float max = -FLT_MAX;
145
+
146
+ for (int j = 0; j < QK4_1; j++) {
147
+ const float v = src[j];
148
+ if (min > v) min = v;
149
+ if (max < v) max = v;
150
+ }
151
+
152
+ const float d = (max - min) / ((1 << 4) - 1);
153
+ const float id = d ? 1.0f/d : 0.0f;
154
+
155
+ dst.d = d;
156
+ dst.m = min;
157
+
158
+ for (int j = 0; j < QK4_1/2; ++j) {
159
+ const float x0 = (src[0 + j] - min)*id;
160
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
161
+
162
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
163
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
164
+
165
+ dst.qs[j] = xi0;
166
+ dst.qs[j] |= xi1 << 4;
167
+ }
168
+ }
169
+
170
+ void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
171
+ #pragma METAL fp math_mode(safe)
172
+ float amax = 0.0f; // absolute max
173
+ float max = 0.0f;
174
+
175
+ for (int j = 0; j < QK5_0; j++) {
176
+ const float v = src[j];
177
+ if (amax < fabs(v)) {
178
+ amax = fabs(v);
179
+ max = v;
180
+ }
181
+ }
182
+
183
+ const float d = max / -16;
184
+ const float id = d ? 1.0f/d : 0.0f;
185
+
186
+ dst.d = d;
187
+
188
+ uint32_t qh = 0;
189
+ for (int j = 0; j < QK5_0/2; ++j) {
190
+ const float x0 = src[0 + j]*id;
191
+ const float x1 = src[QK5_0/2 + j]*id;
192
+
193
+ const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
194
+ const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
195
+
196
+ dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
197
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
198
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
199
+ }
200
+
201
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
202
+
203
+ for (int j = 0; j < 4; ++j) {
204
+ dst.qh[j] = qh8[j];
205
+ }
206
+ }
207
+
208
+ void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
209
+ #pragma METAL fp math_mode(safe)
210
+ float max = src[0];
211
+ float min = src[0];
212
+
213
+ for (int j = 1; j < QK5_1; j++) {
214
+ const float v = src[j];
215
+ min = v < min ? v : min;
216
+ max = v > max ? v : max;
217
+ }
218
+
219
+ const float d = (max - min) / 31;
220
+ const float id = d ? 1.0f/d : 0.0f;
221
+
222
+ dst.d = d;
223
+ dst.m = min;
224
+
225
+ uint32_t qh = 0;
226
+ for (int j = 0; j < QK5_1/2; ++j) {
227
+ const float x0 = (src[0 + j] - min)*id;
228
+ const float x1 = (src[QK5_1/2 + j] - min)*id;
229
+
230
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
231
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
232
+
233
+ dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
234
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
235
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
236
+ }
237
+
238
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
239
+
240
+ for (int j = 0; j < 4; ++j) {
241
+ dst.qh[j] = qh8[j];
242
+ }
243
+ }
244
+
245
+ void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
246
+ #pragma METAL fp math_mode(safe)
247
+ float amax = 0.0f; // absolute max
248
+ float max = 0.0f;
249
+
250
+ for (int j = 0; j < QK4_NL; j++) {
251
+ const float v = src[j];
252
+ if (amax < fabs(v)) {
253
+ amax = fabs(v);
254
+ max = v;
255
+ }
256
+ }
257
+
258
+ const float d = max / kvalues_iq4nl_f[0];
259
+ const float id = d ? 1.0f/d : 0.0f;
260
+
261
+ float sumqx = 0, sumq2 = 0;
262
+ for (int j = 0; j < QK4_NL/2; ++j) {
263
+ const float x0 = src[0 + j]*id;
264
+ const float x1 = src[QK4_NL/2 + j]*id;
265
+
266
+ const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
267
+ const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
268
+
269
+ dst.qs[j] = xi0 | (xi1 << 4);
270
+
271
+ const float v0 = kvalues_iq4nl_f[xi0];
272
+ const float v1 = kvalues_iq4nl_f[xi1];
273
+ const float w0 = src[0 + j]*src[0 + j];
274
+ const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
275
+ sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
276
+ sumq2 += w0*v0*v0 + w1*v1*v1;
277
+
278
+ }
279
+
280
+ dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
281
+ }
282
+
100
283
  template <typename type4x4>
101
284
  void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
102
285
  device const uint16_t * qs = ((device const uint16_t *)xb + 2);
@@ -279,6 +462,27 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
279
462
  }
280
463
  }
281
464
 
465
+ void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
466
+ #pragma METAL fp math_mode(safe)
467
+ float amax = 0.0f; // absolute max
468
+
469
+ for (int j = 0; j < QK8_0; j++) {
470
+ const float v = src[j];
471
+ amax = MAX(amax, fabs(v));
472
+ }
473
+
474
+ const float d = amax / ((1 << 7) - 1);
475
+ const float id = d ? 1.0f/d : 0.0f;
476
+
477
+ dst.d = d;
478
+
479
+ for (int j = 0; j < QK8_0; ++j) {
480
+ const float x0 = src[j]*id;
481
+
482
+ dst.qs[j] = round(x0);
483
+ }
484
+ }
485
+
282
486
  template <typename type4x4>
283
487
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
284
488
  const float d = xb->d;
@@ -810,16 +1014,18 @@ kernel void kernel_scale(
810
1014
  device const float * src0,
811
1015
  device float * dst,
812
1016
  constant float & scale,
1017
+ constant float & bias,
813
1018
  uint tpig[[thread_position_in_grid]]) {
814
- dst[tpig] = src0[tpig] * scale;
1019
+ dst[tpig] = src0[tpig] * scale + bias;
815
1020
  }
816
1021
 
817
1022
  kernel void kernel_scale_4(
818
1023
  device const float4 * src0,
819
1024
  device float4 * dst,
820
1025
  constant float & scale,
1026
+ constant float & bias,
821
1027
  uint tpig[[thread_position_in_grid]]) {
822
- dst[tpig] = src0[tpig] * scale;
1028
+ dst[tpig] = src0[tpig] * scale + bias;
823
1029
  }
824
1030
 
825
1031
  kernel void kernel_clamp(
@@ -993,6 +1199,114 @@ kernel void kernel_neg(
993
1199
  dst[tpig] = -src0[tpig];
994
1200
  }
995
1201
 
1202
+ kernel void kernel_reglu(
1203
+ device const char * src0,
1204
+ device const char * src1,
1205
+ device char * dst,
1206
+ constant ggml_metal_kargs_glu & args,
1207
+ uint tgpig[[threadgroup_position_in_grid]],
1208
+ uint tpitg[[thread_position_in_threadgroup]],
1209
+ uint ntg[[threads_per_threadgroup]]) {
1210
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1211
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1212
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1213
+
1214
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1215
+ const float x0 = src0_row[i0];
1216
+ const float x1 = src1_row[i0];
1217
+
1218
+ dst_row[i0] = x0*x1*(x0 > 0.0f);
1219
+ }
1220
+ }
1221
+
1222
+ kernel void kernel_geglu(
1223
+ device const char * src0,
1224
+ device const char * src1,
1225
+ device char * dst,
1226
+ constant ggml_metal_kargs_glu & args,
1227
+ uint tgpig[[threadgroup_position_in_grid]],
1228
+ uint tpitg[[thread_position_in_threadgroup]],
1229
+ uint ntg[[threads_per_threadgroup]]) {
1230
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1231
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1232
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1233
+
1234
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1235
+ const float x0 = src0_row[i0];
1236
+ const float x1 = src1_row[i0];
1237
+
1238
+ const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
1239
+
1240
+ dst_row[i0] = gelu*x1;
1241
+ }
1242
+ }
1243
+
1244
+ kernel void kernel_swiglu(
1245
+ device const char * src0,
1246
+ device const char * src1,
1247
+ device char * dst,
1248
+ constant ggml_metal_kargs_glu & args,
1249
+ uint tgpig[[threadgroup_position_in_grid]],
1250
+ uint tpitg[[thread_position_in_threadgroup]],
1251
+ uint ntg[[threads_per_threadgroup]]) {
1252
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1253
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1254
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1255
+
1256
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1257
+ const float x0 = src0_row[i0];
1258
+ const float x1 = src1_row[i0];
1259
+
1260
+ const float silu = x0 / (1.0f + exp(-x0));
1261
+
1262
+ dst_row[i0] = silu*x1;
1263
+ }
1264
+ }
1265
+
1266
+ kernel void kernel_geglu_erf(
1267
+ device const char * src0,
1268
+ device const char * src1,
1269
+ device char * dst,
1270
+ constant ggml_metal_kargs_glu & args,
1271
+ uint tgpig[[threadgroup_position_in_grid]],
1272
+ uint tpitg[[thread_position_in_threadgroup]],
1273
+ uint ntg[[threads_per_threadgroup]]) {
1274
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1275
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1276
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1277
+
1278
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1279
+ const float x0 = src0_row[i0];
1280
+ const float x1 = src1_row[i0];
1281
+
1282
+ const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
1283
+
1284
+ dst_row[i0] = gelu_erf*x1;
1285
+ }
1286
+ }
1287
+
1288
+ kernel void kernel_geglu_quick(
1289
+ device const char * src0,
1290
+ device const char * src1,
1291
+ device char * dst,
1292
+ constant ggml_metal_kargs_glu & args,
1293
+ uint tgpig[[threadgroup_position_in_grid]],
1294
+ uint tpitg[[thread_position_in_threadgroup]],
1295
+ uint ntg[[threads_per_threadgroup]]) {
1296
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1297
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1298
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1299
+
1300
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1301
+ const float x0 = src0_row[i0];
1302
+ const float x1 = src1_row[i0];
1303
+
1304
+ const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
1305
+
1306
+ dst_row[i0] = gelu_quick*x1;
1307
+ }
1308
+ }
1309
+
996
1310
  template <bool norm>
997
1311
  kernel void kernel_sum_rows(
998
1312
  constant ggml_metal_kargs_sum_rows & args,
@@ -1055,24 +1369,28 @@ kernel void kernel_soft_max(
1055
1369
  device char * dst,
1056
1370
  constant ggml_metal_kargs_soft_max & args,
1057
1371
  threadgroup float * buf [[threadgroup(0)]],
1058
- uint tgpig[[threadgroup_position_in_grid]],
1059
- uint tpitg[[thread_position_in_threadgroup]],
1372
+ uint3 tgpig[[threadgroup_position_in_grid]],
1373
+ uint3 tpitg[[thread_position_in_threadgroup]],
1060
1374
  uint sgitg[[simdgroup_index_in_threadgroup]],
1061
1375
  uint tiisg[[thread_index_in_simdgroup]],
1062
- uint ntg[[threads_per_threadgroup]]) {
1063
- const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
1064
- const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
1065
- const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
1376
+ uint3 tptg[[threads_per_threadgroup]]) {
1377
+ const int32_t i03 = tgpig.z;
1378
+ const int32_t i02 = tgpig.y;
1379
+ const int32_t i01 = tgpig.x;
1380
+
1381
+ const int32_t i13 = i03%args.ne13;
1382
+ const int32_t i12 = i02%args.ne12;
1383
+ const int32_t i11 = i01;
1066
1384
 
1067
- device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
1068
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
1069
- device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
1385
+ device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
1386
+ device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
1387
+ device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
1070
1388
 
1071
1389
  float slope = 1.0f;
1072
1390
 
1073
1391
  // ALiBi
1074
1392
  if (args.max_bias > 0.0f) {
1075
- const int64_t h = i02;
1393
+ const int32_t h = i02;
1076
1394
 
1077
1395
  const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1078
1396
  const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
@@ -1083,13 +1401,13 @@ kernel void kernel_soft_max(
1083
1401
  // parallel max
1084
1402
  float lmax = -INFINITY;
1085
1403
 
1086
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1404
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1087
1405
  lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
1088
1406
  }
1089
1407
 
1090
1408
  // find the max value in the block
1091
1409
  float max_val = simd_max(lmax);
1092
- if (ntg > N_SIMDWIDTH) {
1410
+ if (tptg.x > N_SIMDWIDTH) {
1093
1411
  if (sgitg == 0) {
1094
1412
  buf[tiisg] = -INFINITY;
1095
1413
  }
@@ -1108,7 +1426,7 @@ kernel void kernel_soft_max(
1108
1426
 
1109
1427
  // parallel sum
1110
1428
  float lsum = 0.0f;
1111
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1429
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1112
1430
  const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
1113
1431
  lsum += exp_psrc0;
1114
1432
  pdst[i00] = exp_psrc0;
@@ -1120,7 +1438,7 @@ kernel void kernel_soft_max(
1120
1438
 
1121
1439
  float sum = simd_sum(lsum);
1122
1440
 
1123
- if (ntg > N_SIMDWIDTH) {
1441
+ if (tptg.x > N_SIMDWIDTH) {
1124
1442
  if (sgitg == 0) {
1125
1443
  buf[tiisg] = 0.0f;
1126
1444
  }
@@ -1139,7 +1457,7 @@ kernel void kernel_soft_max(
1139
1457
 
1140
1458
  const float inv_sum = 1.0f/sum;
1141
1459
 
1142
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1460
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1143
1461
  pdst[i00] *= inv_sum;
1144
1462
  }
1145
1463
  }
@@ -1151,23 +1469,27 @@ kernel void kernel_soft_max_4(
1151
1469
  device char * dst,
1152
1470
  constant ggml_metal_kargs_soft_max & args,
1153
1471
  threadgroup float * buf [[threadgroup(0)]],
1154
- uint tgpig[[threadgroup_position_in_grid]],
1155
- uint tpitg[[thread_position_in_threadgroup]],
1472
+ uint3 tgpig[[threadgroup_position_in_grid]],
1473
+ uint3 tpitg[[thread_position_in_threadgroup]],
1156
1474
  uint sgitg[[simdgroup_index_in_threadgroup]],
1157
1475
  uint tiisg[[thread_index_in_simdgroup]],
1158
- uint ntg[[threads_per_threadgroup]]) {
1159
- const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
1160
- const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
1161
- const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
1476
+ uint3 tptg[[threads_per_threadgroup]]) {
1477
+ const int32_t i03 = tgpig.z;
1478
+ const int32_t i02 = tgpig.y;
1479
+ const int32_t i01 = tgpig.x;
1162
1480
 
1163
- device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
1164
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
1165
- device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
1481
+ const int32_t i13 = i03%args.ne13;
1482
+ const int32_t i12 = i02%args.ne12;
1483
+ const int32_t i11 = i01;
1484
+
1485
+ device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
1486
+ device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
1487
+ device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
1166
1488
 
1167
1489
  float slope = 1.0f;
1168
1490
 
1169
1491
  if (args.max_bias > 0.0f) {
1170
- const int64_t h = i02;
1492
+ const int32_t h = i02;
1171
1493
 
1172
1494
  const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1173
1495
  const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
@@ -1178,14 +1500,14 @@ kernel void kernel_soft_max_4(
1178
1500
  // parallel max
1179
1501
  float4 lmax4 = -INFINITY;
1180
1502
 
1181
- for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1503
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1182
1504
  lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
1183
1505
  }
1184
1506
 
1185
1507
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
1186
1508
 
1187
1509
  float max_val = simd_max(lmax);
1188
- if (ntg > N_SIMDWIDTH) {
1510
+ if (tptg.x > N_SIMDWIDTH) {
1189
1511
  if (sgitg == 0) {
1190
1512
  buf[tiisg] = -INFINITY;
1191
1513
  }
@@ -1204,7 +1526,7 @@ kernel void kernel_soft_max_4(
1204
1526
 
1205
1527
  // parallel sum
1206
1528
  float4 lsum4 = 0.0f;
1207
- for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1529
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1208
1530
  const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
1209
1531
  lsum4 += exp_psrc4;
1210
1532
  pdst4[i00] = exp_psrc4;
@@ -1218,7 +1540,7 @@ kernel void kernel_soft_max_4(
1218
1540
 
1219
1541
  float sum = simd_sum(lsum);
1220
1542
 
1221
- if (ntg > N_SIMDWIDTH) {
1543
+ if (tptg.x > N_SIMDWIDTH) {
1222
1544
  if (sgitg == 0) {
1223
1545
  buf[tiisg] = 0.0f;
1224
1546
  }
@@ -1237,7 +1559,7 @@ kernel void kernel_soft_max_4(
1237
1559
 
1238
1560
  const float inv_sum = 1.0f/sum;
1239
1561
 
1240
- for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1562
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1241
1563
  pdst4[i00] *= inv_sum;
1242
1564
  }
1243
1565
  }
@@ -1323,7 +1645,7 @@ kernel void kernel_ssm_conv_f32(
1323
1645
  x[0] = sumf;
1324
1646
  }
1325
1647
 
1326
- // ref: ggml.c:ggml_compute_forward_ssm_scan_f32
1648
+ // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
1327
1649
  kernel void kernel_ssm_scan_f32(
1328
1650
  device const void * src0,
1329
1651
  device const void * src1,
@@ -1331,46 +1653,119 @@ kernel void kernel_ssm_scan_f32(
1331
1653
  device const void * src3,
1332
1654
  device const void * src4,
1333
1655
  device const void * src5,
1656
+ device const void * src6,
1334
1657
  device float * dst,
1335
1658
  constant ggml_metal_kargs_ssm_scan & args,
1336
1659
  uint3 tgpig[[threadgroup_position_in_grid]],
1337
1660
  uint3 tpitg[[thread_position_in_threadgroup]],
1338
1661
  uint3 ntg[[threads_per_threadgroup]]) {
1339
- const int64_t ir = tgpig.x;
1340
- const int64_t i3 = tgpig.y;
1662
+ const int64_t i1 = 0;
1663
+ const int64_t ir = tgpig.x; // current head
1664
+ const int64_t i3 = tgpig.y; // current seq
1665
+
1666
+ const uint64_t nb00 = sizeof(float);
1667
+ const uint64_t nb10 = sizeof(float);
1668
+ const uint64_t nb20 = sizeof(float);
1341
1669
 
1342
1670
  const int64_t nc = args.d_state;
1343
- // const int64_t nr = args.d_inner;
1671
+ const int64_t nr = args.d_inner;
1672
+ const int64_t nh = args.n_head;
1673
+ const int64_t ng = args.n_group;
1344
1674
  const int64_t n_t = args.n_seq_tokens;
1345
- // const int64_t n_s = args.n_seqs;
1675
+
1676
+ const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
1677
+
1678
+ device const int32_t * ids = (device const int32_t *) src6;
1679
+
1680
+ device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1681
+ device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1346
1682
 
1347
1683
  for (int64_t i2 = 0; i2 < n_t; ++i2) {
1348
- device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
1349
- device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
1350
- device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
1351
- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
1352
- device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
1353
- device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
1354
- device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
1355
- device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13);
1356
-
1357
- if (i2 > 0) {
1358
- s0 = s;
1359
- }
1360
-
1361
- // i1 == 0
1362
- float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
1363
- float x_dt = x[0] * dt_soft_plus;
1684
+ device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
1685
+ device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
1686
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
1687
+ device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
1688
+ device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
1689
+ device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
1690
+
1691
+ const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
1692
+ const float x_dt = x[0] * dt_soft_plus;
1364
1693
  float sumf = 0.0f;
1365
1694
 
1366
1695
  for (int64_t i0 = 0; i0 < nc; ++i0) {
1367
- int64_t i = i0;
1368
- float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
1696
+ const int64_t i = i0 + i1*nc;
1697
+ const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
1369
1698
  sumf += state * C[i0];
1370
1699
  s[i] = state;
1371
1700
  }
1372
1701
 
1373
1702
  y[0] = sumf;
1703
+
1704
+ // recurse
1705
+ s0 = s;
1706
+ }
1707
+ }
1708
+
1709
+ // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
1710
+ // TODO: optimize (e.g. by parallelizing over d_state)
1711
+ kernel void kernel_ssm_scan_f32_group(
1712
+ device const void * src0,
1713
+ device const void * src1,
1714
+ device const void * src2,
1715
+ device const void * src3,
1716
+ device const void * src4,
1717
+ device const void * src5,
1718
+ device const void * src6,
1719
+ device float * dst,
1720
+ constant ggml_metal_kargs_ssm_scan & args,
1721
+ uint3 tgpig[[threadgroup_position_in_grid]],
1722
+ uint3 tpitg[[thread_position_in_threadgroup]],
1723
+ uint3 ntg[[threads_per_threadgroup]]) {
1724
+ const int64_t i1 = tgpig.x;
1725
+ const int64_t ir = tgpig.y; // current head
1726
+ const int64_t i3 = tgpig.z; // current seq
1727
+
1728
+ const uint64_t nb00 = sizeof(float);
1729
+ const uint64_t nb10 = sizeof(float);
1730
+ const uint64_t nb20 = sizeof(float);
1731
+
1732
+ const int64_t nc = args.d_state;
1733
+ const int64_t nr = args.d_inner;
1734
+ const int64_t nh = args.n_head;
1735
+ const int64_t ng = args.n_group;
1736
+ const int64_t n_t = args.n_seq_tokens;
1737
+
1738
+ const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
1739
+
1740
+ device const int32_t * ids = (device const int32_t *) src6;
1741
+
1742
+ device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1743
+ device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1744
+
1745
+ for (int64_t i2 = 0; i2 < n_t; ++i2) {
1746
+ device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
1747
+ device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
1748
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
1749
+ device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
1750
+ device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
1751
+ device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
1752
+
1753
+ const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
1754
+ const float x_dt = x[0] * dt_soft_plus;
1755
+ const float dA = exp(dt_soft_plus * A[0]);
1756
+ float sumf = 0.0f;
1757
+
1758
+ for (int64_t i0 = 0; i0 < nc; ++i0) {
1759
+ const int64_t i = i0 + i1*nc;
1760
+ const float state = (s0[i] * dA) + (B[i0] * x_dt);
1761
+ sumf += state * C[i0];
1762
+ s[i] = state;
1763
+ }
1764
+
1765
+ y[0] = sumf;
1766
+
1767
+ // recurse
1768
+ s0 = s;
1374
1769
  }
1375
1770
  }
1376
1771
 
@@ -2532,6 +2927,70 @@ template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<
2532
2927
  template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
2533
2928
  #endif
2534
2929
 
2930
+ template<typename T04, typename T14, typename args_t>
2931
+ void kernel_mul_mv_c4_impl(
2932
+ args_t args,
2933
+ device const char * src0,
2934
+ device const char * src1,
2935
+ device char * dst,
2936
+ uint3 tgpig,
2937
+ ushort tiisg) {
2938
+ const int r0 = tgpig.x*32 + tiisg;
2939
+ const int rb = tgpig.y*N_MV_T_T;
2940
+ const int im = tgpig.z;
2941
+
2942
+ if (r0 >= args.ne01) {
2943
+ return;
2944
+ }
2945
+
2946
+ const uint i12 = im%args.ne12;
2947
+ const uint i13 = im/args.ne12;
2948
+
2949
+ const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
2950
+
2951
+ device const T04 * x = (device const T04 *) (src0 + offset0);
2952
+
2953
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
2954
+
2955
+ for (int row = 0; row < N_MV_T_T; ++row) {
2956
+ int r1 = rb + row;
2957
+ if (r1 >= args.ne11) {
2958
+ break;
2959
+ }
2960
+
2961
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
2962
+
2963
+ device const T14 * y = (device const T14 *) (src1 + offset1);
2964
+
2965
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
2966
+ }
2967
+ }
2968
+
2969
+ template<typename T04, typename T14>
2970
+ kernel void kernel_mul_mv_c4(
2971
+ constant ggml_metal_kargs_mul_mv & args,
2972
+ device const char * src0,
2973
+ device const char * src1,
2974
+ device char * dst,
2975
+ uint3 tgpig[[threadgroup_position_in_grid]],
2976
+ ushort tiisg[[thread_index_in_simdgroup]]) {
2977
+ kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
2978
+ args,
2979
+ src0,
2980
+ src1,
2981
+ dst,
2982
+ tgpig,
2983
+ tiisg);
2984
+ }
2985
+
2986
+ typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
2987
+
2988
+ template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
2989
+ template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
2990
+ #if defined(GGML_METAL_USE_BF16)
2991
+ template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
2992
+ #endif
2993
+
2535
2994
  template<typename T, typename T4>
2536
2995
  kernel void kernel_mul_mv_1row(
2537
2996
  constant ggml_metal_kargs_mul_mv & args,
@@ -3447,7 +3906,7 @@ kernel void kernel_flash_attn_ext(
3447
3906
  // load the mask in shared memory
3448
3907
  #pragma unroll(Q)
3449
3908
  for (short j = 0; j < Q; ++j) {
3450
- device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
3909
+ device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
3451
3910
 
3452
3911
  const float m = pm[ic + tiisg];
3453
3912
 
@@ -3933,7 +4392,7 @@ kernel void kernel_flash_attn_ext_vec(
3933
4392
  const bool has_mask = mask != q;
3934
4393
 
3935
4394
  // pointer to the mask
3936
- device const half * pm = (device const half *) (mask + iq1*args.nb31);
4395
+ device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
3937
4396
 
3938
4397
  float slope = 1.0f;
3939
4398
 
@@ -4306,11 +4765,16 @@ kernel void kernel_cpy(
4306
4765
  device const char * src0,
4307
4766
  device char * dst,
4308
4767
  uint3 tgpig[[threadgroup_position_in_grid]],
4768
+ uint tiitg[[thread_index_in_threadgroup]],
4309
4769
  ushort3 tpitg[[thread_position_in_threadgroup]],
4310
- ushort3 ntg[[threads_per_threadgroup]]) {
4770
+ ushort3 tptg[[threads_per_threadgroup]]) {
4311
4771
  const int i03 = tgpig[2];
4312
4772
  const int i02 = tgpig[1];
4313
- const int i01 = tgpig[0];
4773
+ const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;
4774
+
4775
+ if (i01 >= args.ne01) {
4776
+ return;
4777
+ }
4314
4778
 
4315
4779
  const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
4316
4780
 
@@ -4321,7 +4785,7 @@ kernel void kernel_cpy(
4321
4785
 
4322
4786
  device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
4323
4787
 
4324
- for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
4788
+ for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
4325
4789
  device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4326
4790
  dst_data[i00] = (T1) src[0];
4327
4791
  }
@@ -4341,6 +4805,7 @@ template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bf
4341
4805
  template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
4342
4806
  #endif
4343
4807
 
4808
+ // TODO: templetify these kernels
4344
4809
  kernel void kernel_cpy_f32_q8_0(
4345
4810
  constant ggml_metal_kargs_cpy & args,
4346
4811
  device const char * src0,
@@ -4364,23 +4829,7 @@ kernel void kernel_cpy_f32_q8_0(
4364
4829
  for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
4365
4830
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4366
4831
 
4367
- float amax = 0.0f; // absolute max
4368
-
4369
- for (int j = 0; j < QK8_0; j++) {
4370
- const float v = src[j];
4371
- amax = MAX(amax, fabs(v));
4372
- }
4373
-
4374
- const float d = amax / ((1 << 7) - 1);
4375
- const float id = d ? 1.0f/d : 0.0f;
4376
-
4377
- dst_data[i00/QK8_0].d = d;
4378
-
4379
- for (int j = 0; j < QK8_0; ++j) {
4380
- const float x0 = src[j]*id;
4381
-
4382
- dst_data[i00/QK8_0].qs[j] = round(x0);
4383
- }
4832
+ quantize_q8_0(src, dst_data[i00/QK8_0]);
4384
4833
  }
4385
4834
  }
4386
4835
 
@@ -4407,32 +4856,7 @@ kernel void kernel_cpy_f32_q4_0(
4407
4856
  for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
4408
4857
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4409
4858
 
4410
- float amax = 0.0f; // absolute max
4411
- float max = 0.0f;
4412
-
4413
- for (int j = 0; j < QK4_0; j++) {
4414
- const float v = src[j];
4415
- if (amax < fabs(v)) {
4416
- amax = fabs(v);
4417
- max = v;
4418
- }
4419
- }
4420
-
4421
- const float d = max / -8;
4422
- const float id = d ? 1.0f/d : 0.0f;
4423
-
4424
- dst_data[i00/QK4_0].d = d;
4425
-
4426
- for (int j = 0; j < QK4_0/2; ++j) {
4427
- const float x0 = src[0 + j]*id;
4428
- const float x1 = src[QK4_0/2 + j]*id;
4429
-
4430
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
4431
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
4432
-
4433
- dst_data[i00/QK4_0].qs[j] = xi0;
4434
- dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
4435
- }
4859
+ quantize_q4_0(src, dst_data[i00/QK4_0]);
4436
4860
  }
4437
4861
  }
4438
4862
 
@@ -4459,31 +4883,7 @@ kernel void kernel_cpy_f32_q4_1(
4459
4883
  for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
4460
4884
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4461
4885
 
4462
- float min = FLT_MAX;
4463
- float max = -FLT_MAX;
4464
-
4465
- for (int j = 0; j < QK4_1; j++) {
4466
- const float v = src[j];
4467
- if (min > v) min = v;
4468
- if (max < v) max = v;
4469
- }
4470
-
4471
- const float d = (max - min) / ((1 << 4) - 1);
4472
- const float id = d ? 1.0f/d : 0.0f;
4473
-
4474
- dst_data[i00/QK4_1].d = d;
4475
- dst_data[i00/QK4_1].m = min;
4476
-
4477
- for (int j = 0; j < QK4_1/2; ++j) {
4478
- const float x0 = (src[0 + j] - min)*id;
4479
- const float x1 = (src[QK4_1/2 + j] - min)*id;
4480
-
4481
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
4482
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
4483
-
4484
- dst_data[i00/QK4_1].qs[j] = xi0;
4485
- dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
4486
- }
4886
+ quantize_q4_1(src, dst_data[i00/QK4_1]);
4487
4887
  }
4488
4888
  }
4489
4889
 
@@ -4510,38 +4910,7 @@ kernel void kernel_cpy_f32_q5_0(
4510
4910
  for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
4511
4911
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4512
4912
 
4513
- float amax = 0.0f; // absolute max
4514
- float max = 0.0f;
4515
-
4516
- for (int j = 0; j < QK5_0; j++) {
4517
- const float v = src[j];
4518
- if (amax < fabs(v)) {
4519
- amax = fabs(v);
4520
- max = v;
4521
- }
4522
- }
4523
-
4524
- const float d = max / -16;
4525
- const float id = d ? 1.0f/d : 0.0f;
4526
-
4527
- dst_data[i00/QK5_0].d = d;
4528
-
4529
- uint32_t qh = 0;
4530
- for (int j = 0; j < QK5_0/2; ++j) {
4531
- const float x0 = src[0 + j]*id;
4532
- const float x1 = src[QK5_0/2 + j]*id;
4533
-
4534
- const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
4535
- const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
4536
-
4537
- dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
4538
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
4539
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
4540
- }
4541
- thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
4542
- for (int j = 0; j < 4; ++j) {
4543
- dst_data[i00/QK5_0].qh[j] = qh8[j];
4544
- }
4913
+ quantize_q5_0(src, dst_data[i00/QK5_0]);
4545
4914
  }
4546
4915
  }
4547
4916
 
@@ -4568,51 +4937,10 @@ kernel void kernel_cpy_f32_q5_1(
4568
4937
  for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
4569
4938
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4570
4939
 
4571
- float max = src[0];
4572
- float min = src[0];
4573
-
4574
- for (int j = 1; j < QK5_1; j++) {
4575
- const float v = src[j];
4576
- min = v < min ? v : min;
4577
- max = v > max ? v : max;
4578
- }
4579
-
4580
- const float d = (max - min) / 31;
4581
- const float id = d ? 1.0f/d : 0.0f;
4582
-
4583
- dst_data[i00/QK5_1].d = d;
4584
- dst_data[i00/QK5_1].m = min;
4585
-
4586
- uint32_t qh = 0;
4587
- for (int j = 0; j < QK5_1/2; ++j) {
4588
- const float x0 = (src[0 + j] - min)*id;
4589
- const float x1 = (src[QK5_1/2 + j] - min)*id;
4590
-
4591
- const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
4592
- const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
4593
-
4594
- dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
4595
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
4596
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
4597
- }
4598
- thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
4599
- for (int j = 0; j < 4; ++j) {
4600
- dst_data[i00/QK5_1].qh[j] = qh8[j];
4601
- }
4940
+ quantize_q5_1(src, dst_data[i00/QK5_1]);
4602
4941
  }
4603
4942
  }
4604
4943
 
4605
- static inline int best_index_int8(int n, constant float * val, float x) {
4606
- if (x <= val[0]) return 0;
4607
- if (x >= val[n-1]) return n-1;
4608
- int ml = 0, mu = n-1;
4609
- while (mu-ml > 1) {
4610
- int mav = (ml+mu)/2;
4611
- if (x < val[mav]) mu = mav; else ml = mav;
4612
- }
4613
- return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
4614
- }
4615
-
4616
4944
  kernel void kernel_cpy_f32_iq4_nl(
4617
4945
  constant ggml_metal_kargs_cpy & args,
4618
4946
  device const char * src0,
@@ -4636,40 +4964,7 @@ kernel void kernel_cpy_f32_iq4_nl(
4636
4964
  for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
4637
4965
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4638
4966
 
4639
- float amax = 0.0f; // absolute max
4640
- float max = 0.0f;
4641
-
4642
- for (int j = 0; j < QK4_NL; j++) {
4643
- const float v = src[j];
4644
- if (amax < fabs(v)) {
4645
- amax = fabs(v);
4646
- max = v;
4647
- }
4648
- }
4649
-
4650
- const float d = max / kvalues_iq4nl_f[0];
4651
- const float id = d ? 1.0f/d : 0.0f;
4652
-
4653
- float sumqx = 0, sumq2 = 0;
4654
- for (int j = 0; j < QK4_NL/2; ++j) {
4655
- const float x0 = src[0 + j]*id;
4656
- const float x1 = src[QK4_NL/2 + j]*id;
4657
-
4658
- const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
4659
- const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
4660
-
4661
- dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
4662
-
4663
- const float v0 = kvalues_iq4nl_f[xi0];
4664
- const float v1 = kvalues_iq4nl_f[xi1];
4665
- const float w0 = src[0 + j]*src[0 + j];
4666
- const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
4667
- sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
4668
- sumq2 += w0*v0*v0 + w1*v1*v1;
4669
-
4670
- }
4671
-
4672
- dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
4967
+ quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
4673
4968
  }
4674
4969
  }
4675
4970
 
@@ -6350,10 +6645,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
6350
6645
 
6351
6646
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
6352
6647
  kernel void kernel_get_rows_q(
6648
+ constant ggml_metal_kargs_get_rows & args,
6353
6649
  device const void * src0,
6354
6650
  device const void * src1,
6355
6651
  device float * dst,
6356
- constant ggml_metal_kargs_get_rows & args,
6357
6652
  uint3 tgpig[[threadgroup_position_in_grid]],
6358
6653
  uint tiitg[[thread_index_in_threadgroup]],
6359
6654
  uint3 tptg [[threads_per_threadgroup]]) {
@@ -6373,10 +6668,10 @@ kernel void kernel_get_rows_q(
6373
6668
 
6374
6669
  template<typename T>
6375
6670
  kernel void kernel_get_rows_f(
6671
+ constant ggml_metal_kargs_get_rows & args,
6376
6672
  device const void * src0,
6377
6673
  device const void * src1,
6378
6674
  device float * dst,
6379
- constant ggml_metal_kargs_get_rows & args,
6380
6675
  uint3 tgpig[[threadgroup_position_in_grid]],
6381
6676
  uint tiitg[[thread_index_in_threadgroup]],
6382
6677
  uint3 tptg [[threads_per_threadgroup]]) {
@@ -6394,10 +6689,10 @@ kernel void kernel_get_rows_f(
6394
6689
  }
6395
6690
 
6396
6691
  kernel void kernel_get_rows_i32(
6692
+ constant ggml_metal_kargs_get_rows & args,
6397
6693
  device const void * src0,
6398
6694
  device const void * src1,
6399
6695
  device int32_t * dst,
6400
- constant ggml_metal_kargs_get_rows & args,
6401
6696
  uint3 tgpig[[threadgroup_position_in_grid]],
6402
6697
  uint tiitg[[thread_index_in_threadgroup]],
6403
6698
  uint3 tptg [[threads_per_threadgroup]]) {
@@ -6414,6 +6709,67 @@ kernel void kernel_get_rows_i32(
6414
6709
  }
6415
6710
  }
6416
6711
 
6712
+ template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
6713
+ kernel void kernel_set_rows_q32(
6714
+ constant ggml_metal_kargs_set_rows & args,
6715
+ device const void * src0,
6716
+ device const void * src1,
6717
+ device float * dst,
6718
+ uint3 tgpig[[threadgroup_position_in_grid]],
6719
+ uint tiitg[[thread_index_in_threadgroup]],
6720
+ uint3 tptg [[threads_per_threadgroup]]) {
6721
+ const int32_t i03 = tgpig.z;
6722
+ const int32_t i02 = tgpig.y;
6723
+
6724
+ const int32_t i12 = i03%args.ne12;
6725
+ const int32_t i11 = i02%args.ne11;
6726
+
6727
+ const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
6728
+ if (i01 >= args.ne01) {
6729
+ return;
6730
+ }
6731
+
6732
+ const int32_t i10 = i01;
6733
+ const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
6734
+
6735
+ device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
6736
+ const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
6737
+
6738
+ for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
6739
+ quantize_func(src_row + 32*ind, dst_row[ind]);
6740
+ }
6741
+ }
6742
+
6743
+ template<typename T>
6744
+ kernel void kernel_set_rows_f(
6745
+ constant ggml_metal_kargs_set_rows & args,
6746
+ device const void * src0,
6747
+ device const void * src1,
6748
+ device float * dst,
6749
+ uint3 tgpig[[threadgroup_position_in_grid]],
6750
+ uint tiitg[[thread_index_in_threadgroup]],
6751
+ uint3 tptg [[threads_per_threadgroup]]) {
6752
+ const int32_t i03 = tgpig.z;
6753
+ const int32_t i02 = tgpig.y;
6754
+
6755
+ const int32_t i12 = i03%args.ne12;
6756
+ const int32_t i11 = i02%args.ne11;
6757
+
6758
+ const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
6759
+ if (i01 >= args.ne01) {
6760
+ return;
6761
+ }
6762
+
6763
+ const int32_t i10 = i01;
6764
+ const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
6765
+
6766
+ device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
6767
+ const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
6768
+
6769
+ for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
6770
+ dst_row[ind] = (T) src_row[ind];
6771
+ }
6772
+ }
6417
6773
 
6418
6774
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
6419
6775
  #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
@@ -6837,6 +7193,27 @@ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get
6837
7193
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
6838
7194
  template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6839
7195
 
7196
+ //
7197
+ // set rows
7198
+ //
7199
+
7200
+ typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
7201
+
7202
+ template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f<float>;
7203
+ template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f<half>;
7204
+ #if defined(GGML_METAL_USE_BF16)
7205
+ template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
7206
+ #endif
7207
+
7208
+ typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;
7209
+
7210
+ template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0, quantize_q8_0>;
7211
+ template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0, quantize_q4_0>;
7212
+ template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1, quantize_q4_1>;
7213
+ template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0, quantize_q5_0>;
7214
+ template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1, quantize_q5_1>;
7215
+ template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
7216
+
6840
7217
  //
6841
7218
  // matrix-matrix multiplication
6842
7219
  //