@novastera-oss/llamarn 0.2.9 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  5. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  9. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  12. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  17. package/cpp/build-info.cpp +2 -2
  18. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  19. package/cpp/llama.cpp/README.md +4 -5
  20. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  21. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  22. package/cpp/llama.cpp/common/arg.cpp +17 -0
  23. package/cpp/llama.cpp/common/chat.cpp +37 -20
  24. package/cpp/llama.cpp/common/chat.h +2 -0
  25. package/cpp/llama.cpp/common/common.h +4 -0
  26. package/cpp/llama.cpp/convert_hf_to_gguf.py +745 -6
  27. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  28. package/cpp/llama.cpp/ggml/CMakeLists.txt +7 -2
  29. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  30. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  31. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +0 -1
  32. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  33. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  34. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  35. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1203 -163
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +33 -9
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  43. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +17 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  47. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +8 -6
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +185 -79
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  66. package/cpp/llama.cpp/ggml/src/ggml-impl.h +64 -0
  67. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  68. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +35 -9
  69. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +167 -39
  70. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +254 -57
  71. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +505 -40
  73. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  83. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  84. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +60 -9
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +711 -292
  92. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  93. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  94. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  95. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  105. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  106. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  112. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  113. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  114. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  116. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  117. package/cpp/llama.cpp/ggml/src/ggml.c +382 -61
  118. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  119. package/cpp/llama.cpp/gguf-py/gguf/constants.py +209 -0
  120. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  121. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +73 -21
  122. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  123. package/cpp/llama.cpp/include/llama.h +0 -40
  124. package/cpp/llama.cpp/src/llama-arch.cpp +210 -3
  125. package/cpp/llama.cpp/src/llama-arch.h +18 -1
  126. package/cpp/llama.cpp/src/llama-batch.cpp +27 -1
  127. package/cpp/llama.cpp/src/llama-batch.h +8 -1
  128. package/cpp/llama.cpp/src/llama-chat.cpp +15 -0
  129. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  130. package/cpp/llama.cpp/src/llama-graph.cpp +119 -184
  131. package/cpp/llama.cpp/src/llama-graph.h +47 -60
  132. package/cpp/llama.cpp/src/llama-hparams.cpp +7 -1
  133. package/cpp/llama.cpp/src/llama-hparams.h +3 -0
  134. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +28 -18
  135. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +4 -2
  136. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +214 -65
  137. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +62 -24
  138. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  139. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +9 -4
  140. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  141. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +20 -10
  142. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  143. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  144. package/cpp/llama.cpp/src/llama-model.cpp +2530 -685
  145. package/cpp/llama.cpp/src/llama-model.h +18 -0
  146. package/cpp/llama.cpp/src/llama-quant.cpp +1 -0
  147. package/cpp/llama.cpp/src/llama-vocab.cpp +13 -2
  148. package/cpp/llama.cpp/src/llama-vocab.h +41 -0
  149. package/ios/include/chat.h +2 -0
  150. package/ios/include/common.h +4 -0
  151. package/ios/include/llama.h +0 -40
  152. package/ios/libs/llama.xcframework/Info.plist +19 -19
  153. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  154. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5055 -4886
  155. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  156. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  157. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +0 -40
  158. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  159. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  160. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  161. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
  162. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  163. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  164. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  165. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  166. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  167. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  168. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3766
  169. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  170. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  171. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +0 -40
  172. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  173. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  174. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +0 -40
  175. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  176. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  177. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  178. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +0 -40
  179. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  180. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  181. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  182. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4890
  183. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  184. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  185. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +0 -40
  186. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  187. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  188. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  189. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
  190. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  191. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  192. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  193. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  194. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  195. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5091 -4922
  196. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  197. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  198. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +0 -40
  199. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  200. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  201. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4897
  202. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3794
  203. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  204. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  205. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  206. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  207. package/package.json +1 -1
  208. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  209. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  210. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  211. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  212. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  213. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  214. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  215. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  216. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  217. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  218. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  219. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  220. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  221. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  222. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  223. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  224. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  225. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  226. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  227. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  228. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  229. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  230. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  231. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  232. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  233. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  234. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  235. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  236. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  237. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  238. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  239. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  240. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  241. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  242. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  243. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  244. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  245. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  246. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  247. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -3,6 +3,7 @@
3
3
  #include "ggml-cpu.h"
4
4
  #include "ggml-impl.h"
5
5
  #include "binary-ops.h"
6
+ #include "ggml.h"
6
7
  #include "unary-ops.h"
7
8
  #include "vec.h"
8
9
 
@@ -3184,6 +3185,721 @@ void ggml_compute_forward_silu_back(
3184
3185
  }
3185
3186
  }
3186
3187
 
