@novastera-oss/llamarn 0.2.7 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (319) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  13. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  18. package/cpp/LlamaCppModel.cpp +56 -22
  19. package/cpp/build-info.cpp +2 -2
  20. package/cpp/llama.cpp/CMakeLists.txt +1 -2
  21. package/cpp/llama.cpp/README.md +4 -5
  22. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  23. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  24. package/cpp/llama.cpp/common/arg.cpp +24 -0
  25. package/cpp/llama.cpp/common/chat.cpp +37 -20
  26. package/cpp/llama.cpp/common/chat.h +2 -0
  27. package/cpp/llama.cpp/common/common.cpp +3 -0
  28. package/cpp/llama.cpp/common/common.h +5 -0
  29. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  30. package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
  31. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  32. package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
  33. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  34. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  35. package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
  36. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  100. package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
  101. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
  103. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
  104. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
  105. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  133. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  134. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  135. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  136. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  137. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
  138. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  141. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  142. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  144. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  168. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  169. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  170. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  172. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  173. package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
  177. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
  178. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
  179. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  180. package/cpp/llama.cpp/include/llama.h +8 -43
  181. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  182. package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
  183. package/cpp/llama.cpp/src/llama-arch.h +36 -1
  184. package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
  185. package/cpp/llama.cpp/src/llama-batch.h +105 -70
  186. package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
  187. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  188. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  189. package/cpp/llama.cpp/src/llama-context.h +13 -13
  190. package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
  191. package/cpp/llama.cpp/src/llama-graph.h +78 -79
  192. package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
  193. package/cpp/llama.cpp/src/llama-hparams.h +11 -0
  194. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
  195. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  196. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
  197. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
  198. package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
  199. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
  200. package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
  201. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
  202. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  203. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  204. package/cpp/llama.cpp/src/llama-memory.h +21 -22
  205. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  206. package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
  207. package/cpp/llama.cpp/src/llama-model.h +40 -0
  208. package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
  209. package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
  210. package/cpp/llama.cpp/src/llama-vocab.h +42 -0
  211. package/cpp/rn-utils.h +3 -0
  212. package/ios/include/chat.h +2 -0
  213. package/ios/include/common.h +5 -0
  214. package/ios/include/llama.h +8 -43
  215. package/ios/libs/llama.xcframework/Info.plist +19 -19
  216. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  217. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  218. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  219. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  220. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
  221. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  223. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  224. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  225. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  231. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  232. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  248. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  250. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
  254. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
  255. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  261. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  262. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  263. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  264. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  265. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  267. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  268. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
  269. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
  270. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
  274. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  275. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  276. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  277. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  278. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  279. package/package.json +1 -1
  280. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  315. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  316. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  317. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  318. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  319. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -3,6 +3,7 @@
3
3
  #include "ggml-cpu.h"
4
4
  #include "ggml-impl.h"
5
5
  #include "binary-ops.h"
6
+ #include "ggml.h"
6
7
  #include "unary-ops.h"
7
8
  #include "vec.h"
8
9
 
