@novastera-oss/llamarn 0.2.6 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. package/android/src/main/cpp/include/llama.h +134 -36
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +2 -2
  11. package/cpp/LlamaCppModel.h +3 -3
  12. package/cpp/PureCppImpl.cpp +1 -1
  13. package/cpp/PureCppImpl.h +2 -2
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  16. package/cpp/llama.cpp/Makefile +2 -2
  17. package/cpp/llama.cpp/README.md +32 -13
  18. package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
  19. package/cpp/llama.cpp/common/arg.cpp +30 -6
  20. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  21. package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
  22. package/cpp/llama.cpp/common/chat-parser.h +2 -0
  23. package/cpp/llama.cpp/common/chat.cpp +12 -9
  24. package/cpp/llama.cpp/common/chat.h +1 -1
  25. package/cpp/llama.cpp/common/common.cpp +50 -40
  26. package/cpp/llama.cpp/common/common.h +5 -2
  27. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  28. package/cpp/llama.cpp/convert_hf_to_gguf.py +97 -56
  29. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
  30. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  31. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +47 -13
  32. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  34. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  35. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -2
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +5 -8
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +6 -8
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  70. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  72. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
  74. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
  76. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  84. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -38
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +431 -247
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  101. package/cpp/llama.cpp/ggml/src/ggml.c +0 -6
  102. package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
  103. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
  104. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
  105. package/cpp/llama.cpp/include/llama.h +134 -36
  106. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  107. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  108. package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
  109. package/cpp/llama.cpp/src/llama-arch.h +7 -1
  110. package/cpp/llama.cpp/src/llama-batch.cpp +270 -19
  111. package/cpp/llama.cpp/src/llama-batch.h +36 -11
  112. package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
  113. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  114. package/cpp/llama.cpp/src/llama-context.cpp +313 -213
  115. package/cpp/llama.cpp/src/llama-context.h +16 -12
  116. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  117. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  118. package/cpp/llama.cpp/src/llama-graph.cpp +249 -129
  119. package/cpp/llama.cpp/src/llama-graph.h +90 -34
  120. package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
  121. package/cpp/llama.cpp/src/llama-hparams.h +8 -2
  122. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +82 -50
  123. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  124. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +292 -174
  125. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +68 -38
  126. package/cpp/llama.cpp/src/llama-kv-cells.h +18 -13
  127. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
  128. package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
  129. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +266 -282
  130. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +54 -57
  131. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  132. package/cpp/llama.cpp/src/llama-memory.h +64 -23
  133. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  134. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  135. package/cpp/llama.cpp/src/llama-model.cpp +726 -141
  136. package/cpp/llama.cpp/src/llama-model.h +4 -0
  137. package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
  138. package/cpp/llama.cpp/src/llama-vocab.cpp +32 -23
  139. package/cpp/llama.cpp/src/llama.cpp +11 -7
  140. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  141. package/cpp/rn-completion.cpp +2 -2
  142. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  143. package/ios/include/chat.h +1 -1
  144. package/ios/include/common.h +5 -2
  145. package/ios/include/llama.h +134 -36
  146. package/ios/libs/llama.xcframework/Info.plist +18 -18
  147. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  148. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  149. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +134 -36
  150. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  151. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  152. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  153. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  154. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  155. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3624
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +134 -36
  160. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +134 -36
  161. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  162. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +134 -36
  163. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  164. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  165. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  167. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +134 -36
  168. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  173. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  175. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4725
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +134 -36
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  178. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4746
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3652
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  183. package/package.json +1 -2
  184. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  185. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  186. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  187. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
  188. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
  189. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  190. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  191. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  192. /package/cpp/{rn-utils.hpp → rn-utils.h} +0 -0
@@ -993,31 +993,61 @@ kernel void kernel_neg(
993
993
  dst[tpig] = -src0[tpig];
994
994
  }
995
995
 