3188
+ // ggml_compute_forward_reglu
3189
+
3190
+ static void ggml_compute_forward_reglu_f32(
3191
+ const ggml_compute_params * params,
3192
+ ggml_tensor * dst) {
3193
+
3194
+ const ggml_tensor * src0 = dst->src[0];
3195
+ const ggml_tensor * src1 = dst->src[1];
3196
+ char * src0_d = (char *) src0->data;
3197
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3198
+ const size_t src0_o = src0->nb[1];
3199
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3200
+
3201
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3202
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3203
+
3204
+ if (src1) {
3205
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3206
+ GGML_ASSERT(src0->type == src1->type);
3207
+ }
3208
+
3209
+ const int ith = params->ith;
3210
+ const int nth = params->nth;
3211
+
3212
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3213
+ const int nr = ggml_nrows(src0);
3214
+
3215
+ GGML_ASSERT(dst->ne[0] == nc);
3216
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3217
+
3218
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3219
+
3220
+ // rows per thread
3221
+ const int dr = (nr + nth - 1)/nth;
3222
+
3223
+ // row range for this thread
3224
+ const int ir0 = dr*ith;
3225
+ const int ir1 = MIN(ir0 + dr, nr);
3226
+
3227
+ for (int i1 = ir0; i1 < ir1; i1++) {
3228
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3229
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3230
+
3231
+ if (!src1) {
3232
+ src0_p += swapped ? nc : 0;
3233
+ src1_p += swapped ? 0 : nc;
3234
+ }
3235
+
3236
+ ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3237
+
3238
+ #ifndef NDEBUG
3239
+ for (int k = 0; k < nc; k++) {
3240
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3241
+ GGML_UNUSED(x);
3242
+ assert(!isnan(x));
3243
+ assert(!isinf(x));
3244
+ }
3245
+ #endif
3246
+ }
3247
+ }
3248
+
3249
+ static void ggml_compute_forward_reglu_f16(
3250
+ const ggml_compute_params * params,
3251
+ ggml_tensor * dst) {
3252
+
3253
+ const ggml_tensor * src0 = dst->src[0];
3254
+ const ggml_tensor * src1 = dst->src[1];
3255
+ char * src0_d = (char *) src0->data;
3256
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3257
+ const size_t src0_o = src0->nb[1];
3258
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3259
+
3260
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3261
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3262
+
3263
+ if (src1) {
3264
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3265
+ GGML_ASSERT(src0->type == src1->type);
3266
+ }
3267
+
3268
+ const int ith = params->ith;
3269
+ const int nth = params->nth;
3270
+
3271
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3272
+ const int nr = ggml_nrows(src0);
3273
+
3274
+ GGML_ASSERT(dst->ne[0] == nc);
3275
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3276
+
3277
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3278
+
3279
+ // rows per thread
3280
+ const int dr = (nr + nth - 1)/nth;
3281
+
3282
+ // row range for this thread
3283
+ const int ir0 = dr*ith;
3284
+ const int ir1 = MIN(ir0 + dr, nr);
3285
+
3286
+ for (int i1 = ir0; i1 < ir1; i1++) {
3287
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3288
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3289
+
3290
+ if (!src1) {
3291
+ src0_p += swapped ? nc : 0;
3292
+ src1_p += swapped ? 0 : nc;
3293
+ }
3294
+
3295
+ ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3296
+
3297
+ #ifndef NDEBUG
3298
+ for (int k = 0; k < nc; k++) {
3299
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3300
+ const float v = GGML_FP16_TO_FP32(x);
3301
+ GGML_UNUSED(v);
3302
+ assert(!isnan(v));
3303
+ assert(!isinf(v));
3304
+ }
3305
+ #endif
3306
+ }
3307
+ }
3308
+
3309
+ static void ggml_compute_forward_reglu(
3310
+ const ggml_compute_params * params,
3311
+ ggml_tensor * dst) {
3312
+
3313
+ const ggml_tensor * src0 = dst->src[0];
3314
+
3315
+ switch (src0->type) {
3316
+ case GGML_TYPE_F32:
3317
+ {
3318
+ ggml_compute_forward_reglu_f32(params, dst);
3319
+ } break;
3320
+ case GGML_TYPE_F16:
3321
+ {
3322
+ ggml_compute_forward_reglu_f16(params, dst);
3323
+ } break;
3324
+ default:
3325
+ {
3326
+ GGML_ABORT("fatal error");
3327
+ }
3328
+ }
3329
+ }
3330
+
3331
+ // ggml_compute_forward_geglu
3332
+
3333
+ static void ggml_compute_forward_geglu_f32(
3334
+ const ggml_compute_params * params,
3335
+ ggml_tensor * dst) {
3336
+
3337
+ const ggml_tensor * src0 = dst->src[0];
3338
+ const ggml_tensor * src1 = dst->src[1];
3339
+ char * src0_d = (char *) src0->data;
3340
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3341
+ const size_t src0_o = src0->nb[1];
3342
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3343
+
3344
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3345
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3346
+
3347
+ if (src1) {
3348
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3349
+ GGML_ASSERT(src0->type == src1->type);
3350
+ }
3351
+
3352
+ const int ith = params->ith;
3353
+ const int nth = params->nth;
3354
+
3355
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3356
+ const int nr = ggml_nrows(src0);
3357
+
3358
+ GGML_ASSERT(dst->ne[0] == nc);
3359
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3360
+
3361
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3362
+
3363
+ // rows per thread
3364
+ const int dr = (nr + nth - 1)/nth;
3365
+
3366
+ // row range for this thread
3367
+ const int ir0 = dr*ith;
3368
+ const int ir1 = MIN(ir0 + dr, nr);
3369
+
3370
+ for (int i1 = ir0; i1 < ir1; i1++) {
3371
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3372
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3373
+
3374
+ if (!src1) {
3375
+ src0_p += swapped ? nc : 0;
3376
+ src1_p += swapped ? 0 : nc;
3377
+ }
3378
+
3379
+ ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3380
+
3381
+ #ifndef NDEBUG
3382
+ for (int k = 0; k < nc; k++) {
3383
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3384
+ GGML_UNUSED(x);
3385
+ assert(!isnan(x));
3386
+ assert(!isinf(x));
3387
+ }
3388
+ #endif
3389
+ }
3390
+ }
3391
+
3392
+ static void ggml_compute_forward_geglu_f16(
3393
+ const ggml_compute_params * params,
3394
+ ggml_tensor * dst) {
3395
+
3396
+ const ggml_tensor * src0 = dst->src[0];
3397
+ const ggml_tensor * src1 = dst->src[1];
3398
+ char * src0_d = (char *) src0->data;
3399
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3400
+ const size_t src0_o = src0->nb[1];
3401
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3402
+
3403
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3404
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3405
+
3406
+ if (src1) {
3407
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3408
+ GGML_ASSERT(src0->type == src1->type);
3409
+ }
3410
+
3411
+ const int ith = params->ith;
3412
+ const int nth = params->nth;
3413
+
3414
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3415
+ const int nr = ggml_nrows(src0);
3416
+
3417
+ GGML_ASSERT(dst->ne[0] == nc);
3418
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3419
+
3420
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3421
+
3422
+ // rows per thread
3423
+ const int dr = (nr + nth - 1)/nth;
3424
+
3425
+ // row range for this thread
3426
+ const int ir0 = dr*ith;
3427
+ const int ir1 = MIN(ir0 + dr, nr);
3428
+
3429
+ for (int i1 = ir0; i1 < ir1; i1++) {
3430
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3431
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3432
+
3433
+ if (!src1) {
3434
+ src0_p += swapped ? nc : 0;
3435
+ src1_p += swapped ? 0 : nc;
3436
+ }
3437
+
3438
+ ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3439
+
3440
+ #ifndef NDEBUG
3441
+ for (int k = 0; k < nc; k++) {
3442
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3443
+ const float v = GGML_FP16_TO_FP32(x);
3444
+ GGML_UNUSED(v);
3445
+ assert(!isnan(v));
3446
+ assert(!isinf(v));
3447
+ }
3448
+ #endif
3449
+ }
3450
+ }
3451
+
3452
+ static void ggml_compute_forward_geglu(
3453
+ const ggml_compute_params * params,
3454
+ ggml_tensor * dst) {
3455
+
3456
+ const ggml_tensor * src0 = dst->src[0];
3457
+
3458
+ switch (src0->type) {
3459
+ case GGML_TYPE_F32:
3460
+ {
3461
+ ggml_compute_forward_geglu_f32(params, dst);
3462
+ } break;
3463
+ case GGML_TYPE_F16:
3464
+ {
3465
+ ggml_compute_forward_geglu_f16(params, dst);
3466
+ } break;
3467
+ default:
3468
+ {
3469
+ GGML_ABORT("fatal error");
3470
+ }
3471
+ }
3472
+ }
3473
+
3474
+ // ggml_compute_forward_swiglu
3475
+
3476
+ static void ggml_compute_forward_swiglu_f32(
3477
+ const ggml_compute_params * params,
3478
+ ggml_tensor * dst) {
3479
+
3480
+ const ggml_tensor * src0 = dst->src[0];
3481
+ const ggml_tensor * src1 = dst->src[1];
3482
+ char * src0_d = (char *) src0->data;
3483
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3484
+ const size_t src0_o = src0->nb[1];
3485
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3486
+
3487
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3488
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3489
+
3490
+ if (src1) {
3491
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3492
+ GGML_ASSERT(src0->type == src1->type);
3493
+ }
3494
+
3495
+ const int ith = params->ith;
3496
+ const int nth = params->nth;
3497
+
3498
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3499
+ const int nr = ggml_nrows(src0);
3500
+
3501
+ GGML_ASSERT(dst->ne[0] == nc);
3502
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3503
+
3504
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3505
+
3506
+ // rows per thread
3507
+ const int dr = (nr + nth - 1)/nth;
3508
+
3509
+ // row range for this thread
3510
+ const int ir0 = dr*ith;
3511
+ const int ir1 = MIN(ir0 + dr, nr);
3512
+
3513
+ for (int i1 = ir0; i1 < ir1; i1++) {
3514
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3515
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3516
+
3517
+ if (!src1) {
3518
+ src0_p += swapped ? nc : 0;
3519
+ src1_p += swapped ? 0 : nc;
3520
+ }
3521
+
3522
+ ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3523
+
3524
+ #ifndef NDEBUG
3525
+ for (int k = 0; k < nc; k++) {
3526
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3527
+ GGML_UNUSED(x);
3528
+ assert(!isnan(x));
3529
+ assert(!isinf(x));
3530
+ }
3531
+ #endif
3532
+ }
3533
+ }
3534
+
3535
+ static void ggml_compute_forward_swiglu_f16(
3536
+ const ggml_compute_params * params,
3537
+ ggml_tensor * dst) {
3538
+
3539
+ const ggml_tensor * src0 = dst->src[0];
3540
+ const ggml_tensor * src1 = dst->src[1];
3541
+ char * src0_d = (char *) src0->data;
3542
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3543
+ const size_t src0_o = src0->nb[1];
3544
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3545
+
3546
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3547
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3548
+
3549
+ if (src1) {
3550
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3551
+ GGML_ASSERT(src0->type == src1->type);
3552
+ }
3553
+
3554
+ const int ith = params->ith;
3555
+ const int nth = params->nth;
3556
+
3557
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3558
+ const int nr = ggml_nrows(src0);
3559
+
3560
+ GGML_ASSERT(dst->ne[0] == nc);
3561
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3562
+
3563
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3564
+
3565
+ // rows per thread
3566
+ const int dr = (nr + nth - 1)/nth;
3567
+
3568
+ // row range for this thread
3569
+ const int ir0 = dr*ith;
3570
+ const int ir1 = MIN(ir0 + dr, nr);
3571
+
3572
+ for (int i1 = ir0; i1 < ir1; i1++) {
3573
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3574
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3575
+
3576
+ if (!src1) {
3577
+ src0_p += swapped ? nc : 0;
3578
+ src1_p += swapped ? 0 : nc;
3579
+ }
3580
+
3581
+ ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3582
+
3583
+ #ifndef NDEBUG
3584
+ for (int k = 0; k < nc; k++) {
3585
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3586
+ const float v = GGML_FP16_TO_FP32(x);
3587
+ GGML_UNUSED(v);
3588
+ assert(!isnan(v));
3589
+ assert(!isinf(v));
3590
+ }
3591
+ #endif
3592
+ }
3593
+ }
3594
+
3595
+ static void ggml_compute_forward_swiglu(
3596
+ const ggml_compute_params * params,
3597
+ ggml_tensor * dst) {
3598
+
3599
+ const ggml_tensor * src0 = dst->src[0];
3600
+
3601
+ switch (src0->type) {
3602
+ case GGML_TYPE_F32:
3603
+ {
3604
+ ggml_compute_forward_swiglu_f32(params, dst);
3605
+ } break;
3606
+ case GGML_TYPE_F16:
3607
+ {
3608
+ ggml_compute_forward_swiglu_f16(params, dst);
3609
+ } break;
3610
+ default:
3611
+ {
3612
+ GGML_ABORT("fatal error");
3613
+ }
3614
+ }
3615
+ }
3616
+
3617
+ // ggml_compute_forward_geglu_erf
3618
+
3619
+ static void ggml_compute_forward_geglu_erf_f32(
3620
+ const ggml_compute_params * params,
3621
+ ggml_tensor * dst) {
3622
+
3623
+ const ggml_tensor * src0 = dst->src[0];
3624
+ const ggml_tensor * src1 = dst->src[1];
3625
+ char * src0_d = (char *) src0->data;
3626
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3627
+ const size_t src0_o = src0->nb[1];
3628
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3629
+
3630
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3631
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3632
+
3633
+ if (src1) {
3634
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3635
+ GGML_ASSERT(src0->type == src1->type);
3636
+ }
3637
+
3638
+ const int ith = params->ith;
3639
+ const int nth = params->nth;
3640
+
3641
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3642
+ const int nr = ggml_nrows(src0);
3643
+
3644
+ GGML_ASSERT(dst->ne[0] == nc);
3645
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3646
+
3647
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3648
+
3649
+ // rows per thread
3650
+ const int dr = (nr + nth - 1)/nth;
3651
+
3652
+ // row range for this thread
3653
+ const int ir0 = dr*ith;
3654
+ const int ir1 = MIN(ir0 + dr, nr);
3655
+
3656
+ for (int i1 = ir0; i1 < ir1; i1++) {
3657
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3658
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3659
+
3660
+ if (!src1) {
3661
+ src0_p += swapped ? nc : 0;
3662
+ src1_p += swapped ? 0 : nc;
3663
+ }
3664
+
3665
+ ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3666
+
3667
+ #ifndef NDEBUG
3668
+ for (int k = 0; k < nc; k++) {
3669
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3670
+ GGML_UNUSED(x);
3671
+ assert(!isnan(x));
3672
+ assert(!isinf(x));
3673
+ }
3674
+ #endif
3675
+ }
3676
+ }
3677
+
3678
+ static void ggml_compute_forward_geglu_erf_f16(
3679
+ const ggml_compute_params * params,
3680
+ ggml_tensor * dst) {
3681
+
3682
+ const ggml_tensor * src0 = dst->src[0];
3683
+ const ggml_tensor * src1 = dst->src[1];
3684
+ char * src0_d = (char *) src0->data;
3685
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3686
+ const size_t src0_o = src0->nb[1];
3687
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3688
+
3689
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3690
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3691
+
3692
+ if (src1) {
3693
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3694
+ GGML_ASSERT(src0->type == src1->type);
3695
+ }
3696
+
3697
+ const int ith = params->ith;
3698
+ const int nth = params->nth;
3699
+
3700
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3701
+ const int nr = ggml_nrows(src0);
3702
+
3703
+ GGML_ASSERT(dst->ne[0] == nc);
3704
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3705
+
3706
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3707
+
3708
+ // rows per thread
3709
+ const int dr = (nr + nth - 1)/nth;
3710
+
3711
+ // row range for this thread
3712
+ const int ir0 = dr*ith;
3713
+ const int ir1 = MIN(ir0 + dr, nr);
3714
+
3715
+ for (int i1 = ir0; i1 < ir1; i1++) {
3716
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3717
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3718
+
3719
+ if (!src1) {
3720
+ src0_p += swapped ? nc : 0;
3721
+ src1_p += swapped ? 0 : nc;
3722
+ }
3723
+
3724
+ ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3725
+
3726
+ #ifndef NDEBUG
3727
+ for (int k = 0; k < nc; k++) {
3728
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3729
+ const float v = GGML_FP16_TO_FP32(x);
3730
+ GGML_UNUSED(v);
3731
+ assert(!isnan(v));
3732
+ assert(!isinf(v));
3733
+ }
3734
+ #endif
3735
+ }
3736
+ }
3737
+
3738
+ static void ggml_compute_forward_geglu_erf(
3739
+ const ggml_compute_params * params,
3740
+ ggml_tensor * dst) {
3741
+
3742
+ const ggml_tensor * src0 = dst->src[0];
3743
+
3744
+ switch (src0->type) {
3745
+ case GGML_TYPE_F32:
3746
+ {
3747
+ ggml_compute_forward_geglu_erf_f32(params, dst);
3748
+ } break;
3749
+ case GGML_TYPE_F16:
3750
+ {
3751
+ ggml_compute_forward_geglu_erf_f16(params, dst);
3752
+ } break;
3753
+ default:
3754
+ {
3755
+ GGML_ABORT("fatal error");
3756
+ }
3757
+ }
3758
+ }
3759
+
3760
+ // ggml_compute_forward_geglu_quick
3761
+
3762
+ static void ggml_compute_forward_geglu_quick_f32(
3763
+ const ggml_compute_params * params,
3764
+ ggml_tensor * dst) {
3765
+
3766
+ const ggml_tensor * src0 = dst->src[0];
3767
+ const ggml_tensor * src1 = dst->src[1];
3768
+ char * src0_d = (char *) src0->data;
3769
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3770
+ const size_t src0_o = src0->nb[1];
3771
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3772
+
3773
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3774
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3775
+
3776
+ if (src1) {
3777
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3778
+ GGML_ASSERT(src0->type == src1->type);
3779
+ }
3780
+
3781
+ const int ith = params->ith;
3782
+ const int nth = params->nth;
3783
+
3784
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3785
+ const int nr = ggml_nrows(src0);
3786
+
3787
+ GGML_ASSERT(dst->ne[0] == nc);
3788
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3789
+
3790
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3791
+
3792
+ // rows per thread
3793
+ const int dr = (nr + nth - 1)/nth;
3794
+
3795
+ // row range for this thread
3796
+ const int ir0 = dr*ith;
3797
+ const int ir1 = MIN(ir0 + dr, nr);
3798
+
3799
+ for (int i1 = ir0; i1 < ir1; i1++) {
3800
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3801
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3802
+
3803
+ if (!src1) {
3804
+ src0_p += swapped ? nc : 0;
3805
+ src1_p += swapped ? 0 : nc;
3806
+ }
3807
+
3808
+ ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3809
+
3810
+ #ifndef NDEBUG
3811
+ for (int k = 0; k < nc; k++) {
3812
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3813
+ GGML_UNUSED(x);
3814
+ assert(!isnan(x));
3815
+ assert(!isinf(x));
3816
+ }
3817
+ #endif
3818
+ }
3819
+ }
3820
+
3821
+ static void ggml_compute_forward_geglu_quick_f16(
3822
+ const ggml_compute_params * params,
3823
+ ggml_tensor * dst) {
3824
+
3825
+ const ggml_tensor * src0 = dst->src[0];
3826
+ const ggml_tensor * src1 = dst->src[1];
3827
+ char * src0_d = (char *) src0->data;
3828
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3829
+ const size_t src0_o = src0->nb[1];
3830
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3831
+
3832
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3833
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3834
+
3835
+ if (src1) {
3836
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3837
+ GGML_ASSERT(src0->type == src1->type);
3838
+ }
3839
+
3840
+ const int ith = params->ith;
3841
+ const int nth = params->nth;
3842
+
3843
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3844
+ const int nr = ggml_nrows(src0);
3845
+
3846
+ GGML_ASSERT(dst->ne[0] == nc);
3847
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3848
+
3849
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3850
+
3851
+ // rows per thread
3852
+ const int dr = (nr + nth - 1)/nth;
3853
+
3854
+ // row range for this thread
3855
+ const int ir0 = dr*ith;
3856
+ const int ir1 = MIN(ir0 + dr, nr);
3857
+
3858
+ for (int i1 = ir0; i1 < ir1; i1++) {
3859
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3860
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3861
+
3862
+ if (!src1) {
3863
+ src0_p += swapped ? nc : 0;
3864
+ src1_p += swapped ? 0 : nc;
3865
+ }
3866
+
3867
+ ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3868
+
3869
+ #ifndef NDEBUG
3870
+ for (int k = 0; k < nc; k++) {
3871
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3872
+ const float v = GGML_FP16_TO_FP32(x);
3873
+ GGML_UNUSED(v);
3874
+ assert(!isnan(v));
3875
+ assert(!isinf(v));
3876
+ }
3877
+ #endif
3878
+ }
3879
+ }
3880
+
3881
+ static void ggml_compute_forward_geglu_quick(
3882
+ const ggml_compute_params * params,
3883
+ ggml_tensor * dst) {
3884
+
3885
+ const ggml_tensor * src0 = dst->src[0];
3886
+
3887
+ switch (src0->type) {
3888
+ case GGML_TYPE_F32:
3889
+ {
3890
+ ggml_compute_forward_geglu_quick_f32(params, dst);
3891
+ } break;
3892
+ case GGML_TYPE_F16:
3893
+ {
3894
+ ggml_compute_forward_geglu_quick_f16(params, dst);
3895
+ } break;
3896
+ default:
3897
+ {
3898
+ GGML_ABORT("fatal error");
3899
+ }
3900
+ }
3901
+ }
3902
+
3187
3903
  // ggml_compute_forward_norm