@@ -108,7 +109,7 @@ static void ggml_compute_forward_dup_f16(
108
109
  for (int i01 = ir0; i01 < ir1; i01++) {
109
110
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
110
111
  for (int i00 = 0; i00 < ne00; i00++) {
111
- dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
112
+ dst_ptr[id] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
112
113
  id++;
113
114
  }
114
115
  }
@@ -130,7 +131,7 @@ static void ggml_compute_forward_dup_f16(
130
131
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
131
132
 
132
133
  for (int i00 = 0; i00 < ne00; i00++) {
133
- src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
134
+ src0_f32[i00] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
134
135
  }
135
136
 
136
137
  quantize_row_q(src0_f32, dst_ptr + id, ne00);
@@ -156,7 +157,7 @@ static void ggml_compute_forward_dup_f16(
156
157
  for (int i00 = 0; i00 < ne00; i00++) {
157
158
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
158
159
 
159
- dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
160
+ dst_ptr[id] = GGML_CPU_FP16_TO_FP32(*src0_ptr);
160
161
  id++;
161
162
  }
162
163
  }
@@ -267,7 +268,7 @@ static void ggml_compute_forward_dup_f16(
267
268
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
268
269
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
269
270
 
270
- *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
271
+ *(float *) dst_ptr = GGML_CPU_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
271
272
 
272
273
  if (++i10 == ne0) {
273
274
  i10 = 0;
@@ -372,7 +373,7 @@ static void ggml_compute_forward_dup_bf16(
372
373
  for (int i01 = ir0; i01 < ir1; i01++) {
373
374
  const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
374
375
  for (int i00 = 0; i00 < ne00; i00++) {
375
- dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
376
+ dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
376
377
  id++;
377
378
  }
378
379
  }
@@ -473,7 +474,7 @@ static void ggml_compute_forward_dup_bf16(
473
474
  for (int i00 = 0; i00 < ne00; i00++) {
474
475
  const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
475
476
 
476
- dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
477
+ dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
477
478
  id++;
478
479
  }
479
480
  }
@@ -566,7 +567,7 @@ static void ggml_compute_forward_dup_bf16(
566
567
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
567
568
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
568
569
 
569
- *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
570
+ *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
570
571
 
571
572
  if (++i10 == ne0) {
572
573
  i10 = 0;
@@ -696,24 +697,8 @@ static void ggml_compute_forward_dup_f32(
696
697
  if (ggml_is_contiguous(dst)) {
697
698
  // TODO: simplify
698
699
  if (nb00 == sizeof(float)) {
699
- if (dst->type == GGML_TYPE_F32) {
700
- size_t id = 0;
701
- const size_t rs = ne00 * nb00;
702
- char * dst_ptr = (char *) dst->data;
703
-
704
- for (int i03 = 0; i03 < ne03; i03++) {
705
- for (int i02 = 0; i02 < ne02; i02++) {
706
- id += rs * ir0;
707
- for (int i01 = ir0; i01 < ir1; i01++) {
708
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
709
- memcpy(dst_ptr + id, src0_ptr, rs);
710
- id += rs;
711
- }
712
- id += rs * (ne01 - ir1);
713
- }
714
- }
715
- } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
716
- ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
700
+ if (ggml_get_type_traits_cpu(dst->type)->from_float) {
701
+ ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
717
702
 
718
703
  size_t id = 0;
719
704
  size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
@@ -724,7 +709,7 @@ static void ggml_compute_forward_dup_f32(
724
709
  id += rs * ir0;
725
710
  for (int i01 = ir0; i01 < ir1; i01++) {
726
711
  const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
727
- quantize_row_q(src0_ptr, dst_ptr + id, ne00);
712
+ from_float(src0_ptr, dst_ptr + id, ne00);
728
713
  id += rs;
729
714
  }
730
715
  id += rs * (ne01 - ir1);
@@ -765,7 +750,7 @@ static void ggml_compute_forward_dup_f32(
765
750
  for (int i00 = 0; i00 < ne00; i00++) {
766
751
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
767
752
 
768
- dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
753
+ dst_ptr[id] = GGML_CPU_FP32_TO_FP16(*src0_ptr);
769
754
  id++;
770
755
  }
771
756
  }
@@ -878,7 +863,7 @@ static void ggml_compute_forward_dup_f32(
878
863
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
879
864
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
880
865
 
881
- *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
866
+ *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr);
882
867
 
883
868
  if (++i10 == ne0) {
884
869
  i10 = 0;
@@ -1419,7 +1404,7 @@ static void ggml_compute_forward_add1_f16_f32(
1419
1404
  ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
1420
1405
  ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
1421
1406
  for (int i = 0; i < ne0; i++) {
1422
- dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
1407
+ dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
1423
1408
  }
1424
1409
  }
1425
1410
  }
@@ -1435,7 +1420,7 @@ static void ggml_compute_forward_add1_f16_f16(
1435
1420
  GGML_ASSERT(ggml_is_scalar(src1));
1436
1421
 
1437
1422
  // scalar to add
1438
- const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
1423
+ const float v = GGML_CPU_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
1439
1424
 
1440
1425
  const int ith = params->ith;
1441
1426
  const int nth = params->nth;
@@ -1467,7 +1452,7 @@ static void ggml_compute_forward_add1_f16_f16(
1467
1452
  ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
1468
1453
  ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
1469
1454
  for (int i = 0; i < ne0; i++) {
1470
- dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
1455
+ dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
1471
1456
  }
1472
1457
  }
1473
1458
  }
@@ -1889,7 +1874,7 @@ static void ggml_compute_forward_sum_f16(
1889
1874
  }
1890
1875
  }
1891
1876
  }
1892
- ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
1877
+ ((ggml_fp16_t *) dst->data)[0] = GGML_CPU_FP32_TO_FP16(sum);
1893
1878
  }
1894
1879
 
1895
1880
  static void ggml_compute_forward_sum_bf16(
@@ -2300,6 +2285,12 @@ void ggml_compute_forward_repeat(
2300
2285
  {
2301
2286
  ggml_compute_forward_repeat_f32(params, dst);
2302
2287
  } break;
2288
+ // TODO: templateify the implemenation and support for I64
2289
+ // ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
2290
+ //case GGML_TYPE_I64:
2291
+ // {
2292
+ // ggml_compute_forward_repeat_i64(params, dst);
2293
+ // } break;
2303
2294
  default:
2304
2295
  {
2305
2296
  GGML_ABORT("fatal error");
@@ -2660,7 +2651,7 @@ static void ggml_compute_forward_gelu_f16(
2660
2651
  #ifndef NDEBUG
2661
2652
  for (int k = 0; k < nc; k++) {
2662
2653
  const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2663
- const float v = GGML_FP16_TO_FP32(x);
2654
+ const float v = GGML_CPU_FP16_TO_FP32(x);
2664
2655
  GGML_UNUSED(v);
2665
2656
  assert(!isnan(v));
2666
2657
  assert(!isinf(v));
@@ -2763,7 +2754,7 @@ static void ggml_compute_forward_gelu_erf_f16(
2763
2754
  #ifndef NDEBUG
2764
2755
  for (int k = 0; k < nc; k++) {
2765
2756
  const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2766
- const float v = GGML_FP16_TO_FP32(x);
2757
+ const float v = GGML_CPU_FP16_TO_FP32(x);
2767
2758
  GGML_UNUSED(v);
2768
2759
  assert(!isnan(v));
2769
2760
  assert(!isinf(v));
@@ -2866,7 +2857,7 @@ static void ggml_compute_forward_gelu_quick_f16(
2866
2857
  #ifndef NDEBUG
2867
2858
  for (int k = 0; k < nc; k++) {
2868
2859
  const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2869
- const float v = GGML_FP16_TO_FP32(x);
2860
+ const float v = GGML_CPU_FP16_TO_FP32(x);
2870
2861
  GGML_UNUSED(v);
2871
2862
  assert(!isnan(v));
2872
2863
  assert(!isinf(v));
@@ -2969,7 +2960,7 @@ static void ggml_compute_forward_silu_f16(
2969
2960
  #ifndef NDEBUG
2970
2961
  for (int k = 0; k < nc; k++) {
2971
2962
  const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
2972
- const float v = GGML_FP16_TO_FP32(x);
2963
+ const float v = GGML_CPU_FP16_TO_FP32(x);
2973
2964
  GGML_UNUSED(v);
2974
2965
  assert(!isnan(v));
2975
2966
  assert(!isinf(v));
@@ -3144,8 +3135,718 @@ static void ggml_compute_forward_silu_back_f16(
3144
3135
  const int ith = params->ith;
3145
3136
  const int nth = params->nth;
3146
3137
 
3147
- const int nc = src1->ne[0];
3148
- const int nr = ggml_nrows(src1);
3138
+ const int nc = src1->ne[0];
3139
+ const int nr = ggml_nrows(src1);
3140
+
3141
+ // rows per thread
3142
+ const int dr = (nr + nth - 1)/nth;
3143
+
3144
+ // row range for this thread
3145
+ const int ir0 = dr*ith;
3146
+ const int ir1 = MIN(ir0 + dr, nr);
3147
+
3148
+ for (int i1 = ir0; i1 < ir1; i1++) {
3149
+ ggml_vec_silu_backward_f16(nc,
3150
+ (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
3151
+ (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
3152
+ (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
3153
+
3154
+ #ifndef NDEBUG
3155
+ for (int k = 0; k < nc; k++) {
3156
+ const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3157
+ const float v = GGML_CPU_FP16_TO_FP32(x);
3158
+ GGML_UNUSED(v);
3159
+ assert(!isnan(v));
3160
+ assert(!isinf(v));
3161
+ }
3162
+ #endif
3163
+ }
3164
+ }
3165
+
3166
+ void ggml_compute_forward_silu_back(
3167
+ const ggml_compute_params * params,
3168
+ ggml_tensor * dst) {
3169
+
3170
+ const ggml_tensor * src0 = dst->src[0];
3171
+
3172
+ switch (src0->type) {
3173
+ case GGML_TYPE_F32:
3174
+ {
3175
+ ggml_compute_forward_silu_back_f32(params, dst);
3176
+ } break;
3177
+ case GGML_TYPE_F16:
3178
+ {
3179
+ ggml_compute_forward_silu_back_f16(params, dst);
3180
+ } break;
3181
+ default:
3182
+ {
3183
+ GGML_ABORT("fatal error");
3184
+ }
3185
+ }
3186
+ }
3187
+
3188
+ // ggml_compute_forward_reglu
3189
+
3190
+ static void ggml_compute_forward_reglu_f32(
3191
+ const ggml_compute_params * params,
3192
+ ggml_tensor * dst) {
3193
+
3194
+ const ggml_tensor * src0 = dst->src[0];
3195
+ const ggml_tensor * src1 = dst->src[1];
3196
+ char * src0_d = (char *) src0->data;
3197
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3198
+ const size_t src0_o = src0->nb[1];
3199
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3200
+
3201
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3202
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3203
+
3204
+ if (src1) {
3205
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3206
+ GGML_ASSERT(src0->type == src1->type);
3207
+ }
3208
+
3209
+ const int ith = params->ith;
3210
+ const int nth = params->nth;
3211
+
3212
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3213
+ const int nr = ggml_nrows(src0);
3214
+
3215
+ GGML_ASSERT(dst->ne[0] == nc);
3216
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3217
+
3218
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3219
+
3220
+ // rows per thread
3221
+ const int dr = (nr + nth - 1)/nth;
3222
+
3223
+ // row range for this thread
3224
+ const int ir0 = dr*ith;
3225
+ const int ir1 = MIN(ir0 + dr, nr);
3226
+
3227
+ for (int i1 = ir0; i1 < ir1; i1++) {
3228
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3229
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3230
+
3231
+ if (!src1) {
3232
+ src0_p += swapped ? nc : 0;
3233
+ src1_p += swapped ? 0 : nc;
3234
+ }
3235
+
3236
+ ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3237
+
3238
+ #ifndef NDEBUG
3239
+ for (int k = 0; k < nc; k++) {
3240
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3241
+ GGML_UNUSED(x);
3242
+ assert(!isnan(x));
3243
+ assert(!isinf(x));
3244
+ }
3245
+ #endif
3246
+ }
3247
+ }
3248
+
3249
+ static void ggml_compute_forward_reglu_f16(
3250
+ const ggml_compute_params * params,
3251
+ ggml_tensor * dst) {
3252
+
3253
+ const ggml_tensor * src0 = dst->src[0];
3254
+ const ggml_tensor * src1 = dst->src[1];
3255
+ char * src0_d = (char *) src0->data;
3256
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3257
+ const size_t src0_o = src0->nb[1];
3258
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3259
+
3260
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3261
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3262
+
3263
+ if (src1) {
3264
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3265
+ GGML_ASSERT(src0->type == src1->type);
3266
+ }
3267
+
3268
+ const int ith = params->ith;
3269
+ const int nth = params->nth;
3270
+
3271
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3272
+ const int nr = ggml_nrows(src0);
3273
+
3274
+ GGML_ASSERT(dst->ne[0] == nc);
3275
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3276
+
3277
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3278
+
3279
+ // rows per thread
3280
+ const int dr = (nr + nth - 1)/nth;
3281
+
3282
+ // row range for this thread
3283
+ const int ir0 = dr*ith;
3284
+ const int ir1 = MIN(ir0 + dr, nr);
3285
+
3286
+ for (int i1 = ir0; i1 < ir1; i1++) {
3287
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3288
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3289
+
3290
+ if (!src1) {
3291
+ src0_p += swapped ? nc : 0;
3292
+ src1_p += swapped ? 0 : nc;
3293
+ }
3294
+
3295
+ ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3296
+
3297
+ #ifndef NDEBUG
3298
+ for (int k = 0; k < nc; k++) {
3299
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3300
+ const float v = GGML_FP16_TO_FP32(x);
3301
+ GGML_UNUSED(v);
3302
+ assert(!isnan(v));
3303
+ assert(!isinf(v));
3304
+ }
3305
+ #endif
3306
+ }
3307
+ }
3308
+
3309
+ static void ggml_compute_forward_reglu(
3310
+ const ggml_compute_params * params,
3311
+ ggml_tensor * dst) {
3312
+
3313
+ const ggml_tensor * src0 = dst->src[0];
3314
+
3315
+ switch (src0->type) {
3316
+ case GGML_TYPE_F32:
3317
+ {
3318
+ ggml_compute_forward_reglu_f32(params, dst);
3319
+ } break;
3320
+ case GGML_TYPE_F16:
3321
+ {
3322
+ ggml_compute_forward_reglu_f16(params, dst);
3323
+ } break;
3324
+ default:
3325
+ {
3326
+ GGML_ABORT("fatal error");
3327
+ }
3328
+ }
3329
+ }
3330
+
3331
+ // ggml_compute_forward_geglu
3332
+
3333
+ static void ggml_compute_forward_geglu_f32(
3334
+ const ggml_compute_params * params,
3335
+ ggml_tensor * dst) {
3336
+
3337
+ const ggml_tensor * src0 = dst->src[0];
3338
+ const ggml_tensor * src1 = dst->src[1];
3339
+ char * src0_d = (char *) src0->data;
3340
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3341
+ const size_t src0_o = src0->nb[1];
3342
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3343
+
3344
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3345
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3346
+
3347
+ if (src1) {
3348
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3349
+ GGML_ASSERT(src0->type == src1->type);
3350
+ }
3351
+
3352
+ const int ith = params->ith;
3353
+ const int nth = params->nth;
3354
+
3355
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3356
+ const int nr = ggml_nrows(src0);
3357
+
3358
+ GGML_ASSERT(dst->ne[0] == nc);
3359
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3360
+
3361
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3362
+
3363
+ // rows per thread
3364
+ const int dr = (nr + nth - 1)/nth;
3365
+
3366
+ // row range for this thread
3367
+ const int ir0 = dr*ith;
3368
+ const int ir1 = MIN(ir0 + dr, nr);
3369
+
3370
+ for (int i1 = ir0; i1 < ir1; i1++) {
3371
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3372
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3373
+
3374
+ if (!src1) {
3375
+ src0_p += swapped ? nc : 0;
3376
+ src1_p += swapped ? 0 : nc;
3377
+ }
3378
+
3379
+ ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3380
+
3381
+ #ifndef NDEBUG
3382
+ for (int k = 0; k < nc; k++) {
3383
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3384
+ GGML_UNUSED(x);
3385
+ assert(!isnan(x));
3386
+ assert(!isinf(x));
3387
+ }
3388
+ #endif
3389
+ }
3390
+ }
3391
+
3392
+ static void ggml_compute_forward_geglu_f16(
3393
+ const ggml_compute_params * params,
3394
+ ggml_tensor * dst) {
3395
+
3396
+ const ggml_tensor * src0 = dst->src[0];
3397
+ const ggml_tensor * src1 = dst->src[1];
3398
+ char * src0_d = (char *) src0->data;
3399
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3400
+ const size_t src0_o = src0->nb[1];
3401
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3402
+
3403
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3404
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3405
+
3406
+ if (src1) {
3407
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3408
+ GGML_ASSERT(src0->type == src1->type);
3409
+ }
3410
+
3411
+ const int ith = params->ith;
3412
+ const int nth = params->nth;
3413
+
3414
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3415
+ const int nr = ggml_nrows(src0);
3416
+
3417
+ GGML_ASSERT(dst->ne[0] == nc);
3418
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3419
+
3420
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3421
+
3422
+ // rows per thread
3423
+ const int dr = (nr + nth - 1)/nth;
3424
+
3425
+ // row range for this thread
3426
+ const int ir0 = dr*ith;
3427
+ const int ir1 = MIN(ir0 + dr, nr);
3428
+
3429
+ for (int i1 = ir0; i1 < ir1; i1++) {
3430
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3431
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3432
+
3433
+ if (!src1) {
3434
+ src0_p += swapped ? nc : 0;
3435
+ src1_p += swapped ? 0 : nc;
3436
+ }
3437
+
3438
+ ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3439
+
3440
+ #ifndef NDEBUG
3441
+ for (int k = 0; k < nc; k++) {
3442
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3443
+ const float v = GGML_FP16_TO_FP32(x);
3444
+ GGML_UNUSED(v);
3445
+ assert(!isnan(v));
3446
+ assert(!isinf(v));
3447
+ }
3448
+ #endif
3449
+ }
3450
+ }
3451
+
3452
+ static void ggml_compute_forward_geglu(
3453
+ const ggml_compute_params * params,
3454
+ ggml_tensor * dst) {
3455
+
3456
+ const ggml_tensor * src0 = dst->src[0];
3457
+
3458
+ switch (src0->type) {
3459
+ case GGML_TYPE_F32:
3460
+ {
3461
+ ggml_compute_forward_geglu_f32(params, dst);
3462
+ } break;
3463
+ case GGML_TYPE_F16:
3464
+ {
3465
+ ggml_compute_forward_geglu_f16(params, dst);
3466
+ } break;
3467
+ default:
3468
+ {
3469
+ GGML_ABORT("fatal error");
3470
+ }
3471
+ }
3472
+ }
3473
+
3474
+ // ggml_compute_forward_swiglu
3475
+
3476
+ static void ggml_compute_forward_swiglu_f32(
3477
+ const ggml_compute_params * params,
3478
+ ggml_tensor * dst) {
3479
+
3480
+ const ggml_tensor * src0 = dst->src[0];
3481
+ const ggml_tensor * src1 = dst->src[1];
3482
+ char * src0_d = (char *) src0->data;
3483
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3484
+ const size_t src0_o = src0->nb[1];
3485
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3486
+
3487
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3488
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3489
+
3490
+ if (src1) {
3491
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3492
+ GGML_ASSERT(src0->type == src1->type);
3493
+ }
3494
+
3495
+ const int ith = params->ith;
3496
+ const int nth = params->nth;
3497
+
3498
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3499
+ const int nr = ggml_nrows(src0);
3500
+
3501
+ GGML_ASSERT(dst->ne[0] == nc);
3502
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3503
+
3504
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3505
+
3506
+ // rows per thread
3507
+ const int dr = (nr + nth - 1)/nth;
3508
+
3509
+ // row range for this thread
3510
+ const int ir0 = dr*ith;
3511
+ const int ir1 = MIN(ir0 + dr, nr);
3512
+
3513
+ for (int i1 = ir0; i1 < ir1; i1++) {
3514
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3515
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3516
+
3517
+ if (!src1) {
3518
+ src0_p += swapped ? nc : 0;
3519
+ src1_p += swapped ? 0 : nc;
3520
+ }
3521
+
3522
+ ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3523
+
3524
+ #ifndef NDEBUG
3525
+ for (int k = 0; k < nc; k++) {
3526
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3527
+ GGML_UNUSED(x);
3528
+ assert(!isnan(x));
3529
+ assert(!isinf(x));
3530
+ }
3531
+ #endif
3532
+ }
3533
+ }
3534
+
3535
+ static void ggml_compute_forward_swiglu_f16(
3536
+ const ggml_compute_params * params,
3537
+ ggml_tensor * dst) {
3538
+
3539
+ const ggml_tensor * src0 = dst->src[0];
3540
+ const ggml_tensor * src1 = dst->src[1];
3541
+ char * src0_d = (char *) src0->data;
3542
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3543
+ const size_t src0_o = src0->nb[1];
3544
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3545
+
3546
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3547
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3548
+
3549
+ if (src1) {
3550
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3551
+ GGML_ASSERT(src0->type == src1->type);
3552
+ }
3553
+
3554
+ const int ith = params->ith;
3555
+ const int nth = params->nth;
3556
+
3557
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3558
+ const int nr = ggml_nrows(src0);
3559
+
3560
+ GGML_ASSERT(dst->ne[0] == nc);
3561
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3562
+
3563
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3564
+
3565
+ // rows per thread
3566
+ const int dr = (nr + nth - 1)/nth;
3567
+
3568
+ // row range for this thread
3569
+ const int ir0 = dr*ith;
3570
+ const int ir1 = MIN(ir0 + dr, nr);
3571
+
3572
+ for (int i1 = ir0; i1 < ir1; i1++) {
3573
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3574
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3575
+
3576
+ if (!src1) {
3577
+ src0_p += swapped ? nc : 0;
3578
+ src1_p += swapped ? 0 : nc;
3579
+ }
3580
+
3581
+ ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3582
+
3583
+ #ifndef NDEBUG
3584
+ for (int k = 0; k < nc; k++) {
3585
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3586
+ const float v = GGML_FP16_TO_FP32(x);
3587
+ GGML_UNUSED(v);
3588
+ assert(!isnan(v));
3589
+ assert(!isinf(v));
3590
+ }
3591
+ #endif
3592
+ }
3593
+ }
3594
+
3595
+ static void ggml_compute_forward_swiglu(
3596
+ const ggml_compute_params * params,
3597
+ ggml_tensor * dst) {
3598
+
3599
+ const ggml_tensor * src0 = dst->src[0];
3600
+
3601
+ switch (src0->type) {
3602
+ case GGML_TYPE_F32:
3603
+ {
3604
+ ggml_compute_forward_swiglu_f32(params, dst);
3605
+ } break;
3606
+ case GGML_TYPE_F16:
3607
+ {
3608
+ ggml_compute_forward_swiglu_f16(params, dst);
3609
+ } break;
3610
+ default:
3611
+ {
3612
+ GGML_ABORT("fatal error");
3613
+ }
3614
+ }
3615
+ }
3616
+
3617
+ // ggml_compute_forward_geglu_erf
3618
+
3619
+ static void ggml_compute_forward_geglu_erf_f32(
3620
+ const ggml_compute_params * params,
3621
+ ggml_tensor * dst) {
3622
+
3623
+ const ggml_tensor * src0 = dst->src[0];
3624
+ const ggml_tensor * src1 = dst->src[1];
3625
+ char * src0_d = (char *) src0->data;
3626
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3627
+ const size_t src0_o = src0->nb[1];
3628
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3629
+
3630
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3631
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3632
+
3633
+ if (src1) {
3634
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3635
+ GGML_ASSERT(src0->type == src1->type);
3636
+ }
3637
+
3638
+ const int ith = params->ith;
3639
+ const int nth = params->nth;
3640
+
3641
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3642
+ const int nr = ggml_nrows(src0);
3643
+
3644
+ GGML_ASSERT(dst->ne[0] == nc);
3645
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3646
+
3647
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3648
+
3649
+ // rows per thread
3650
+ const int dr = (nr + nth - 1)/nth;
3651
+
3652
+ // row range for this thread
3653
+ const int ir0 = dr*ith;
3654
+ const int ir1 = MIN(ir0 + dr, nr);
3655
+
3656
+ for (int i1 = ir0; i1 < ir1; i1++) {
3657
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3658
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3659
+
3660
+ if (!src1) {
3661
+ src0_p += swapped ? nc : 0;
3662
+ src1_p += swapped ? 0 : nc;
3663
+ }
3664
+
3665
+ ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3666
+
3667
+ #ifndef NDEBUG
3668
+ for (int k = 0; k < nc; k++) {
3669
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3670
+ GGML_UNUSED(x);
3671
+ assert(!isnan(x));
3672
+ assert(!isinf(x));
3673
+ }
3674
+ #endif
3675
+ }
3676
+ }
3677
+
3678
+ static void ggml_compute_forward_geglu_erf_f16(
3679
+ const ggml_compute_params * params,
3680
+ ggml_tensor * dst) {
3681
+
3682
+ const ggml_tensor * src0 = dst->src[0];
3683
+ const ggml_tensor * src1 = dst->src[1];
3684
+ char * src0_d = (char *) src0->data;
3685
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3686
+ const size_t src0_o = src0->nb[1];
3687
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3688
+
3689
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3690
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3691
+
3692
+ if (src1) {
3693
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3694
+ GGML_ASSERT(src0->type == src1->type);
3695
+ }
3696
+
3697
+ const int ith = params->ith;
3698
+ const int nth = params->nth;
3699
+
3700
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3701
+ const int nr = ggml_nrows(src0);
3702
+
3703
+ GGML_ASSERT(dst->ne[0] == nc);
3704
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3705
+
3706
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3707
+
3708
+ // rows per thread
3709
+ const int dr = (nr + nth - 1)/nth;
3710
+
3711
+ // row range for this thread
3712
+ const int ir0 = dr*ith;
3713
+ const int ir1 = MIN(ir0 + dr, nr);
3714
+
3715
+ for (int i1 = ir0; i1 < ir1; i1++) {
3716
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3717
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3718
+
3719
+ if (!src1) {
3720
+ src0_p += swapped ? nc : 0;
3721
+ src1_p += swapped ? 0 : nc;
3722
+ }
3723
+
3724
+ ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3725
+
3726
+ #ifndef NDEBUG
3727
+ for (int k = 0; k < nc; k++) {
3728
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3729
+ const float v = GGML_FP16_TO_FP32(x);
3730
+ GGML_UNUSED(v);
3731
+ assert(!isnan(v));
3732
+ assert(!isinf(v));
3733
+ }
3734
+ #endif
3735
+ }
3736
+ }
3737
+
3738
+ static void ggml_compute_forward_geglu_erf(
3739
+ const ggml_compute_params * params,
3740
+ ggml_tensor * dst) {
3741
+
3742
+ const ggml_tensor * src0 = dst->src[0];
3743
+
3744
+ switch (src0->type) {
3745
+ case GGML_TYPE_F32:
3746
+ {
3747
+ ggml_compute_forward_geglu_erf_f32(params, dst);
3748
+ } break;
3749
+ case GGML_TYPE_F16:
3750
+ {
3751
+ ggml_compute_forward_geglu_erf_f16(params, dst);
3752
+ } break;
3753
+ default:
3754
+ {
3755
+ GGML_ABORT("fatal error");
3756
+ }
3757
+ }
3758
+ }
3759
+
3760
+ // ggml_compute_forward_geglu_quick
3761
+
3762
+ static void ggml_compute_forward_geglu_quick_f32(
3763
+ const ggml_compute_params * params,
3764
+ ggml_tensor * dst) {
3765
+
3766
+ const ggml_tensor * src0 = dst->src[0];
3767
+ const ggml_tensor * src1 = dst->src[1];
3768
+ char * src0_d = (char *) src0->data;
3769
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3770
+ const size_t src0_o = src0->nb[1];
3771
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3772
+
3773
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3774
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3775
+
3776
+ if (src1) {
3777
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3778
+ GGML_ASSERT(src0->type == src1->type);
3779
+ }
3780
+
3781
+ const int ith = params->ith;
3782
+ const int nth = params->nth;
3783
+
3784
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3785
+ const int nr = ggml_nrows(src0);
3786
+
3787
+ GGML_ASSERT(dst->ne[0] == nc);
3788
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3789
+
3790
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3791
+
3792
+ // rows per thread
3793
+ const int dr = (nr + nth - 1)/nth;
3794
+
3795
+ // row range for this thread
3796
+ const int ir0 = dr*ith;
3797
+ const int ir1 = MIN(ir0 + dr, nr);
3798
+
3799
+ for (int i1 = ir0; i1 < ir1; i1++) {
3800
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3801
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3802
+
3803
+ if (!src1) {
3804
+ src0_p += swapped ? nc : 0;
3805
+ src1_p += swapped ? 0 : nc;
3806
+ }
3807
+
3808
+ ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3809
+
3810
+ #ifndef NDEBUG
3811
+ for (int k = 0; k < nc; k++) {
3812
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3813
+ GGML_UNUSED(x);
3814
+ assert(!isnan(x));
3815
+ assert(!isinf(x));
3816
+ }
3817
+ #endif
3818
+ }
3819
+ }
3820
+
3821
+ static void ggml_compute_forward_geglu_quick_f16(
3822
+ const ggml_compute_params * params,
3823
+ ggml_tensor * dst) {
3824
+
3825
+ const ggml_tensor * src0 = dst->src[0];
3826
+ const ggml_tensor * src1 = dst->src[1];
3827
+ char * src0_d = (char *) src0->data;
3828
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3829
+ const size_t src0_o = src0->nb[1];
3830
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3831
+
3832
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3833
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3834
+
3835
+ if (src1) {
3836
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3837
+ GGML_ASSERT(src0->type == src1->type);
3838
+ }
3839
+
3840
+ const int ith = params->ith;
3841
+ const int nth = params->nth;
3842
+
3843
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3844
+ const int nr = ggml_nrows(src0);
3845
+
3846
+ GGML_ASSERT(dst->ne[0] == nc);
3847
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3848
+
3849
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3149
3850
 
3150
3851
  // rows per thread
3151
3852
  const int dr = (nr + nth - 1)/nth;
@@ -3155,24 +3856,29 @@ static void ggml_compute_forward_silu_back_f16(
3155
3856
  const int ir1 = MIN(ir0 + dr, nr);
3156
3857
 
3157
3858
  for (int i1 = ir0; i1 < ir1; i1++) {
3158
- ggml_vec_silu_backward_f16(nc,
3159
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
3160
- (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
3161
- (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
3859
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3860
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3162
3861
 
3163
- #ifndef NDEBUG
3862
+ if (!src1) {
3863
+ src0_p += swapped ? nc : 0;
3864
+ src1_p += swapped ? 0 : nc;
3865
+ }
3866
+
3867
+ ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3868
+
3869
+ #ifndef NDEBUG
3164
3870
  for (int k = 0; k < nc; k++) {
3165
- const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3871
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3166
3872
  const float v = GGML_FP16_TO_FP32(x);
3167
3873
  GGML_UNUSED(v);
3168
3874
  assert(!isnan(v));
3169
3875
  assert(!isinf(v));
3170
3876
  }
3171
- #endif
3877
+ #endif
3172
3878
  }
3173
3879
  }
3174
3880
 
3175
- void ggml_compute_forward_silu_back(
3881
+ static void ggml_compute_forward_geglu_quick(
3176
3882
  const ggml_compute_params * params,
3177
3883
  ggml_tensor * dst) {
3178
3884
 
@@ -3181,11 +3887,11 @@ void ggml_compute_forward_silu_back(
3181
3887
  switch (src0->type) {
3182
3888
  case GGML_TYPE_F32:
3183
3889
  {
3184
- ggml_compute_forward_silu_back_f32(params, dst);
3890
+ ggml_compute_forward_geglu_quick_f32(params, dst);
3185
3891
  } break;
3186
3892
  case GGML_TYPE_F16:
3187
3893
  {
3188
- ggml_compute_forward_silu_back_f16(params, dst);
3894
+ ggml_compute_forward_geglu_quick_f16(params, dst);
3189
3895
  } break;
3190
3896
  default:
3191
3897
  {
@@ -3937,9 +4643,11 @@ static void ggml_compute_forward_scale_f32(
3937
4643
  GGML_ASSERT(ggml_is_contiguous(dst));
3938
4644
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
3939
4645
 
3940
- // scale factor
3941
- float v;
3942
- memcpy(&v, dst->op_params, sizeof(float));
4646
+ float s; // scale factor
4647
+ float b; // bias
4648
+
4649
+ memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
4650
+ memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
3943
4651
 
3944
4652
  const int ith = params->ith;
3945
4653
  const int nth = params->nth;
@@ -3958,12 +4666,22 @@ static void ggml_compute_forward_scale_f32(
3958
4666
 
3959
4667
  const size_t nb1 = dst->nb[1];
3960
4668
 
3961
- for (int i1 = ir0; i1 < ir1; i1++) {
3962
- if (dst->data != src0->data) {
3963
- // src0 is same shape as dst => same indices
3964
- memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4669
+ if (b == 0.0f) {
4670
+ for (int i1 = ir0; i1 < ir1; i1++) {
4671
+ if (dst->data != src0->data) {
4672
+ // src0 is same shape as dst => same indices
4673
+ // TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
4674
+ memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4675
+ }
4676
+ ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
4677
+ }
4678
+ } else {
4679
+ for (int i1 = ir0; i1 < ir1; i1++) {
4680
+ ggml_vec_mad1_f32(nc,
4681
+ (float *) ((char *) dst->data + i1*nb1),
4682
+ (float *) ((char *) src0->data + i1*nb1),
4683
+ s, b);
3965
4684
  }
3966
- ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
3967
4685
  }
3968
4686
  }
3969
4687
 
@@ -4470,6 +5188,74 @@ void ggml_compute_forward_get_rows(
4470
5188
  //}
4471
5189
  }
4472
5190
 
5191
+ static void ggml_compute_forward_set_rows_f32(
5192
+ const ggml_compute_params * params,
5193
+ ggml_tensor * dst) {
5194
+
5195
+ const ggml_tensor * src0 = dst->src[0];
5196
+ const ggml_tensor * src1 = dst->src[1];
5197
+
5198
+ GGML_TENSOR_BINARY_OP_LOCALS
5199
+
5200
+ const int64_t nc = ne00;
5201
+ const int64_t nr = ne01;
5202
+
5203
+ assert(ne0 == nc);
5204
+ assert(ne2 == ne02);
5205
+ assert(ne3 == ne03);
5206
+ assert(src0->type == GGML_TYPE_F32);
5207
+ assert(ne02 % ne11 == 0);
5208
+ assert(ne03 % ne12 == 0);
5209
+
5210
+ const int ith = params->ith;
5211
+ const int nth = params->nth;
5212
+
5213
+ // rows per thread
5214
+ const int64_t dr = (nr + nth - 1)/nth;
5215
+
5216
+ // row range for this thread
5217
+ const int64_t ir0 = dr*ith;
5218
+ const int64_t ir1 = std::min(ir0 + dr, nr);
5219
+
5220
+ ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
5221
+
5222
+ for (int64_t i03 = 0; i03 < ne03; ++i03) {
5223
+ for (int64_t i02 = 0; i02 < ne02; ++i02) {
5224
+ for (int64_t i = ir0; i < ir1; ++i) {
5225
+ const int64_t i12 = i03%ne12;
5226
+ const int64_t i11 = i02%ne11;
5227
+ const int64_t i10 = i;
5228
+
5229
+ const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
5230
+
5231
+ GGML_ASSERT(i1 >= 0 && i1 < ne1);
5232
+
5233
+ from_float(
5234
+ (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
5235
+ ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
5236
+ }
5237
+ }
5238
+ }
5239
+ }
5240
+
5241
+ void ggml_compute_forward_set_rows(
5242
+ const ggml_compute_params * params,
5243
+ ggml_tensor * dst) {
5244
+
5245
+ const ggml_tensor * src0 = dst->src[0];
5246
+
5247
+ switch (src0->type) {
5248
+ case GGML_TYPE_F32:
5249
+ {
5250
+ ggml_compute_forward_set_rows_f32(params, dst);
5251
+ } break;
5252
+ default:
5253
+ {
5254
+ GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type));
5255
+ }
5256
+ }
5257
+ }
5258
+
4473
5259
  // ggml_compute_forward_get_rows_back
4474
5260
 
4475
5261
  static void ggml_compute_forward_get_rows_back_f32_f16(
@@ -4500,7 +5286,7 @@ static void ggml_compute_forward_get_rows_back_f32_f16(
4500
5286
 
4501
5287
  for (int j = 0; j < nc; ++j) {
4502
5288
  ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
4503
- ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_FP16_TO_FP32(v);
5289
+ ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_CPU_FP16_TO_FP32(v);
4504
5290
  }
4505
5291
  }
4506
5292
  }
@@ -4744,14 +5530,17 @@ static void ggml_compute_forward_soft_max_f32(
4744
5530
  memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
4745
5531
  memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
4746
5532
 
4747
- // TODO: handle transposed/permuted matrices
4748
-
4749
5533
  const int ith = params->ith;
4750
5534
  const int nth = params->nth;
4751
5535
 
4752
5536
  GGML_TENSOR_UNARY_OP_LOCALS
4753
5537
 
4754
- //const int64_t ne11 = src1 ? src1->ne[1] : 1;
5538
+ const int64_t nb11 = src1 ? src1->nb[1] : 1;
5539
+ const int64_t nb12 = src1 ? src1->nb[2] : 1;
5540
+ const int64_t nb13 = src1 ? src1->nb[3] : 1;
5541
+
5542
+ const int64_t ne12 = src1 ? src1->ne[2] : 1;
5543
+ const int64_t ne13 = src1 ? src1->ne[3] : 1;
4755
5544
 
4756
5545
  // TODO: is this supposed to be ceil instead of floor?
4757
5546
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -4761,68 +5550,66 @@ static void ggml_compute_forward_soft_max_f32(
4761
5550
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
4762
5551
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
4763
5552
 
4764
- const int nc = src0->ne[0];
4765
- const int nr = ggml_nrows(src0);
4766
-
4767
- // rows per thread
4768
- const int dr = (nr + nth - 1)/nth;
4769
-
4770
- // row range for this thread
4771
- const int ir0 = dr*ith;
4772
- const int ir1 = MIN(ir0 + dr, nr);
4773
-
4774
- float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
5553
+ float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
4775
5554
 
4776
5555
  const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
4777
5556
 
4778
- for (int i1 = ir0; i1 < ir1; i1++) {
4779
- // ALiBi
4780
- const uint32_t h = (i1/ne01)%ne02; // head
4781
- const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
4782
-
4783
- float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
4784
- float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
4785
-
4786
- // broadcast the mask across rows
4787
- ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
4788
- float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
4789
-
4790
- ggml_vec_cpy_f32 (nc, wp, sp);
4791
- ggml_vec_scale_f32(nc, wp, scale);
4792
- if (mp_f32) {
4793
- if (use_f16) {
4794
- for (int i = 0; i < nc; ++i) {
4795
- wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
4796
- }
4797
- } else {
4798
- for (int i = 0; i < nc; ++i) {
4799
- wp[i] += slope*mp_f32[i];
5557
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5558
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5559
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5560
+ const int64_t i11 = i01;
5561
+ const int64_t i12 = i02%ne12;
5562
+ const int64_t i13 = i03%ne13;
5563
+
5564
+ // ALiBi
5565
+ const uint32_t h = i02; // head
5566
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5567
+
5568
+ float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5569
+ float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
5570
+
5571
+ // broadcast the mask across rows
5572
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5573
+ float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5574
+
5575
+ ggml_vec_cpy_f32 (ne00, wp, sp);
5576
+ ggml_vec_scale_f32(ne00, wp, scale);
5577
+ if (mp_f32) {
5578
+ if (use_f16) {
5579
+ for (int i = 0; i < ne00; ++i) {
5580
+ wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5581
+ }
5582
+ } else {
5583
+ for (int i = 0; i < ne00; ++i) {
5584
+ wp[i] += slope*mp_f32[i];
5585
+ }
5586
+ }
4800
5587
  }
4801
- }
4802
- }
4803
5588
 
4804
5589
  #ifndef NDEBUG
4805
- for (int i = 0; i < nc; ++i) {
4806
- //printf("p[%d] = %f\n", i, p[i]);
4807
- assert(!isnan(wp[i]));
4808
- }
5590
+ for (int i = 0; i < ne00; ++i) {
5591
+ //printf("p[%d] = %f\n", i, p[i]);
5592
+ assert(!isnan(wp[i]));
5593
+ }
4809
5594
  #endif
4810
5595
 
4811
- float max = -INFINITY;
4812
- ggml_vec_max_f32(nc, &max, wp);
5596
+ float max = -INFINITY;
5597
+ ggml_vec_max_f32(ne00, &max, wp);
4813
5598
 
4814
- ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
4815
- assert(sum > 0.0);
5599
+ ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
5600
+ assert(sum > 0.0);
4816
5601
 
4817
- sum = 1.0/sum;
4818
- ggml_vec_scale_f32(nc, dp, sum);
5602
+ sum = 1.0/sum;
5603
+ ggml_vec_scale_f32(ne00, dp, sum);
4819
5604
 
4820
5605
  #ifndef NDEBUG
4821
- for (int i = 0; i < nc; ++i) {
4822
- assert(!isnan(dp[i]));
4823
- assert(!isinf(dp[i]));
4824
- }
5606
+ for (int i = 0; i < ne00; ++i) {
5607
+ assert(!isnan(dp[i]));
5608
+ assert(!isinf(dp[i]));
5609
+ }
4825
5610
  #endif
5611
+ }
5612
+ }
4826
5613
  }
4827
5614
  }
4828
5615
 
@@ -5018,8 +5805,8 @@ static void ggml_compute_forward_clamp_f16(
5018
5805
  ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
5019
5806
 
5020
5807
  for (int i = 0; i < nc; i++) {
5021
- float v = GGML_FP16_TO_FP32(src0_ptr[i]);
5022
- dst_ptr[i] = GGML_FP32_TO_FP16(MAX(MIN(v, max), min));
5808
+ float v = GGML_CPU_FP16_TO_FP32(src0_ptr[i]);
5809
+ dst_ptr[i] = GGML_CPU_FP32_TO_FP16(MAX(MIN(v, max), min));
5023
5810
  }
5024
5811
  }
5025
5812
  }
@@ -5476,11 +6263,11 @@ static void ggml_compute_forward_rope_f16(
5476
6263
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5477
6264
  ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5478
6265
 
5479
- const float x0 = GGML_FP16_TO_FP32(src[0]);
5480
- const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
6266
+ const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
6267
+ const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
5481
6268
 
5482
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5483
- dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
6269
+ dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
6270
+ dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5484
6271
  }
5485
6272
  } else {
5486
6273
  for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
@@ -5492,11 +6279,11 @@ static void ggml_compute_forward_rope_f16(
5492
6279
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5493
6280
  ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5494
6281
 
5495
- const float x0 = GGML_FP16_TO_FP32(src[0]);
5496
- const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
6282
+ const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
6283
+ const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
5497
6284
 
5498
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5499
- dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
6285
+ dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
6286
+ dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5500
6287
  }
5501
6288
  }
5502
6289
  } else {
@@ -5507,11 +6294,11 @@ static void ggml_compute_forward_rope_f16(
5507
6294
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5508
6295
  ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5509
6296
 
5510
- const float x0 = GGML_FP16_TO_FP32(src[0]);
5511
- const float x1 = GGML_FP16_TO_FP32(src[1]);
6297
+ const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
6298
+ const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
5512
6299
 
5513
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5514
- dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
6300
+ dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
6301
+ dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5515
6302
  }
5516
6303
  }
5517
6304
 
@@ -5525,11 +6312,11 @@ static void ggml_compute_forward_rope_f16(
5525
6312
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5526
6313
  ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5527
6314
 
5528
- const float x0 = GGML_FP16_TO_FP32(src[0]);
5529
- const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
6315
+ const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
6316
+ const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
5530
6317
 
5531
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5532
- dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
6318
+ dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
6319
+ dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5533
6320
  }
5534
6321
  } else {
5535
6322
  for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
@@ -5640,7 +6427,7 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
5640
6427
  for (int64_t i11 = 0; i11 < ne11; i11++) {
5641
6428
  const float * const src = (float *)((char *) src1->data + i11*nb11);
5642
6429
  for (int64_t i10 = 0; i10 < ne10; i10++) {
5643
- dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
6430
+ dst_data[i10*ne11 + i11] = GGML_CPU_FP32_TO_FP16(src[i10]);
5644
6431
  }
5645
6432
  }
5646
6433
  }
@@ -5933,7 +6720,7 @@ static void ggml_compute_forward_im2col_f16(
5933
6720
  if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
5934
6721
  dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
5935
6722
  } else {
5936
- dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
6723
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
5937
6724
  }
5938
6725
  }
5939
6726
  }
@@ -6058,6 +6845,186 @@ void ggml_compute_forward_im2col_back_f32(
6058
6845
  }
6059
6846
  }
6060
6847
 
6848
+ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6849
+ void * a, void * b, float * c) {
6850
+ const ggml_type_traits * traits = ggml_get_type_traits(type);
6851
+ struct ggml_tensor src1 = {};
6852
+ src1.type = type;
6853
+ src1.ne[0] = k;
6854
+ src1.ne[1] = m;
6855
+ src1.ne[2] = 1;
6856
+ src1.ne[3] = 1;
6857
+ src1.nb[0] = traits->type_size;
6858
+ src1.nb[1] = k * traits->type_size;
6859
+ src1.nb[2] = src1.nb[1];
6860
+ src1.nb[3] = src1.nb[2];
6861
+ src1.data = a;
6862
+
6863
+ struct ggml_tensor src0 = {};
6864
+ src0.type = type;
6865
+ src0.ne[0] = k;
6866
+ src0.ne[1] = n;
6867
+ src0.ne[2] = 1;
6868
+ src0.ne[3] = 1;
6869
+ src0.nb[0] = traits->type_size;
6870
+ src0.nb[1] = k * traits->type_size;
6871
+ src0.nb[2] = src0.nb[1];
6872
+ src0.nb[3] = src0.nb[2];
6873
+ src0.data = b;
6874
+
6875
+ struct ggml_tensor dst = {};
6876
+ dst.ne[0] = n;
6877
+ dst.ne[1] = m;
6878
+ dst.ne[2] = 1;
6879
+ dst.ne[3] = 1;
6880
+ dst.nb[0] = sizeof(float);
6881
+ dst.nb[1] = n * sizeof(float);
6882
+ dst.nb[2] = dst.nb[1];
6883
+ dst.nb[3] = dst.nb[2];
6884
+ dst.data = c;
6885
+ dst.src[0] = &src0;
6886
+ dst.src[1] = &src1;
6887
+
6888
+ ggml_compute_forward_mul_mat(params, &dst);
6889
+ }
6890
+
6891
+ // ggml_compute_forward_conv_2d
6892
+
6893
+ static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
6894
+ const ggml_tensor * kernel, // [KW, KH, IC, OC]
6895
+ const ggml_tensor * src, // [W, H, C, N]
6896
+ ggml_tensor * dst, // [OW, OH, OC, N]
6897
+ ggml_type kernel_type) {
6898
+
6899
+ GGML_ASSERT(ggml_is_contiguous(kernel));
6900
+ GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
6901
+ GGML_ASSERT(kernel->type == kernel_type);
6902
+
6903
+ const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
6904
+
6905
+ const int32_t stride_x = dst->op_params[0];
6906
+ const int32_t stride_y = dst->op_params[1];
6907
+ const int32_t pad_x = dst->op_params[2];
6908
+ const int32_t pad_y = dst->op_params[3];
6909
+ const int32_t dilation_x = dst->op_params[4];
6910
+ const int32_t dilation_y = dst->op_params[5];
6911
+
6912
+ const int64_t c_in = src->ne[2];
6913
+ const int64_t c_out = kernel->ne[3];
6914
+ GGML_ASSERT(c_in == kernel->ne[2]);
6915
+
6916
+ const int64_t src_w = src->ne[0];
6917
+ const int64_t src_h = src->ne[1];
6918
+ const int64_t knl_w = kernel->ne[0];
6919
+ const int64_t knl_h = kernel->ne[1];
6920
+ const int64_t dst_w = dst->ne[0];
6921
+ const int64_t dst_h = dst->ne[1];
6922
+
6923
+ const float * src_data = (float *) src->data;
6924
+ void * knl_data = kernel->data;
6925
+ float * dst_data = (float *) dst->data;
6926
+
6927
+ const int64_t knl_n = knl_w * knl_h * c_in;
6928
+ const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
6929
+
6930
+ const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float);
6931
+ const int64_t batch_size = params->wsize / space_per_patch;
6932
+ const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6933
+ const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
6934
+
6935
+ GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6936
+
6937
+ void * tmp = params->wdata;
6938
+
6939
+ for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6940
+
6941
+ const int64_t patch_start_batch = batch_i * patches_per_batch;
6942
+ const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch,
6943
+ patch_total);
6944
+ const int64_t patch_n = patch_end_batch - patch_start_batch;
6945
+
6946
+ const int64_t patch_per_thread = (patch_n + params->nth - 1) / params->nth;
6947
+ const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6948
+ const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
6949
+
6950
+ //im2col for a patch
6951
+ for (int64_t p = patch_start; p < patch_end; ++p) {
6952
+ const int64_t batch_n = p / (dst_w * dst_h);
6953
+ const int64_t src_x = (p / dst_w) % dst_h;
6954
+ const int64_t src_y = p % dst_w;
6955
+
6956
+ const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
6957
+ char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
6958
+
6959
+ for (int64_t ic = 0; ic < c_in; ++ic) {
6960
+ for (int64_t ky = 0; ky < knl_h; ++ky) {
6961
+ for (int64_t kx = 0; kx < knl_w; ++kx) {
6962
+ const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
6963
+ const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
6964
+
6965
+ int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6966
+
6967
+ float src_val;
6968
+ if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6969
+ src_val = 0.0f;
6970
+ } else {
6971
+ const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6972
+ src_val = *src_ptr;
6973
+ }
6974
+
6975
+ char * element_ptr = dst_row + dst_idx * traits->type_size;
6976
+ if (kernel_type == GGML_TYPE_F32) {
6977
+ *(float *) element_ptr = src_val;
6978
+ } else if (kernel_type == GGML_TYPE_F16) {
6979
+ *(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
6980
+ }
6981
+ }
6982
+ }
6983
+ }
6984
+ } // patches handled by this thread
6985
+
6986
+ ggml_barrier(params->threadpool);
6987
+
6988
+ float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
6989
+
6990
+ GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
6991
+
6992
+ // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6993
+ ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
6994
+
6995
+ ggml_barrier(params->threadpool);
6996
+
6997
+
6998
+ //permute back [OC, N, OH, OW] to [N, OC, OH, OW]
6999
+ const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
7000
+ const int64_t permute_start = params->ith * permute_per_thread;
7001
+ const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
7002
+
7003
+ for (int64_t i = permute_start; i < permute_end; ++i) {
7004
+ const int64_t p = patch_start_batch + i;
7005
+ const int64_t batch_n = p / (dst_w * dst_h);
7006
+ const int64_t dst_y = (p / dst_w) % dst_h;
7007
+ const int64_t dst_x = p % dst_w;
7008
+
7009
+ for (int64_t oc = 0; oc < c_out; ++oc) {
7010
+ const float value = gemm_output[i * c_out + oc];
7011
+ float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
7012
+ *dst_ptr = value;
7013
+ }
7014
+ }
7015
+ }
7016
+ }
7017
+
7018
+ void ggml_compute_forward_conv_2d(
7019
+ const ggml_compute_params * params,
7020
+ ggml_tensor * dst) {
7021
+
7022
+ const ggml_tensor * src0 = dst->src[0];
7023
+ const ggml_tensor * src1 = dst->src[1];
7024
+
7025
+ ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
7026
+ }
7027
+
6061
7028
  // ggml_compute_forward_conv_transpose_2d
6062
7029
 
6063
7030
  void ggml_compute_forward_conv_transpose_2d(
@@ -6109,7 +7076,7 @@ void ggml_compute_forward_conv_transpose_2d(
6109
7076
  const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
6110
7077
  ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
6111
7078
  for (int i10 = 0; i10 < ne10; i10++) {
6112
- dst_data[i10*ne12 + i12] = GGML_FP32_TO_FP16(src[i10]);
7079
+ dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
6113
7080
  }
6114
7081
  }
6115
7082
  }
@@ -6358,7 +7325,7 @@ static void ggml_compute_forward_pool_1d_sk_p0(
6358
7325
  case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
6359
7326
  }
6360
7327
  for (int ki = 0; ki < k; ++ki) {
6361
- const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
7328
+ const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
6362
7329
  switch (op) {
6363
7330
  case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
6364
7331
  case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
@@ -6450,7 +7417,7 @@ void ggml_compute_forward_pool_2d(
6450
7417
  for (int kx = 0; kx < k0; ++kx) {
6451
7418
  int j = ix + kx;
6452
7419
  if (j < 0 || j >= src->ne[0]) continue;
6453
- const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
7420
+ const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
6454
7421
  switch (op) {
6455
7422
  case GGML_OP_POOL_AVG: *out += srow_j; break;
6456
7423
  case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
@@ -6538,7 +7505,7 @@ void ggml_compute_forward_pool_2d_back(
6538
7505
  }
6539
7506
 
6540
7507
  const float val = dst->type == GGML_TYPE_F32 ?
6541
- ((const float *) drowf)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
7508
+ ((const float *) drowf)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
6542
7509
  if (val <= maxval) {
6543
7510
  continue;
6544
7511
  }
@@ -6558,7 +7525,7 @@ void ggml_compute_forward_pool_2d_back(
6558
7525
  if (dst->type == GGML_TYPE_F32) {
6559
7526
  ((float *) drow)[j] += grad0;
6560
7527
  } else {
6561
- ((ggml_fp16_t *) drow)[j] = GGML_FP32_TO_FP16(grad0 + GGML_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
7528
+ ((ggml_fp16_t *) drow)[j] = GGML_CPU_FP32_TO_FP16(grad0 + GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
6562
7529
  }
6563
7530
  } else if (op == GGML_OP_POOL_AVG) {
6564
7531
  const float grad = grad0 / ka;
@@ -6577,7 +7544,7 @@ void ggml_compute_forward_pool_2d_back(
6577
7544
  if (dst->type == GGML_TYPE_F32) {
6578
7545
  ((float *) drow)[j] += grad;
6579
7546
  } else {
6580
- ((ggml_fp16_t *) drow)[j] += GGML_FP32_TO_FP16(grad);
7547
+ ((ggml_fp16_t *) drow)[j] += GGML_CPU_FP32_TO_FP16(grad);
6581
7548
  }
6582
7549
  }
6583
7550
  }
@@ -6608,12 +7575,13 @@ static void ggml_compute_forward_upscale_f32(
6608
7575
 
6609
7576
  GGML_TENSOR_UNARY_OP_LOCALS
6610
7577
 
6611
- const float sf0 = (float)ne0/src0->ne[0];
6612
- const float sf1 = (float)ne1/src0->ne[1];
6613
- const float sf2 = (float)ne2/src0->ne[2];
6614
- const float sf3 = (float)ne3/src0->ne[3];
7578
+ float sf0 = (float)ne0/src0->ne[0];
7579
+ float sf1 = (float)ne1/src0->ne[1];
7580
+ float sf2 = (float)ne2/src0->ne[2];
7581
+ float sf3 = (float)ne3/src0->ne[3];
6615
7582
 
6616
- const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0);
7583
+ const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
7584
+ const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
6617
7585
 
6618
7586
  if (mode == GGML_SCALE_MODE_NEAREST) {
6619
7587
  for (int64_t i3 = 0; i3 < ne3; i3++) {
@@ -6634,8 +7602,12 @@ static void ggml_compute_forward_upscale_f32(
6634
7602
  }
6635
7603
  }
6636
7604
  } else if (mode == GGML_SCALE_MODE_BILINEAR) {
6637
- // setting a pixel offset of 0 would replicate the behavior of pytorch interpolate with align_corners=True
6638
- const float pixel_offset = 0.5f;
7605
+ float pixel_offset = 0.5f;
7606
+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7607
+ pixel_offset = 0.0f;
7608
+ sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
7609
+ sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
7610
+ }
6639
7611
 
6640
7612
  for (int64_t i3 = 0; i3 < ne3; i3++) {
6641
7613
  const int64_t i03 = i3 / sf3;
@@ -6793,6 +7765,73 @@ void ggml_compute_forward_pad_reflect_1d(
6793
7765
  }
6794
7766
  }
6795
7767
 
7768
+ // ggml_compute_forward_roll
7769
+
7770
+ static int64_t ggml_wrap_index(int64_t i, int64_t ne) {
7771
+ if (i < 0) {
7772
+ return i + ne;
7773
+ } else if (i >= ne) {
7774
+ return i - ne;
7775
+ }
7776
+ return i;
7777
+ }
7778
+
7779
+ static void ggml_compute_forward_roll_f32(
7780
+ const ggml_compute_params * params,
7781
+ ggml_tensor * dst) {
7782
+
7783
+ const ggml_tensor * src0 = dst->src[0];
7784
+ const float * src_data = (const float *) src0->data;
7785
+ float * dst_data = (float *) dst->data;
7786
+
7787
+ GGML_TENSOR_UNARY_OP_LOCALS
7788
+
7789
+ const int s0 = ggml_get_op_params_i32(dst, 0);
7790
+ const int s1 = ggml_get_op_params_i32(dst, 1);
7791
+ const int s2 = ggml_get_op_params_i32(dst, 2);
7792
+ const int s3 = ggml_get_op_params_i32(dst, 3);
7793
+
7794
+ const int64_t total = ne1 * ne2 * ne3;
7795
+ const int64_t per_thread = (total + params->nth) / params->nth;
7796
+ const int64_t start = params->ith * per_thread;
7797
+ const int64_t end = std::min(start + per_thread, total);
7798
+
7799
+ for (int64_t i = start; i < end; ++i) {
7800
+ const int64_t i1 = i % ne1;
7801
+ const int64_t i2 = (i / ne1) % ne2;
7802
+ const int64_t i3 = i / (ne2 * ne1);
7803
+ float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
7804
+
7805
+ const int64_t i01 = ggml_wrap_index(i1 - s1, ne01);
7806
+ const int64_t i02 = ggml_wrap_index(i2 - s2, ne02);
7807
+ const int64_t i03 = ggml_wrap_index(i3 - s3, ne03);
7808
+ const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
7809
+
7810
+ const int64_t s = ggml_wrap_index(-s0, ne00);
7811
+ const int64_t n = ne00 - s;
7812
+ ggml_vec_cpy_f32(n, dst_row, src_row + s);
7813
+ ggml_vec_cpy_f32(s, dst_row + n, src_row);
7814
+ }
7815
+ }
7816
+
7817
+ void ggml_compute_forward_roll(
7818
+ const ggml_compute_params * params,
7819
+ ggml_tensor * dst) {
7820
+
7821
+ const ggml_tensor * src0 = dst->src[0];
7822
+
7823
+ switch (src0->type) {
7824
+ case GGML_TYPE_F32:
7825
+ {
7826
+ ggml_compute_forward_roll_f32(params, dst);
7827
+ } break;
7828
+ default:
7829
+ {
7830
+ GGML_ABORT("fatal error");
7831
+ }
7832
+ }
7833
+ }
7834
+
6796
7835
  // ggml_compute_forward_arange
6797
7836
 
6798
7837
  static void ggml_compute_forward_arange_f32(
@@ -7026,7 +8065,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7026
8065
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
7027
8066
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
7028
8067
 
7029
- ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
8068
+ ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
7030
8069
  ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
7031
8070
  ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
7032
8071
  ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
@@ -7058,7 +8097,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7058
8097
  memset(VKQ32, 0, DV*sizeof(float));
7059
8098
  }
7060
8099
 
7061
- const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
8100
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
7062
8101
 
7063
8102
  // k indices
7064
8103
  const int ik3 = iq3 / rk3;
@@ -7075,7 +8114,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7075
8114
  // loop over n_kv and n_head_kv
7076
8115
  // ref: https://arxiv.org/pdf/2112.05682.pdf
7077
8116
  for (int64_t ic = 0; ic < nek1; ++ic) {
7078
- const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
8117
+ const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
7079
8118
  if (mv == -INFINITY) {
7080
8119
  continue;
7081
8120
  }
@@ -7143,7 +8182,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7143
8182
 
7144
8183
  if (v->type == GGML_TYPE_F16) {
7145
8184
  for (int64_t d = 0; d < DV; ++d) {
7146
- VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
8185
+ VKQ32[d] = GGML_CPU_FP16_TO_FP32(VKQ16[d]);
7147
8186
  }
7148
8187
  }
7149
8188
 
@@ -7596,120 +8635,210 @@ void ggml_compute_forward_ssm_conv(
7596
8635
  static void ggml_compute_forward_ssm_scan_f32(
7597
8636
  const ggml_compute_params * params,
7598
8637
  ggml_tensor * dst) {
7599
- const ggml_tensor * src0 = dst->src[0]; // s
7600
- const ggml_tensor * src1 = dst->src[1]; // x
7601
- const ggml_tensor * src2 = dst->src[2]; // dt
7602
- const ggml_tensor * src3 = dst->src[3]; // A
7603
- const ggml_tensor * src4 = dst->src[4]; // B
7604
- const ggml_tensor * src5 = dst->src[5]; // C
8638
+ const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
8639
+ const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
8640
+ const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
8641
+ const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
8642
+ const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
8643
+ const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
8644
+ const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
7605
8645
 
7606
8646
  const int ith = params->ith;
7607
8647
  const int nth = params->nth;
7608
8648
 
7609
- const int64_t nc = src0->ne[0]; // d_state
7610
- const int64_t nr = src0->ne[1]; // d_inner
7611
- const int64_t n_t = src1->ne[1]; // number of tokens per sequence
7612
- const int64_t n_s = src0->ne[2]; // number of sequences in the batch
8649
+ const int64_t nc = src0->ne[0]; // d_state
8650
+ const int64_t nr = src0->ne[1]; // dim
8651
+ const int64_t nh = src1->ne[1]; // n_head
8652
+ const int64_t ng = src4->ne[1];
8653
+ const int64_t nt = src1->ne[2]; // number of tokens per sequence
8654
+ const int64_t ns = src1->ne[3]; // number of sequences in the batch
8655
+
8656
+ // can't use ggml_nbytes because src1 is not necessarily contiguous
8657
+ const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
7613
8658
 
7614
- GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
8659
+ GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
7615
8660
  GGML_ASSERT(src0->nb[0] == sizeof(float));
7616
8661
  GGML_ASSERT(src1->nb[0] == sizeof(float));
7617
8662
  GGML_ASSERT(src2->nb[0] == sizeof(float));
7618
8663
  GGML_ASSERT(src3->nb[0] == sizeof(float));
7619
8664
  GGML_ASSERT(src4->nb[0] == sizeof(float));
7620
8665
  GGML_ASSERT(src5->nb[0] == sizeof(float));
7621
- // required for the dot product between s and C
7622
- GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
7623
- // required for per-sequence offsets for states
7624
- GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
7625
- // required to get correct offset for state destination (i.e. src1->nb[3])
7626
- GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
8666
+ GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
8667
+ // allows optimizing the modulo since n_group should be a power of 2
8668
+ GGML_ASSERT((ng & -ng) == ng);
7627
8669
 
7628
- // rows per thread
7629
- const int dr = (nr + nth - 1)/nth;
8670
+ // heads per thread
8671
+ const int dh = (nh + nth - 1)/nth;
7630
8672
 
7631
- // row range for this thread
7632
- const int ir0 = dr*ith;
7633
- const int ir1 = MIN(ir0 + dr, nr);
7634
- const int ir = ir1 - ir0;
8673
+ // head range for this thread
8674
+ const int ih0 = dh*ith;
8675
+ const int ih1 = MIN(ih0 + dh, nh);
8676
+
8677
+ const int32_t * ids = (const int32_t *) src6->data;
7635
8678
 
7636
- #ifdef __ARM_FEATURE_SVE
7637
- for (int i3 = 0; i3 < n_s; ++i3) {
7638
- for (int i2 = 0; i2 < n_t; ++i2) {
7639
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7640
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7641
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7642
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7643
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7644
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7645
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7646
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7647
-
7648
- // use the output as the source for the next token-wise iterations
7649
- if (i2 > 0) { s0 = s; }
7650
-
7651
- // d_inner
7652
- for (int i1 = 0; i1 < ir; ++i1) {
7653
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7654
- float x_dt = x[i1] * dt_soft_plus;
7655
- svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
7656
- svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
7657
- svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
7658
-
7659
- for (int64_t k = 0; k < nc; k += svcntw()) {
7660
- svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
7661
- svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
7662
- svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
7663
- svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
7664
-
7665
- svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
7666
- t1 = exp_ps_sve(svptrue_b32(), t1);
7667
- svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
7668
-
7669
- vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
7670
- r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
7671
-
7672
- GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
8679
+ for (int i3 = 0; i3 < ns; ++i3) {
8680
+ const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
8681
+ float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
8682
+
8683
+ for (int i2 = 0; i2 < nt; ++i2) {
8684
+ const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
8685
+ const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
8686
+ const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
8687
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
8688
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
8689
+ float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
8690
+
8691
+ if (src3->ne[0] == 1) {
8692
+ // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
8693
+
8694
+ // n_head
8695
+ for (int h = ih0; h < ih1; ++h) {
8696
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8697
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8698
+ const float dA = expf(dt_soft_plus * A[h]);
8699
+
8700
+ // dim
8701
+ for (int i1 = 0; i1 < nr; ++i1) {
8702
+ const int ii = i1 + h*nr;
8703
+ const float x_dt = x[ii] * dt_soft_plus;
8704
+ float sumf = 0.0f;
8705
+ #if defined(GGML_SIMD)
8706
+ #if defined(__ARM_FEATURE_SVE)
8707
+ const int ggml_f32_epr = svcntw();
8708
+ const int ggml_f32_step = 1 * ggml_f32_epr;
8709
+
8710
+ const int np = (nc & ~(ggml_f32_step - 1));
8711
+
8712
+ GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
8713
+
8714
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8715
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8716
+
8717
+ for (int i = 0; i < np; i += ggml_f32_step) {
8718
+ // TODO: maybe unroll more?
8719
+ for (int j = 0; j < 1; j++) {
8720
+ GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
8721
+ GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
8722
+ GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
8723
+
8724
+ t0 = GGML_F32_VEC_MUL(t0, adA);
8725
+ t1 = GGML_F32_VEC_MUL(t1, axdt);
8726
+
8727
+ t0 = GGML_F32_VEC_ADD(t0, t1);
8728
+
8729
+ sum = GGML_F32_VEC_FMA(sum, t0, t2);
8730
+
8731
+ GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
8732
+ }
8733
+ }
8734
+
8735
+ sumf = GGML_F32xt_REDUCE_ONE(sum);
8736
+ #else
8737
+ const int np = (nc & ~(GGML_F32_STEP - 1));
8738
+
8739
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8740
+
8741
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8742
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8743
+
8744
+ GGML_F32_VEC ax[GGML_F32_ARR];
8745
+ GGML_F32_VEC ay[GGML_F32_ARR];
8746
+ GGML_F32_VEC az[GGML_F32_ARR];
8747
+
8748
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
8749
+ for (int j = 0; j < GGML_F32_ARR; j++) {
8750
+ ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
8751
+ ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
8752
+ az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
8753
+
8754
+ ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
8755
+ ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
8756
+
8757
+ ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
8758
+
8759
+ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
8760
+
8761
+ GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
8762
+ }
8763
+ }
8764
+
8765
+ // reduce sum0..sum3 to sum0
8766
+ GGML_F32_VEC_REDUCE(sumf, sum);
8767
+ #endif
8768
+ #else
8769
+ const int np = 0;
8770
+ #endif
8771
+ // d_state
8772
+ for (int i0 = np; i0 < nc; ++i0) {
8773
+ const int i = i0 + ii*nc;
8774
+ const int ig = i0 + (h & (ng - 1))*nc;
8775
+ // state = prev_state * dA + dB * x
8776
+ const float state = (s0[i] * dA) + (B[ig] * x_dt);
8777
+ // y = rowwise_dotprod(state, C)
8778
+ sumf += state * C[ig];
8779
+ s[i] = state;
8780
+ }
8781
+ y[ii] = sumf;
7673
8782
  }
7674
- y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
7675
8783
  }
7676
- }
7677
- }
7678
- #else
7679
- for (int i3 = 0; i3 < n_s; ++i3) {
7680
- for (int i2 = 0; i2 < n_t; ++i2) {
7681
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7682
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7683
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7684
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7685
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7686
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7687
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7688
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7689
-
7690
- // use the output as the source for the next token-wise iterations
7691
- if (i2 > 0) { s0 = s; }
7692
-
7693
- // d_inner
7694
- for (int i1 = 0; i1 < ir; ++i1) {
7695
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7696
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7697
- float x_dt = x[i1] * dt_soft_plus;
7698
- float sumf = 0.0f;
7699
- // d_state
7700
- for (int i0 = 0; i0 < nc; ++i0) {
7701
- int i = i0 + i1*nc;
7702
- // state = prev_state * dA + dB * x
7703
- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
7704
- // y = rowwise_dotprod(state, C)
7705
- sumf += state * C[i0];
7706
- s[i] = state;
8784
+ } else {
8785
+ // Mamba-1 has an element-wise decay factor for the states
8786
+
8787
+ // n_head
8788
+ for (int h = ih0; h < ih1; ++h) {
8789
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8790
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8791
+
8792
+ // dim
8793
+ for (int i1 = 0; i1 < nr; ++i1) {
8794
+ const int ii = i1 + h*nr;
8795
+ const float x_dt = x[ii] * dt_soft_plus;
8796
+ #if defined(__ARM_FEATURE_SVE)
8797
+ svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
8798
+ svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
8799
+ svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8800
+
8801
+ // d_state
8802
+ // TODO: what happens when (d_state % svcntw()) != 0?
8803
+ for (int64_t k = 0; k < nc; k += svcntw()) {
8804
+ svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
8805
+ svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
8806
+ svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
8807
+ svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
8808
+
8809
+ svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
8810
+ t1 = exp_ps_sve(svptrue_b32(), t1);
8811
+ svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
8812
+
8813
+ vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
8814
+ r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
8815
+
8816
+ GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
8817
+ }
8818
+ y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
8819
+ #else
8820
+ float sumf = 0.0f;
8821
+ // NOTE: can't really use GGML_SIMD here because d_state is usually 16
8822
+ // and also because expf is used within the loop.
8823
+ // d_state
8824
+ for (int i0 = 0; i0 < nc; ++i0) {
8825
+ const int i = i0 + ii*nc;
8826
+ const int ig = i0 + (h & (ng - 1))*nc;
8827
+ // state = prev_state * dA + dB * x
8828
+ const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
8829
+ // y = rowwise_dotprod(state, C)
8830
+ sumf += state * C[ig];
8831
+ s[i] = state;
8832
+ }
8833
+ y[ii] = sumf;
8834
+ #endif
7707
8835
  }
7708
- y[i1] = sumf;
7709
8836
  }
7710
8837
  }
8838
+ // use the output as the source when it's not the first token-wise iteration
8839
+ s0 = s;
7711
8840
  }
7712
- #endif
8841
+ }
7713
8842
  }
7714
8843
 
7715
8844
  void ggml_compute_forward_ssm_scan(
@@ -7927,6 +9056,42 @@ void ggml_compute_forward_unary(
7927
9056
  }
7928
9057
  }
7929
9058
 
9059
+ //ggml_compute_forward_glu
9060
+
9061
+ void ggml_compute_forward_glu(
9062
+ const ggml_compute_params * params,
9063
+ ggml_tensor * dst) {
9064
+
9065
+ const ggml_glu_op op = ggml_get_glu_op(dst);
9066
+
9067
+ switch (op) {
9068
+ case GGML_GLU_OP_REGLU:
9069
+ {
9070
+ ggml_compute_forward_reglu(params, dst);
9071
+ } break;
9072
+ case GGML_GLU_OP_GEGLU:
9073
+ {
9074
+ ggml_compute_forward_geglu(params, dst);
9075
+ } break;
9076
+ case GGML_GLU_OP_SWIGLU:
9077
+ {
9078
+ ggml_compute_forward_swiglu(params, dst);
9079
+ } break;
9080
+ case GGML_GLU_OP_GEGLU_ERF:
9081
+ {
9082
+ ggml_compute_forward_geglu_erf(params, dst);
9083
+ } break;
9084
+ case GGML_GLU_OP_GEGLU_QUICK:
9085
+ {
9086
+ ggml_compute_forward_geglu_quick(params, dst);
9087
+ } break;
9088
+ default:
9089
+ {
9090
+ GGML_ABORT("fatal error");
9091
+ }
9092
+ }
9093
+ }
9094
+
7930
9095
  // ggml_compute_forward_get_rel_pos
7931
9096
 
7932
9097
  static void ggml_compute_forward_get_rel_pos_f16(