996
+ template <bool norm>
996
997
  kernel void kernel_sum_rows(
998
+ constant ggml_metal_kargs_sum_rows & args,
997
999
  device const float * src0,
998
1000
  device float * dst,
999
- constant ggml_metal_kargs_sum_rows & args,
1000
- uint3 tpig[[thread_position_in_grid]]) {
1001
- int64_t i3 = tpig.z;
1002
- int64_t i2 = tpig.y;
1003
- int64_t i1 = tpig.x;
1001
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
1002
+ uint3 tgpig[[threadgroup_position_in_grid]],
1003
+ ushort3 tpitg[[thread_position_in_threadgroup]],
1004
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1005
+ ushort tiisg[[thread_index_in_simdgroup]],
1006
+ ushort3 ntg[[threads_per_threadgroup]]) {
1007
+ int64_t i3 = tgpig.z;
1008
+ int64_t i2 = tgpig.y;
1009
+ int64_t i1 = tgpig.x;
1004
1010
 
1005
1011
  if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
1006
1012
  return;
1007
1013
  }
1008
1014
 
1015
+ if (sgitg == 0) {
1016
+ shmem_f32[tiisg] = 0.0f;
1017
+ }
1018
+
1009
1019
  device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
1010
1020
  device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
1011
1021
 
1012
- float row_sum = 0;
1022
+ float sumf = 0;
1023
+
1024
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1025
+ sumf += src_row[i0];
1026
+ }
1027
+
1028
+ sumf = simd_sum(sumf);
1029
+
1030
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1013
1031
 
1014
- for (int64_t i0 = 0; i0 < args.ne00; i0++) {
1015
- row_sum += src_row[i0];
1032
+ if (tiisg == 0) {
1033
+ shmem_f32[sgitg] = sumf;
1016
1034
  }
1017
1035
 
1018
- dst_row[0] = row_sum;
1036
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1037
+
1038
+ sumf = shmem_f32[tiisg];
1039
+ sumf = simd_sum(sumf);
1040
+
1041
+ if (tpitg.x == 0) {
1042
+ dst_row[0] = norm ? sumf / args.ne00 : sumf;
1043
+ }
1019
1044
  }
1020
1045
 
1046
+ typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
1047
+
1048
+ template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
1049
+ template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
1050
+
1021
1051
  template<typename T>