3188
3904
 
3189
3905
  static void ggml_compute_forward_norm_f32(
@@ -3927,9 +4643,11 @@ static void ggml_compute_forward_scale_f32(
3927
4643
  GGML_ASSERT(ggml_is_contiguous(dst));
3928
4644
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
3929
4645
 
3930
- // scale factor
3931
- float v;
3932
- memcpy(&v, dst->op_params, sizeof(float));
4646
+ float s; // scale factor
4647
+ float b; // bias
4648
+
4649
+ memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
4650
+ memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
3933
4651
 
3934
4652
  const int ith = params->ith;
3935
4653
  const int nth = params->nth;
@@ -3948,12 +4666,22 @@ static void ggml_compute_forward_scale_f32(
3948
4666
 
3949
4667
  const size_t nb1 = dst->nb[1];
3950
4668
 
3951
- for (int i1 = ir0; i1 < ir1; i1++) {
3952
- if (dst->data != src0->data) {
3953
- // src0 is same shape as dst => same indices
3954
- memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4669
+ if (b == 0.0f) {
4670
+ for (int i1 = ir0; i1 < ir1; i1++) {
4671
+ if (dst->data != src0->data) {
4672
+ // src0 is same shape as dst => same indices
4673
+ // TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
4674
+ memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4675
+ }
4676
+ ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
4677
+ }
4678
+ } else {
4679
+ for (int i1 = ir0; i1 < ir1; i1++) {
4680
+ ggml_vec_mad1_f32(nc,
4681
+ (float *) ((char *) dst->data + i1*nb1),
4682
+ (float *) ((char *) src0->data + i1*nb1),
4683
+ s, b);
3955
4684
  }
3956
- ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
3957
4685
  }
3958
4686
  }
3959
4687
 
@@ -4802,14 +5530,17 @@ static void ggml_compute_forward_soft_max_f32(
4802
5530
  memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
4803
5531
  memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
4804
5532
 
4805
- // TODO: handle transposed/permuted matrices
4806
-
4807
5533
  const int ith = params->ith;
4808
5534
  const int nth = params->nth;
4809
5535
 
4810
5536
  GGML_TENSOR_UNARY_OP_LOCALS
4811
5537
 
4812
- //const int64_t ne11 = src1 ? src1->ne[1] : 1;
5538
+ const int64_t nb11 = src1 ? src1->nb[1] : 1;
5539
+ const int64_t nb12 = src1 ? src1->nb[2] : 1;
5540
+ const int64_t nb13 = src1 ? src1->nb[3] : 1;
5541
+
5542
+ const int64_t ne12 = src1 ? src1->ne[2] : 1;
5543
+ const int64_t ne13 = src1 ? src1->ne[3] : 1;
4813
5544
 
4814
5545
  // TODO: is this supposed to be ceil instead of floor?
4815
5546
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -4819,68 +5550,66 @@ static void ggml_compute_forward_soft_max_f32(
4819
5550
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
4820
5551
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
4821
5552
 
4822
- const int nc = src0->ne[0];
4823
- const int nr = ggml_nrows(src0);
4824
-
4825
- // rows per thread
4826
- const int dr = (nr + nth - 1)/nth;
4827
-
4828
- // row range for this thread
4829
- const int ir0 = dr*ith;
4830
- const int ir1 = MIN(ir0 + dr, nr);
4831
-
4832
- float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
5553
+ float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
4833
5554
 
4834
5555
  const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
4835
5556
 
4836
- for (int i1 = ir0; i1 < ir1; i1++) {
4837
- // ALiBi
4838
- const uint32_t h = (i1/ne01)%ne02; // head
4839
- const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
4840
-
4841
- float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
4842
- float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
4843
-
4844
- // broadcast the mask across rows
4845
- ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
4846
- float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
4847
-
4848
- ggml_vec_cpy_f32 (nc, wp, sp);
4849
- ggml_vec_scale_f32(nc, wp, scale);
4850
- if (mp_f32) {
4851
- if (use_f16) {
4852
- for (int i = 0; i < nc; ++i) {
4853
- wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
4854
- }
4855
- } else {
4856
- for (int i = 0; i < nc; ++i) {
4857
- wp[i] += slope*mp_f32[i];
5557
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5558
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5559
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5560
+ const int64_t i11 = i01;
5561
+ const int64_t i12 = i02%ne12;
5562
+ const int64_t i13 = i03%ne13;
5563
+
5564
+ // ALiBi
5565
+ const uint32_t h = i02; // head
5566
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5567
+
5568
+ float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5569
+ float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
5570
+
5571
+ // broadcast the mask across rows
5572
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5573
+ float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5574
+
5575
+ ggml_vec_cpy_f32 (ne00, wp, sp);
5576
+ ggml_vec_scale_f32(ne00, wp, scale);
5577
+ if (mp_f32) {
5578
+ if (use_f16) {
5579
+ for (int i = 0; i < ne00; ++i) {
5580
+ wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5581
+ }
5582
+ } else {
5583
+ for (int i = 0; i < ne00; ++i) {
5584
+ wp[i] += slope*mp_f32[i];
5585
+ }
5586
+ }
4858
5587
  }
4859
- }
4860
- }
4861
5588
 
4862
5589
  #ifndef NDEBUG
4863
- for (int i = 0; i < nc; ++i) {
4864
- //printf("p[%d] = %f\n", i, p[i]);
4865
- assert(!isnan(wp[i]));
4866
- }
5590
+ for (int i = 0; i < ne00; ++i) {
5591
+ //printf("p[%d] = %f\n", i, p[i]);
5592
+ assert(!isnan(wp[i]));
5593
+ }
4867
5594
  #endif
4868
5595
 
4869
- float max = -INFINITY;
4870
- ggml_vec_max_f32(nc, &max, wp);
5596
+ float max = -INFINITY;
5597
+ ggml_vec_max_f32(ne00, &max, wp);
4871
5598
 
4872
- ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
4873
- assert(sum > 0.0);
5599
+ ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
5600
+ assert(sum > 0.0);
4874
5601
 
4875
- sum = 1.0/sum;
4876
- ggml_vec_scale_f32(nc, dp, sum);
5602
+ sum = 1.0/sum;
5603
+ ggml_vec_scale_f32(ne00, dp, sum);
4877
5604
 
4878
5605
  #ifndef NDEBUG
4879
- for (int i = 0; i < nc; ++i) {
4880
- assert(!isnan(dp[i]));
4881
- assert(!isinf(dp[i]));
4882
- }
5606
+ for (int i = 0; i < ne00; ++i) {
5607
+ assert(!isnan(dp[i]));
5608
+ assert(!isinf(dp[i]));
5609
+ }
4883
5610
  #endif
5611
+ }
5612
+ }
4884
5613
  }
4885
5614
  }
4886
5615
 
@@ -6116,6 +6845,186 @@ void ggml_compute_forward_im2col_back_f32(
6116
6845
  }
6117
6846
  }
6118
6847
 
6848
+ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6849
+ void * a, void * b, float * c) {
6850
+ const ggml_type_traits * traits = ggml_get_type_traits(type);
6851
+ struct ggml_tensor src1 = {};
6852
+ src1.type = type;
6853
+ src1.ne[0] = k;
6854
+ src1.ne[1] = m;
6855
+ src1.ne[2] = 1;
6856
+ src1.ne[3] = 1;
6857
+ src1.nb[0] = traits->type_size;
6858
+ src1.nb[1] = k * traits->type_size;
6859
+ src1.nb[2] = src1.nb[1];
6860
+ src1.nb[3] = src1.nb[2];
6861
+ src1.data = a;
6862
+
6863
+ struct ggml_tensor src0 = {};
6864
+ src0.type = type;
6865
+ src0.ne[0] = k;
6866
+ src0.ne[1] = n;
6867
+ src0.ne[2] = 1;
6868
+ src0.ne[3] = 1;
6869
+ src0.nb[0] = traits->type_size;
6870
+ src0.nb[1] = k * traits->type_size;
6871
+ src0.nb[2] = src0.nb[1];
6872
+ src0.nb[3] = src0.nb[2];
6873
+ src0.data = b;
6874
+
6875
+ struct ggml_tensor dst = {};
6876
+ dst.ne[0] = n;
6877
+ dst.ne[1] = m;
6878
+ dst.ne[2] = 1;
6879
+ dst.ne[3] = 1;
6880
+ dst.nb[0] = sizeof(float);
6881
+ dst.nb[1] = n * sizeof(float);
6882
+ dst.nb[2] = dst.nb[1];
6883
+ dst.nb[3] = dst.nb[2];
6884
+ dst.data = c;
6885
+ dst.src[0] = &src0;
6886
+ dst.src[1] = &src1;
6887
+
6888
+ ggml_compute_forward_mul_mat(params, &dst);
6889
+ }
6890
+
6891
+ // ggml_compute_forward_conv_2d
6892
+
6893
+ static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
6894
+ const ggml_tensor * kernel, // [KW, KH, IC, OC]
6895
+ const ggml_tensor * src, // [W, H, C, N]
6896
+ ggml_tensor * dst, // [OW, OH, OC, N]
6897
+ ggml_type kernel_type) {
6898
+
6899
+ GGML_ASSERT(ggml_is_contiguous(kernel));
6900
+ GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
6901
+ GGML_ASSERT(kernel->type == kernel_type);
6902
+
6903
+ const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
6904
+
6905
+ const int32_t stride_x = dst->op_params[0];
6906
+ const int32_t stride_y = dst->op_params[1];
6907
+ const int32_t pad_x = dst->op_params[2];
6908
+ const int32_t pad_y = dst->op_params[3];
6909
+ const int32_t dilation_x = dst->op_params[4];
6910
+ const int32_t dilation_y = dst->op_params[5];
6911
+
6912
+ const int64_t c_in = src->ne[2];
6913
+ const int64_t c_out = kernel->ne[3];
6914
+ GGML_ASSERT(c_in == kernel->ne[2]);
6915
+
6916
+ const int64_t src_w = src->ne[0];
6917
+ const int64_t src_h = src->ne[1];
6918
+ const int64_t knl_w = kernel->ne[0];
6919
+ const int64_t knl_h = kernel->ne[1];
6920
+ const int64_t dst_w = dst->ne[0];
6921
+ const int64_t dst_h = dst->ne[1];
6922
+
6923
+ const float * src_data = (float *) src->data;
6924
+ void * knl_data = kernel->data;
6925
+ float * dst_data = (float *) dst->data;
6926
+
6927
+ const int64_t knl_n = knl_w * knl_h * c_in;
6928
+ const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
6929
+
6930
+ const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float);
6931
+ const int64_t batch_size = params->wsize / space_per_patch;
6932
+ const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6933
+ const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
6934
+
6935
+ GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6936
+
6937
+ void * tmp = params->wdata;
6938
+
6939
+ for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6940
+
6941
+ const int64_t patch_start_batch = batch_i * patches_per_batch;
6942
+ const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch,
6943
+ patch_total);
6944
+ const int64_t patch_n = patch_end_batch - patch_start_batch;
6945
+
6946
+ const int64_t patch_per_thread = (patch_n + params->nth - 1) / params->nth;
6947
+ const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6948
+ const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
6949
+
6950
+ //im2col for a patch
6951
+ for (int64_t p = patch_start; p < patch_end; ++p) {
6952
+ const int64_t batch_n = p / (dst_w * dst_h);
6953
+ const int64_t src_x = (p / dst_w) % dst_h;
6954
+ const int64_t src_y = p % dst_w;
6955
+
6956
+ const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
6957
+ char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
6958
+
6959
+ for (int64_t ic = 0; ic < c_in; ++ic) {
6960
+ for (int64_t ky = 0; ky < knl_h; ++ky) {
6961
+ for (int64_t kx = 0; kx < knl_w; ++kx) {
6962
+ const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
6963
+ const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
6964
+
6965
+ int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6966
+
6967
+ float src_val;
6968
+ if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6969
+ src_val = 0.0f;
6970
+ } else {
6971
+ const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6972
+ src_val = *src_ptr;
6973
+ }
6974
+
6975
+ char * element_ptr = dst_row + dst_idx * traits->type_size;
6976
+ if (kernel_type == GGML_TYPE_F32) {
6977
+ *(float *) element_ptr = src_val;
6978
+ } else if (kernel_type == GGML_TYPE_F16) {
6979
+ *(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
6980
+ }
6981
+ }
6982
+ }
6983
+ }
6984
+ } // patches handled by this thread
6985
+
6986
+ ggml_barrier(params->threadpool);
6987
+
6988
+ float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
6989
+
6990
+ GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
6991
+
6992
+ // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6993
+ ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
6994
+
6995
+ ggml_barrier(params->threadpool);
6996
+
6997
+
6998
+ //permute back [OC, N, OH, OW] to [N, OC, OH, OW]
6999
+ const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
7000
+ const int64_t permute_start = params->ith * permute_per_thread;
7001
+ const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
7002
+
7003
+ for (int64_t i = permute_start; i < permute_end; ++i) {
7004
+ const int64_t p = patch_start_batch + i;
7005
+ const int64_t batch_n = p / (dst_w * dst_h);
7006
+ const int64_t dst_y = (p / dst_w) % dst_h;
7007
+ const int64_t dst_x = p % dst_w;
7008
+
7009
+ for (int64_t oc = 0; oc < c_out; ++oc) {
7010
+ const float value = gemm_output[i * c_out + oc];
7011
+ float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
7012
+ *dst_ptr = value;
7013
+ }
7014
+ }
7015
+ }
7016
+ }
7017
+
7018
+ void ggml_compute_forward_conv_2d(
7019
+ const ggml_compute_params * params,
7020
+ ggml_tensor * dst) {
7021
+
7022
+ const ggml_tensor * src0 = dst->src[0];
7023
+ const ggml_tensor * src1 = dst->src[1];
7024
+
7025
+ ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
7026
+ }
7027
+
6119
7028
  // ggml_compute_forward_conv_transpose_2d
6120
7029
 
6121
7030
  void ggml_compute_forward_conv_transpose_2d(
@@ -6666,12 +7575,13 @@ static void ggml_compute_forward_upscale_f32(
6666
7575
 
6667
7576
  GGML_TENSOR_UNARY_OP_LOCALS
6668
7577
 
6669
- const float sf0 = (float)ne0/src0->ne[0];
6670
- const float sf1 = (float)ne1/src0->ne[1];
6671
- const float sf2 = (float)ne2/src0->ne[2];
6672
- const float sf3 = (float)ne3/src0->ne[3];
7578
+ float sf0 = (float)ne0/src0->ne[0];
7579
+ float sf1 = (float)ne1/src0->ne[1];
7580
+ float sf2 = (float)ne2/src0->ne[2];
7581
+ float sf3 = (float)ne3/src0->ne[3];
6673
7582
 
6674
- const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0);
7583
+ const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
7584
+ const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
6675
7585
 
6676
7586
  if (mode == GGML_SCALE_MODE_NEAREST) {
6677
7587
  for (int64_t i3 = 0; i3 < ne3; i3++) {
@@ -6692,8 +7602,12 @@ static void ggml_compute_forward_upscale_f32(
6692
7602
  }
6693
7603
  }
6694
7604
  } else if (mode == GGML_SCALE_MODE_BILINEAR) {
6695
- // setting a pixel offset of 0 would replicate the behavior of pytorch interpolate with align_corners=True
6696
- const float pixel_offset = 0.5f;
7605
+ float pixel_offset = 0.5f;
7606
+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7607
+ pixel_offset = 0.0f;
7608
+ sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
7609
+ sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
7610
+ }
6697
7611
 
6698
7612
  for (int64_t i3 = 0; i3 < ne3; i3++) {
6699
7613
  const int64_t i03 = i3 / sf3;
@@ -7151,7 +8065,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7151
8065
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
7152
8066
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
7153
8067
 
7154
- ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
8068
+ ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
7155
8069
  ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
7156
8070
  ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
7157
8071
  ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
@@ -7183,7 +8097,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7183
8097
  memset(VKQ32, 0, DV*sizeof(float));
7184
8098
  }
7185
8099
 
7186
- const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
8100
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
7187
8101
 
7188
8102
  // k indices
7189
8103
  const int ik3 = iq3 / rk3;
@@ -7721,120 +8635,210 @@ void ggml_compute_forward_ssm_conv(
7721
8635
  static void ggml_compute_forward_ssm_scan_f32(
7722
8636
  const ggml_compute_params * params,
7723
8637
  ggml_tensor * dst) {
7724
- const ggml_tensor * src0 = dst->src[0]; // s
7725
- const ggml_tensor * src1 = dst->src[1]; // x
7726
- const ggml_tensor * src2 = dst->src[2]; // dt
7727
- const ggml_tensor * src3 = dst->src[3]; // A
7728
- const ggml_tensor * src4 = dst->src[4]; // B
7729
- const ggml_tensor * src5 = dst->src[5]; // C
8638
+ const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
8639
+ const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
8640
+ const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
8641
+ const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
8642
+ const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
8643
+ const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
8644
+ const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
7730
8645
 
7731
8646
  const int ith = params->ith;
7732
8647
  const int nth = params->nth;
7733
8648
 
7734
- const int64_t nc = src0->ne[0]; // d_state
7735
- const int64_t nr = src0->ne[1]; // d_inner
7736
- const int64_t n_t = src1->ne[1]; // number of tokens per sequence
7737
- const int64_t n_s = src0->ne[2]; // number of sequences in the batch
8649
+ const int64_t nc = src0->ne[0]; // d_state
8650
+ const int64_t nr = src0->ne[1]; // dim
8651
+ const int64_t nh = src1->ne[1]; // n_head
8652
+ const int64_t ng = src4->ne[1];
8653
+ const int64_t nt = src1->ne[2]; // number of tokens per sequence
8654
+ const int64_t ns = src1->ne[3]; // number of sequences in the batch
8655
+
8656
+ // can't use ggml_nbytes because src1 is not necessarily contiguous
8657
+ const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
7738
8658
 
7739
- GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
8659
+ GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
7740
8660
  GGML_ASSERT(src0->nb[0] == sizeof(float));
7741
8661
  GGML_ASSERT(src1->nb[0] == sizeof(float));
7742
8662
  GGML_ASSERT(src2->nb[0] == sizeof(float));
7743
8663
  GGML_ASSERT(src3->nb[0] == sizeof(float));
7744
8664
  GGML_ASSERT(src4->nb[0] == sizeof(float));
7745
8665
  GGML_ASSERT(src5->nb[0] == sizeof(float));
7746
- // required for the dot product between s and C
7747
- GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
7748
- // required for per-sequence offsets for states
7749
- GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
7750
- // required to get correct offset for state destination (i.e. src1->nb[3])
7751
- GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
8666
+ GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
8667
+ // allows optimizing the modulo since n_group should be a power of 2
8668
+ GGML_ASSERT((ng & -ng) == ng);
7752
8669
 
7753
- // rows per thread
7754
- const int dr = (nr + nth - 1)/nth;
8670
+ // heads per thread
8671
+ const int dh = (nh + nth - 1)/nth;
7755
8672
 
7756
- // row range for this thread
7757
- const int ir0 = dr*ith;
7758
- const int ir1 = MIN(ir0 + dr, nr);
7759
- const int ir = ir1 - ir0;
8673
+ // head range for this thread
8674
+ const int ih0 = dh*ith;
8675
+ const int ih1 = MIN(ih0 + dh, nh);
8676
+
8677
+ const int32_t * ids = (const int32_t *) src6->data;
7760
8678
 
7761
- #ifdef __ARM_FEATURE_SVE
7762
- for (int i3 = 0; i3 < n_s; ++i3) {
7763
- for (int i2 = 0; i2 < n_t; ++i2) {
7764
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7765
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7766
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7767
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7768
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7769
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7770
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7771
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7772
-
7773
- // use the output as the source for the next token-wise iterations
7774
- if (i2 > 0) { s0 = s; }
7775
-
7776
- // d_inner
7777
- for (int i1 = 0; i1 < ir; ++i1) {
7778
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7779
- float x_dt = x[i1] * dt_soft_plus;
7780
- svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
7781
- svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
7782
- svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
7783
-
7784
- for (int64_t k = 0; k < nc; k += svcntw()) {
7785
- svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
7786
- svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
7787
- svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
7788
- svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
7789
-
7790
- svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
7791
- t1 = exp_ps_sve(svptrue_b32(), t1);
7792
- svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
7793
-
7794
- vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
7795
- r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
7796
-
7797
- GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
8679
+ for (int i3 = 0; i3 < ns; ++i3) {
8680
+ const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
8681
+ float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
8682
+
8683
+ for (int i2 = 0; i2 < nt; ++i2) {
8684
+ const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
8685
+ const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
8686
+ const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
8687
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
8688
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
8689
+ float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
8690
+
8691
+ if (src3->ne[0] == 1) {
8692
+ // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
8693
+
8694
+ // n_head
8695
+ for (int h = ih0; h < ih1; ++h) {
8696
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8697
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8698
+ const float dA = expf(dt_soft_plus * A[h]);
8699
+
8700
+ // dim
8701
+ for (int i1 = 0; i1 < nr; ++i1) {
8702
+ const int ii = i1 + h*nr;
8703
+ const float x_dt = x[ii] * dt_soft_plus;
8704
+ float sumf = 0.0f;
8705
+ #if defined(GGML_SIMD)
8706
+ #if defined(__ARM_FEATURE_SVE)
8707
+ const int ggml_f32_epr = svcntw();
8708
+ const int ggml_f32_step = 1 * ggml_f32_epr;
8709
+
8710
+ const int np = (nc & ~(ggml_f32_step - 1));
8711
+
8712
+ GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
8713
+
8714
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8715
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8716
+
8717
+ for (int i = 0; i < np; i += ggml_f32_step) {
8718
+ // TODO: maybe unroll more?
8719
+ for (int j = 0; j < 1; j++) {
8720
+ GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
8721
+ GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
8722
+ GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
8723
+
8724
+ t0 = GGML_F32_VEC_MUL(t0, adA);
8725
+ t1 = GGML_F32_VEC_MUL(t1, axdt);
8726
+
8727
+ t0 = GGML_F32_VEC_ADD(t0, t1);
8728
+
8729
+ sum = GGML_F32_VEC_FMA(sum, t0, t2);
8730
+
8731
+ GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
8732
+ }
8733
+ }
8734
+
8735
+ sumf = GGML_F32xt_REDUCE_ONE(sum);
8736
+ #else
8737
+ const int np = (nc & ~(GGML_F32_STEP - 1));
8738
+
8739
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8740
+
8741
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8742
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8743
+
8744
+ GGML_F32_VEC ax[GGML_F32_ARR];
8745
+ GGML_F32_VEC ay[GGML_F32_ARR];
8746
+ GGML_F32_VEC az[GGML_F32_ARR];
8747
+
8748
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
8749
+ for (int j = 0; j < GGML_F32_ARR; j++) {
8750
+ ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
8751
+ ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
8752
+ az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
8753
+
8754
+ ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
8755
+ ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
8756
+
8757
+ ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
8758
+
8759
+ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
8760
+
8761
+ GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
8762
+ }
8763
+ }
8764
+
8765
+ // reduce sum0..sum3 to sum0
8766
+ GGML_F32_VEC_REDUCE(sumf, sum);
8767
+ #endif
8768
+ #else
8769
+ const int np = 0;
8770
+ #endif
8771
+ // d_state
8772
+ for (int i0 = np; i0 < nc; ++i0) {
8773
+ const int i = i0 + ii*nc;
8774
+ const int ig = i0 + (h & (ng - 1))*nc;
8775
+ // state = prev_state * dA + dB * x
8776
+ const float state = (s0[i] * dA) + (B[ig] * x_dt);
8777
+ // y = rowwise_dotprod(state, C)
8778
+ sumf += state * C[ig];
8779
+ s[i] = state;
8780
+ }
8781
+ y[ii] = sumf;
7798
8782
  }
7799
- y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
7800
8783
  }
7801
- }
7802
- }
7803
- #else
7804
- for (int i3 = 0; i3 < n_s; ++i3) {
7805
- for (int i2 = 0; i2 < n_t; ++i2) {
7806
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7807
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7808
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7809
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7810
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7811
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7812
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7813
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7814
-
7815
- // use the output as the source for the next token-wise iterations
7816
- if (i2 > 0) { s0 = s; }
7817
-
7818
- // d_inner
7819
- for (int i1 = 0; i1 < ir; ++i1) {
7820
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7821
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7822
- float x_dt = x[i1] * dt_soft_plus;
7823
- float sumf = 0.0f;
7824
- // d_state
7825
- for (int i0 = 0; i0 < nc; ++i0) {
7826
- int i = i0 + i1*nc;
7827
- // state = prev_state * dA + dB * x
7828
- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
7829
- // y = rowwise_dotprod(state, C)
7830
- sumf += state * C[i0];
7831
- s[i] = state;
8784
+ } else {
8785
+ // Mamba-1 has an element-wise decay factor for the states
8786
+
8787
+ // n_head
8788
+ for (int h = ih0; h < ih1; ++h) {
8789
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8790
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8791
+
8792
+ // dim
8793
+ for (int i1 = 0; i1 < nr; ++i1) {
8794
+ const int ii = i1 + h*nr;
8795
+ const float x_dt = x[ii] * dt_soft_plus;
8796
+ #if defined(__ARM_FEATURE_SVE)
8797
+ svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
8798
+ svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
8799
+ svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8800
+
8801
+ // d_state
8802
+ // TODO: what happens when (d_state % svcntw()) != 0?
8803
+ for (int64_t k = 0; k < nc; k += svcntw()) {
8804
+ svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
8805
+ svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
8806
+ svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
8807
+ svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
8808
+
8809
+ svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
8810
+ t1 = exp_ps_sve(svptrue_b32(), t1);
8811
+ svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
8812
+
8813
+ vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
8814
+ r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
8815
+
8816
+ GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
8817
+ }
8818
+ y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
8819
+ #else
8820
+ float sumf = 0.0f;
8821
+ // NOTE: can't really use GGML_SIMD here because d_state is usually 16
8822
+ // and also because expf is used within the loop.
8823
+ // d_state
8824
+ for (int i0 = 0; i0 < nc; ++i0) {
8825
+ const int i = i0 + ii*nc;
8826
+ const int ig = i0 + (h & (ng - 1))*nc;
8827
+ // state = prev_state * dA + dB * x
8828
+ const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
8829
+ // y = rowwise_dotprod(state, C)
8830
+ sumf += state * C[ig];
8831
+ s[i] = state;
8832
+ }
8833
+ y[ii] = sumf;
8834
+ #endif
7832
8835
  }
7833
- y[i1] = sumf;
7834
8836
  }
7835
8837
  }
8838
+ // use the output as the source when it's not the first token-wise iteration
8839
+ s0 = s;
7836
8840
  }
7837
- #endif
8841
+ }
7838
8842
  }
7839
8843
 
7840
8844
  void ggml_compute_forward_ssm_scan(
@@ -8052,6 +9056,42 @@ void ggml_compute_forward_unary(
8052
9056
  }
8053
9057
  }
8054
9058
 
9059
+ //ggml_compute_forward_glu
9060
+
9061
+ void ggml_compute_forward_glu(
9062
+ const ggml_compute_params * params,
9063
+ ggml_tensor * dst) {
9064
+
9065
+ const ggml_glu_op op = ggml_get_glu_op(dst);
9066
+
9067
+ switch (op) {
9068
+ case GGML_GLU_OP_REGLU:
9069
+ {
9070
+ ggml_compute_forward_reglu(params, dst);
9071
+ } break;
9072
+ case GGML_GLU_OP_GEGLU:
9073
+ {
9074
+ ggml_compute_forward_geglu(params, dst);
9075
+ } break;
9076
+ case GGML_GLU_OP_SWIGLU:
9077
+ {
9078
+ ggml_compute_forward_swiglu(params, dst);
9079
+ } break;
9080
+ case GGML_GLU_OP_GEGLU_ERF:
9081
+ {
9082
+ ggml_compute_forward_geglu_erf(params, dst);
9083
+ } break;
9084
+ case GGML_GLU_OP_GEGLU_QUICK:
9085
+ {
9086
+ ggml_compute_forward_geglu_quick(params, dst);
9087
+ } break;
9088
+ default:
9089
+ {
9090
+ GGML_ABORT("fatal error");
9091
+ }
9092
+ }
9093
+ }
9094
+
8055
9095
  // ggml_compute_forward_get_rel_pos
8056
9096
 
8057
9097
  static void ggml_compute_forward_get_rel_pos_f16(