1022
1052
  kernel void kernel_soft_max(
1023
1053
  device const char * src0,
@@ -3328,14 +3358,12 @@ kernel void kernel_flash_attn_ext(
3328
3358
  constexpr short NW = N_SIMDWIDTH;
3329
3359
  constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
3330
3360
 
3331
- const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3332
- const short T = DK + 2*TS; // shared memory size per query in (half)
3361
+ const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3362
+ const short T = 2*DK + 2*TS; // shared memory size per query in (half)
3333
3363
 
3334
- threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3335
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3336
- threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
3337
- threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
3338
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
3364
+ threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3365
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3366
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
3339
3367
 
3340
3368
  threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
3341
3369
  threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
@@ -3354,7 +3382,7 @@ kernel void kernel_flash_attn_ext(
3354
3382
  if (iq1 + j < args.ne01) {
3355
3383
  sq4[j*DK4 + i] = (q4_t) q4[i];
3356
3384
  } else {
3357
- sq4[j*DK4 + i] = (q4_t) 0.0f;
3385
+ sq4[j*DK4 + i] = 0;
3358
3386
  }
3359
3387
  }
3360
3388
  }
@@ -3548,20 +3576,20 @@ kernel void kernel_flash_attn_ext(
3548
3576
 
3549
3577
  // O = diag(ms)*O
3550
3578
  {
3551
- s8x8_t mm;
3552
- simdgroup_load(mm, ss + 2*C, TS, 0, false);
3579
+ s8x8_t ms;
3580
+ simdgroup_load(ms, ss + 2*C, TS, 0, false);
3553
3581
 
3554
3582
  #pragma unroll(DV8)
3555
3583
  for (short i = 0; i < DV8; ++i) {
3556
- simdgroup_multiply(lo[i], mm, lo[i]);
3584
+ simdgroup_multiply(lo[i], ms, lo[i]);
3557
3585
  }
3558
3586
  }
3559
3587
 
3560
3588
  // O = O + (Q*K^T)*V
3561
3589
  {
3562
3590
  for (short cc = 0; cc < C/8; ++cc) {
3563
- s8x8_t ms;
3564
- simdgroup_load(ms, ss + 8*cc, TS, 0, false);
3591
+ s8x8_t vs;
3592
+ simdgroup_load(vs, ss + 8*cc, TS, 0, false);
3565
3593
 
3566
3594
  if (is_same<vd4x4_t, v4x4_t>::value) {
3567
3595
  // we can read directly from global memory
@@ -3572,7 +3600,7 @@ kernel void kernel_flash_attn_ext(
3572
3600
  v8x8_t mv;
3573
3601
  simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
3574
3602
 
3575
- simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
3603
+ simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]);
3576
3604
  }
3577
3605
  } else {
3578
3606
  for (short ii = 0; ii < DV16; ii += 4) {
@@ -3593,10 +3621,10 @@ kernel void kernel_flash_attn_ext(
3593
3621
  v8x8_t mv;
3594
3622
 
3595
3623
  simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
3596
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
3624
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
3597
3625
 
3598
3626
  simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
3599
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3627
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
3600
3628
  }
3601
3629
  } else {
3602
3630
  if (ii + tx < DV16) {
@@ -3611,10 +3639,10 @@ kernel void kernel_flash_attn_ext(
3611
3639
  v8x8_t mv;
3612
3640
 
3613
3641
  simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
3614
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
3642
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
3615
3643
 
3616
3644
  simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
3617
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3645
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
3618
3646
  }
3619
3647
  }
3620
3648
  }
@@ -3624,93 +3652,89 @@ kernel void kernel_flash_attn_ext(
3624
3652
  }
3625
3653
 
3626
3654
  // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
3627
- for (short j = 0; j < Q; ++j) {
3628
- if (tiisg == 0) {
3629
- ss[j*TS + 0] = S[j];
3630
- ss[j*TS + 1] = M[j];
3631
- }
3655
+ for (short j = tiisg; j < Q; j += NW) {
3656
+ ss[j*TS + 0] = S[j];
3657
+ ss[j*TS + 1] = M[j];
3632
3658
  }
3633
3659
  }
3634
3660
 
3635
- // reduce the warps sequentially
3636
- for (ushort sg = 1; sg < nsg; ++sg) {
3637
- float S = { 0.0f };
3638
- float M = { -__FLT_MAX__/2 };
3661
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3639
3662
 
3640
- threadgroup_barrier(mem_flags::mem_threadgroup);
3663
+ threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation
3664
+ threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK);
3641
3665
 
3642
- // each simdgroup stores its output to shared memory, reusing sq
3643
- if (sgitg == sg) {
3644
- for (short i = 0; i < DV8; ++i) {
3645
- simdgroup_store(lo[i], so + i*8, DV, 0, false);
3646
- }
3666
+ // store result to shared memory in F32
3667
+ if (sgitg == 0) {
3668
+ for (short i = 0; i < DV8; ++i) {
3669
+ //simdgroup_store(lo[i], so + i*8, DV, 0, false);
3670
+ simdgroup_float8x8 t(1.0f);
3671
+ simdgroup_multiply(t, lo[i], t);
3672
+ simdgroup_store(t, so + i*8, DV, 0, false);
3647
3673
  }
3674
+ }
3648
3675
 
3649
- threadgroup_barrier(mem_flags::mem_threadgroup);
3676
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3650
3677
 
3651
- // the first simdgroup accumulates the results from the other simdgroups
3652
- if (sgitg == 0) {
3653
- for (short j = 0; j < Q; ++j) {
3654
- const float S0 = ss[j*TS + 0];
3655
- const float S1 = ss[j*TS + sg*SH + 0];
3678
+ // reduce the warps sequentially
3679
+ for (ushort sg = 1; sg < nsg; ++sg) {
3680
+ if (sgitg == sg) {
3681
+ for (short j = tiisg; j < Q; j += NW) {
3682
+ const float S0 = ss[j*TS - 1*SH + 0];
3683
+ const float S1 = ss[j*TS + 0];
3656
3684
 
3657
- const float M0 = ss[j*TS + 1];
3658
- const float M1 = ss[j*TS + sg*SH + 1];
3685
+ const float M0 = ss[j*TS - 1*SH + 1];
3686
+ const float M1 = ss[j*TS + 1];
3659
3687
 
3660
- M = max(M0, M1);
3688
+ const float M = max(M0, M1);
3661
3689
 
3662
- const float ms0 = exp(M0 - M);
3663
- const float ms1 = exp(M1 - M);
3690
+ float ms0 = exp(M0 - M);
3691
+ float ms1 = exp(M1 - M);
3664
3692
 
3665
- S = S0*ms0 + S1*ms1;
3693
+ const float S = S0*ms0 + S1*ms1;
3666
3694
 
3667
- if (tiisg == 0) {
3668
- ss[j*TS + 0] = S;
3669
- ss[j*TS + 1] = M;
3695
+ ss[j*TS + 0] = S;
3696
+ ss[j*TS + 1] = M;
3670
3697
 
3671
- ss[j*TS + 2*C + j ] = ms0;
3672
- ss[j*TS + 2*C + j + sg*SH] = ms1;
3673
- }
3698
+ ss[j*TS + 2*C + j - 1*SH] = ms0;
3699
+ ss[j*TS + 2*C + j ] = ms1;
3674
3700
  }
3675
3701
 
3702
+ //simdgroup_barrier(mem_flags::mem_threadgroup);
3703
+
3676
3704
  // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
3677
3705
  {
3678
3706
  s8x8_t ms0;
3679
3707
  s8x8_t ms1;
3680
3708
 
3681
- simdgroup_load(ms0, ss + 2*C, TS, 0, false);
3682
- simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
3709
+ simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false);
3710
+ simdgroup_load(ms1, ss + 2*C, TS, 0, false);
3683
3711
 
3684
3712
  #pragma unroll(DV8)
3685
3713
  for (short i = 0; i < DV8; ++i) {
3686
- o8x8_t t;
3714
+ simdgroup_float8x8 t;
3687
3715
 
3688
3716
  simdgroup_load (t, so + i*8, DV, 0, false);
3689
- simdgroup_multiply(t, ms1, t);
3717
+ simdgroup_multiply(t, ms0, t);
3690
3718
 
3691
- simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
3719
+ simdgroup_multiply_accumulate(t, ms1, lo[i], t);
3720
+ simdgroup_store(t, so + i*8, DV, 0, false);
3692
3721
  }
3693
3722
  }
3694
3723
  }
3695
- }
3696
3724
 
3697
- // store result to shared memory (reuse sq)
3698
- if (sgitg == 0) {
3699
- for (short i = 0; i < DV8; ++i) {
3700
- simdgroup_store(lo[i], so + i*8, DV, 0, false);
3701
- }
3725
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3702
3726
  }
3703
3727
 
3704
- device float4 * dst4 = (device float4 *) dst;
3728
+ threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK);
3705
3729
 
3706
3730
  // final rescale with 1/S and store to global memory
3707
- if (sgitg == 0) {
3708
- for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
3709
- const float S = ss[j*TS + 0];
3731
+ for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
3732
+ const float S = 1.0f/sf[j*TS + 0];
3710
3733
 
3711
- for (short i = tiisg; i < DV4; i += NW) {
3712
- dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
3713
- }
3734
+ device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
3735
+
3736
+ for (short i = tiisg; i < DV4; i += NW) {
3737
+ dst4[i] = (float4) so4[j*DV4 + i]*S;
3714
3738
  }
3715
3739
  }
3716
3740
  }
@@ -3719,12 +3743,22 @@ kernel void kernel_flash_attn_ext(
3719
3743
  // template to be able to explore different combinations
3720
3744
  //
3721
3745
  #define FA_TYPES \
3722
- half, half4, simdgroup_half8x8, \
3723
- half, half4x4, simdgroup_half8x8, \
3724
- half, half4x4, simdgroup_half8x8, \
3725
- float, simdgroup_float8x8, \
3726
- float, simdgroup_float8x8, \
3727
- half, half4, simdgroup_half8x8
3746
+ float, float4, simdgroup_float8x8, \
3747
+ half, half4x4, simdgroup_half8x8, \
3748
+ half, half4x4, simdgroup_half8x8, \
3749
+ float, simdgroup_float8x8, \
3750
+ float, simdgroup_float8x8, \
3751
+ half, half4, simdgroup_half8x8
3752
+ //float, float4, simdgroup_float8x8
3753
+
3754
+ #define FA_TYPES_BF \
3755
+ bfloat, bfloat4, simdgroup_bfloat8x8, \
3756
+ bfloat, bfloat4x4, simdgroup_bfloat8x8, \
3757
+ bfloat, bfloat4x4, simdgroup_bfloat8x8, \
3758
+ float, simdgroup_float8x8, \
3759
+ float, simdgroup_float8x8, \
3760
+ half, half4, simdgroup_half8x8
3761
+ //float, float4, simdgroup_float8x8
3728
3762
 
3729
3763
  typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
3730
3764
 
@@ -3739,15 +3773,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
3739
3773
  template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
3740
3774
 
3741
3775
  #if defined(GGML_METAL_USE_BF16)
3742
- template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
3743
- template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
3744
- template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
3745
- template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
3746
- template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
3747
- template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
3748
- template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
3749
- template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
3750
- template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
3776
+ template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
3777
+ template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
3778
+ template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
3779
+ template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
3780
+ template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
3781
+ template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
3782
+ template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
3783
+ template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
3784
+ template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
3751
3785
  #endif
3752
3786
 
3753
3787
  template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
@@ -3801,6 +3835,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_at
3801
3835
  template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
3802
3836
 
3803
3837
  #undef FA_TYPES
3838
+ #undef FA_TYPES_BF
3804
3839
 
3805
3840
  template<
3806
3841
  typename q4_t, // query types in shared memory
@@ -3847,12 +3882,12 @@ kernel void kernel_flash_attn_ext_vec(
3847
3882
 
3848
3883
  const short T = DK + nsg*SH; // shared memory size per query in (half)
3849
3884
 
3850
- //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3851
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3852
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3853
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3854
- threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
3855
- threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
3885
+ //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3886
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3887
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3888
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3889
+ threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
3890
+ threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results
3856
3891
 
3857
3892
  // store the result for all queries in local memory (the O matrix from the paper)
3858
3893
  o4_t lo[DV4/NL];
@@ -4157,7 +4192,7 @@ kernel void kernel_flash_attn_ext_vec(
4157
4192
  half4, \
4158
4193
  float, \
4159
4194
  float, float4, \
4160
- half4
4195
+ float4
4161
4196
 
4162
4197
  typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
4163
4198
 
@@ -80,6 +80,7 @@ set(GGML_OPENCL_KERNELS
80
80
  mul_mv_q4_0_f32_1d_8x_flat
81
81
  mul_mv_q4_0_f32_1d_16x_flat
82
82
  mul_mv_q6_k
83
+ mul_mv_id_q4_0_f32_8x_flat
83
84
  mul
84
85
  norm
85
86
  relu
@@ -95,6 +96,12 @@ set(GGML_OPENCL_KERNELS
95
96
  sub
96
97
  sum_rows
97
98
  transpose
99
+ concat
100
+ tsembd
101
+ upscale
102
+ tanh
103
+ pad
104
+ repeat
98
105
  )
99
106
 
100
107
  foreach (K ${GGML_OPENCL_KERNELS})