@novastera-oss/llamarn 0.4.0 → 0.4.3-beta4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (979) hide show
  1. package/RNLlamaCpp.podspec +4 -1
  2. package/android/CMakeLists.txt +13 -3
  3. package/android/src/main/cpp/include/llama.h +44 -21
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/LlamaCppModel.cpp +2 -10
  21. package/cpp/SystemUtils.cpp +3 -7
  22. package/cpp/build-info.cpp +2 -2
  23. package/cpp/llama.cpp/CMakeLists.txt +12 -0
  24. package/cpp/llama.cpp/CODEOWNERS +116 -10
  25. package/cpp/llama.cpp/CONTRIBUTING.md +30 -3
  26. package/cpp/llama.cpp/README.md +13 -5
  27. package/cpp/llama.cpp/build-xcframework.sh +5 -0
  28. package/cpp/llama.cpp/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  29. package/cpp/llama.cpp/common/CMakeLists.txt +12 -2
  30. package/cpp/llama.cpp/common/arg.cpp +303 -795
  31. package/cpp/llama.cpp/common/arg.h +2 -3
  32. package/cpp/llama.cpp/common/chat-parser-xml-toolcall.cpp +861 -0
  33. package/cpp/llama.cpp/common/chat-parser-xml-toolcall.h +45 -0
  34. package/cpp/llama.cpp/common/chat-parser.cpp +156 -15
  35. package/cpp/llama.cpp/common/chat-parser.h +13 -0
  36. package/cpp/llama.cpp/common/chat.cpp +1147 -88
  37. package/cpp/llama.cpp/common/chat.h +16 -3
  38. package/cpp/llama.cpp/common/common.cpp +70 -15
  39. package/cpp/llama.cpp/common/common.h +57 -19
  40. package/cpp/llama.cpp/common/download.cpp +1072 -0
  41. package/cpp/llama.cpp/common/download.h +55 -0
  42. package/cpp/llama.cpp/common/http.h +73 -0
  43. package/cpp/llama.cpp/common/json-partial.cpp +70 -2
  44. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +61 -22
  45. package/cpp/llama.cpp/common/json-schema-to-grammar.h +2 -0
  46. package/cpp/llama.cpp/common/log.cpp +59 -2
  47. package/cpp/llama.cpp/common/log.h +12 -4
  48. package/cpp/llama.cpp/common/sampling.cpp +84 -8
  49. package/cpp/llama.cpp/common/sampling.h +3 -1
  50. package/cpp/llama.cpp/common/speculative.cpp +1 -1
  51. package/cpp/llama.cpp/convert_hf_to_gguf.py +1608 -233
  52. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +6 -1
  53. package/cpp/llama.cpp/convert_lora_to_gguf.py +37 -5
  54. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -28
  55. package/cpp/llama.cpp/ggml/include/ggml-backend.h +19 -1
  56. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +1 -1
  57. package/cpp/llama.cpp/ggml/include/ggml-hexagon.h +19 -0
  58. package/cpp/llama.cpp/ggml/include/ggml-metal.h +1 -6
  59. package/cpp/llama.cpp/ggml/include/ggml-rpc.h +7 -9
  60. package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +2 -1
  61. package/cpp/llama.cpp/ggml/include/ggml.h +199 -6
  62. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +38 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +299 -130
  64. package/cpp/llama.cpp/ggml/src/ggml-backend-impl.h +4 -4
  65. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +21 -5
  66. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +99 -2
  67. package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +1 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  70. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +138 -47
  71. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +1584 -1773
  72. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +201 -317
  73. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +146 -187
  74. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +771 -713
  75. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +135 -77
  76. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  78. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +16 -17
  79. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +318 -145
  80. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +155 -60
  82. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +8 -8
  83. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +0 -1
  84. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +14 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -9
  86. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +108 -64
  87. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +14 -4
  88. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +530 -87
  89. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +37 -45
  90. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +349 -127
  91. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +947 -1218
  92. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +5 -4
  93. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +143 -29
  94. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +82 -76
  95. package/cpp/llama.cpp/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +7 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +233 -28
  102. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +326 -66
  103. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +12 -3
  104. package/cpp/llama.cpp/ggml/src/ggml-cuda/argsort.cu +102 -6
  105. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +110 -76
  106. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +167 -38
  107. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +6 -11
  108. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +12 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  110. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +245 -151
  111. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cuh +1 -5
  112. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +341 -289
  113. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  114. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cuh +1233 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +6 -6
  117. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  118. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +123 -220
  119. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +41 -39
  120. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +715 -45
  121. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +150 -0
  122. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cuh +1 -0
  123. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +321 -24
  124. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +93 -351
  125. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +828 -1
  126. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmid.cu +164 -0
  127. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmid.cuh +5 -0
  128. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +3 -166
  129. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1 -1
  130. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvf.cu +371 -78
  131. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  132. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +279 -147
  133. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  134. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +97 -85
  135. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad.cu +46 -23
  136. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +63 -54
  137. package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +12 -10
  138. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +192 -77
  139. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cuh +2 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +10 -9
  141. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +137 -75
  142. package/cpp/llama.cpp/ggml/src/ggml-cuda/set.cu +39 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-cuda/set.cuh +7 -0
  144. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  145. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  147. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  149. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  150. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  152. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  153. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  154. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  161. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  164. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  166. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  167. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  168. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  169. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  170. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  173. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  174. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  175. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  176. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  177. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  178. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  179. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  180. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  181. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  182. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  183. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  184. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  185. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  186. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  187. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  188. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  189. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  190. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  191. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  192. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  193. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  194. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  195. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  196. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  197. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  198. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  199. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  200. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  201. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  202. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  203. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  204. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  205. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  206. package/cpp/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu +336 -0
  207. package/cpp/llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh +16 -0
  208. package/cpp/llama.cpp/ggml/src/ggml-cuda/tsembd.cu +3 -3
  209. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +105 -11
  210. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +36 -0
  211. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +87 -6
  212. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +28 -12
  213. package/cpp/llama.cpp/ggml/src/ggml-hexagon/CMakeLists.txt +68 -0
  214. package/cpp/llama.cpp/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3807 -0
  215. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt +40 -0
  216. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/act-ops.c +442 -0
  217. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  218. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  219. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h +40 -0
  220. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-dma.c +69 -0
  221. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-dma.h +119 -0
  222. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-msg.h +156 -0
  223. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ops.h +64 -0
  224. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  225. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-exp.c +93 -0
  226. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.c +60 -0
  227. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  228. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-utils.c +960 -0
  229. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h +1032 -0
  230. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/main.c +829 -0
  231. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c +2223 -0
  232. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  233. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/rope-ops.c +418 -0
  234. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  235. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c +255 -0
  236. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  237. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  238. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp-utils.c +448 -0
  239. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp-utils.h +220 -0
  240. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  241. package/cpp/llama.cpp/ggml/src/ggml-impl.h +110 -12
  242. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +6 -5
  243. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  244. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  245. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  246. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m +599 -0
  247. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.cpp +1662 -0
  248. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.h +251 -0
  249. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m +1527 -0
  250. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +244 -39
  251. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp +3844 -0
  252. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.h +90 -0
  253. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.cpp +723 -0
  254. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +3453 -1907
  255. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  256. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +10 -0
  257. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1331 -109
  258. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/cvt.cl +126 -0
  259. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +31 -4
  260. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +35 -7
  261. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +31 -4
  262. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  263. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  264. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  265. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  266. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  267. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  268. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  269. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  270. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  271. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  272. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  273. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  274. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  275. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  276. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  277. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  278. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +123 -10
  279. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  280. package/cpp/llama.cpp/ggml/src/ggml-quants.c +1 -0
  281. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +341 -161
  282. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
  283. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  284. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +74 -15
  285. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +50 -30
  286. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +10 -4
  287. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +166 -99
  288. package/cpp/llama.cpp/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  289. package/cpp/llama.cpp/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  290. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +72 -94
  291. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  292. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +21 -31
  293. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +252 -316
  294. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +6 -2
  295. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +9 -6
  296. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +359 -142
  297. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  298. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  299. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +80 -60
  300. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  301. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +230 -55
  302. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.hpp +2 -0
  303. package/cpp/llama.cpp/ggml/src/ggml-sycl/pad.cpp +97 -0
  304. package/cpp/llama.cpp/ggml/src/ggml-sycl/pad.hpp +24 -0
  305. package/cpp/llama.cpp/ggml/src/ggml-sycl/pad_reflect_1d.cpp +72 -0
  306. package/cpp/llama.cpp/ggml/src/ggml-sycl/pad_reflect_1d.hpp +8 -0
  307. package/cpp/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
  308. package/cpp/llama.cpp/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  309. package/cpp/llama.cpp/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  310. package/cpp/llama.cpp/ggml/src/ggml-sycl/roll.cpp +122 -0
  311. package/cpp/llama.cpp/ggml/src/ggml-sycl/roll.hpp +20 -0
  312. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +50 -41
  313. package/cpp/llama.cpp/ggml/src/ggml-sycl/set.cpp +73 -0
  314. package/cpp/llama.cpp/ggml/src/ggml-sycl/set.hpp +5 -0
  315. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +45 -36
  316. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +330 -165
  317. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.hpp +4 -0
  318. package/cpp/llama.cpp/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  319. package/cpp/llama.cpp/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  320. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  321. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +16 -12
  322. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  323. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4184 -2159
  324. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  325. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  326. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  327. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  328. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  329. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  330. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  331. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  332. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  333. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  334. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  335. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  336. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  337. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  338. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +53 -30
  339. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  340. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  341. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  342. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +13 -6
  343. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  344. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  345. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  346. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  347. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +138 -2
  348. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  349. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  350. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  351. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  352. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  353. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  354. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  355. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  356. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  357. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  358. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  359. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  360. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  361. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  362. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  363. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  364. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  365. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  366. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  367. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  368. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  369. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  370. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  371. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  372. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -2
  373. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  374. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +52 -14
  375. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +50 -12
  376. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +61 -12
  377. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +54 -12
  378. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +5 -1
  379. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  380. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  381. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  382. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  383. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  384. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  385. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  386. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +10 -2
  387. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  388. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  389. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  390. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  391. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  392. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  393. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +15 -7
  394. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  395. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  396. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  397. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  398. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  399. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +1 -1
  400. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +229 -0
  401. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +33 -0
  402. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +1 -1
  403. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +1 -1
  404. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +1 -1
  405. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +1 -1
  406. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +1 -1
  407. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +1 -1
  408. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +1 -1
  409. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  410. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  411. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +3 -5
  412. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +1 -1
  413. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +3 -5
  414. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +3 -5
  415. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +1 -1
  416. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  417. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +106 -634
  418. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +118 -9
  419. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +556 -0
  420. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +70 -0
  421. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +77 -214
  422. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +589 -0
  423. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  424. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  425. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  426. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  427. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  428. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  429. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +25 -4
  430. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  431. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +55 -5
  432. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  433. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  434. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  435. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  436. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +45 -3
  437. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  438. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  439. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  440. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +227 -0
  441. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  442. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +5 -52
  443. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +5 -35
  444. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +5 -35
  445. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +27 -0
  446. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +5 -41
  447. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  448. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  449. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  450. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  451. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  452. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  453. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  454. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  455. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  456. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  457. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  458. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  459. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +140 -0
  460. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  461. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  462. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +1 -1
  463. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  464. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  465. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  466. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  467. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +171 -0
  468. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  469. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +79 -29
  470. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -12
  471. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +471 -196
  472. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +8 -0
  473. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1690 -383
  474. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  475. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  476. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  477. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  478. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +57 -10
  479. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  480. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  481. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +25 -912
  482. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  483. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  484. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  485. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  486. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  487. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  488. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  489. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/{set_rows.wgsl → set_rows.tmpl.wgsl} +38 -8
  490. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  491. package/cpp/llama.cpp/ggml/src/ggml-zdnn/common.hpp +59 -0
  492. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +96 -314
  493. package/cpp/llama.cpp/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  494. package/cpp/llama.cpp/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  495. package/cpp/llama.cpp/ggml/src/ggml-zdnn/utils.cpp +79 -0
  496. package/cpp/llama.cpp/ggml/src/ggml-zdnn/utils.hpp +19 -0
  497. package/cpp/llama.cpp/ggml/src/ggml.c +440 -17
  498. package/cpp/llama.cpp/ggml/src/gguf.cpp +104 -29
  499. package/cpp/llama.cpp/gguf-py/gguf/constants.py +363 -13
  500. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +64 -0
  501. package/cpp/llama.cpp/gguf-py/gguf/lazy.py +8 -3
  502. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +6 -0
  503. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +156 -18
  504. package/cpp/llama.cpp/gguf-py/gguf/utility.py +80 -0
  505. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +4 -4
  506. package/cpp/llama.cpp/include/llama.h +44 -21
  507. package/cpp/llama.cpp/media/llama1-icon-transparent.png +0 -0
  508. package/cpp/llama.cpp/media/llama1-icon-transparent.svg +77 -0
  509. package/cpp/llama.cpp/media/llama1-icon.png +0 -0
  510. package/cpp/llama.cpp/media/llama1-icon.svg +87 -0
  511. package/cpp/llama.cpp/requirements/requirements-all.txt +2 -0
  512. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -3
  513. package/cpp/llama.cpp/requirements/requirements-convert_legacy_llama.txt +3 -1
  514. package/cpp/llama.cpp/requirements/requirements-tool_bench.txt +1 -1
  515. package/cpp/llama.cpp/src/CMakeLists.txt +101 -0
  516. package/cpp/llama.cpp/src/llama-adapter.cpp +33 -0
  517. package/cpp/llama.cpp/src/llama-adapter.h +3 -0
  518. package/cpp/llama.cpp/src/llama-arch.cpp +344 -14
  519. package/cpp/llama.cpp/src/llama-arch.h +50 -0
  520. package/cpp/llama.cpp/src/llama-batch.cpp +63 -31
  521. package/cpp/llama.cpp/src/llama-batch.h +13 -2
  522. package/cpp/llama.cpp/src/llama-chat.cpp +85 -3
  523. package/cpp/llama.cpp/src/llama-chat.h +4 -0
  524. package/cpp/llama.cpp/src/llama-context.cpp +300 -45
  525. package/cpp/llama.cpp/src/llama-context.h +16 -6
  526. package/cpp/llama.cpp/src/llama-cparams.h +2 -1
  527. package/cpp/llama.cpp/src/llama-grammar.cpp +17 -9
  528. package/cpp/llama.cpp/src/llama-graph.cpp +226 -64
  529. package/cpp/llama.cpp/src/llama-graph.h +27 -5
  530. package/cpp/llama.cpp/src/llama-hparams.cpp +53 -2
  531. package/cpp/llama.cpp/src/llama-hparams.h +48 -8
  532. package/cpp/llama.cpp/src/llama-impl.cpp +3 -3
  533. package/cpp/llama.cpp/src/llama-impl.h +2 -0
  534. package/cpp/llama.cpp/src/llama-kv-cache-iswa.cpp +13 -3
  535. package/cpp/llama.cpp/src/llama-kv-cache-iswa.h +2 -0
  536. package/cpp/llama.cpp/src/llama-kv-cache.cpp +120 -62
  537. package/cpp/llama.cpp/src/llama-kv-cache.h +13 -4
  538. package/cpp/llama.cpp/src/llama-kv-cells.h +44 -2
  539. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +19 -9
  540. package/cpp/llama.cpp/src/llama-memory-hybrid.h +2 -0
  541. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +38 -17
  542. package/cpp/llama.cpp/src/llama-memory-recurrent.h +5 -2
  543. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  544. package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
  545. package/cpp/llama.cpp/src/llama-model.cpp +1070 -12614
  546. package/cpp/llama.cpp/src/llama-model.h +40 -4
  547. package/cpp/llama.cpp/src/llama-quant.cpp +14 -6
  548. package/cpp/llama.cpp/src/llama-sampling.cpp +243 -136
  549. package/cpp/llama.cpp/src/llama-vocab.cpp +43 -3
  550. package/cpp/llama.cpp/src/llama-vocab.h +43 -39
  551. package/cpp/llama.cpp/src/llama.cpp +69 -10
  552. package/cpp/llama.cpp/src/models/afmoe.cpp +187 -0
  553. package/cpp/llama.cpp/src/models/apertus.cpp +125 -0
  554. package/cpp/llama.cpp/src/models/arcee.cpp +135 -0
  555. package/cpp/llama.cpp/src/models/arctic.cpp +138 -0
  556. package/cpp/llama.cpp/src/models/arwkv7.cpp +86 -0
  557. package/cpp/llama.cpp/src/models/baichuan.cpp +122 -0
  558. package/cpp/llama.cpp/src/models/bailingmoe.cpp +144 -0
  559. package/cpp/llama.cpp/src/models/bailingmoe2.cpp +135 -0
  560. package/cpp/llama.cpp/src/models/bert.cpp +176 -0
  561. package/cpp/llama.cpp/src/models/bitnet.cpp +160 -0
  562. package/cpp/llama.cpp/src/models/bloom.cpp +101 -0
  563. package/cpp/llama.cpp/src/models/chameleon.cpp +178 -0
  564. package/cpp/llama.cpp/src/models/chatglm.cpp +132 -0
  565. package/cpp/llama.cpp/src/models/codeshell.cpp +111 -0
  566. package/cpp/llama.cpp/src/models/cogvlm.cpp +100 -0
  567. package/cpp/llama.cpp/src/models/cohere2-iswa.cpp +131 -0
  568. package/cpp/llama.cpp/src/models/command-r.cpp +122 -0
  569. package/cpp/llama.cpp/src/models/dbrx.cpp +123 -0
  570. package/cpp/llama.cpp/src/models/deci.cpp +135 -0
  571. package/cpp/llama.cpp/src/models/deepseek.cpp +144 -0
  572. package/cpp/llama.cpp/src/models/deepseek2.cpp +237 -0
  573. package/cpp/llama.cpp/src/models/dots1.cpp +134 -0
  574. package/cpp/llama.cpp/src/models/dream.cpp +105 -0
  575. package/cpp/llama.cpp/src/models/ernie4-5-moe.cpp +150 -0
  576. package/cpp/llama.cpp/src/models/ernie4-5.cpp +110 -0
  577. package/cpp/llama.cpp/src/models/exaone.cpp +114 -0
  578. package/cpp/llama.cpp/src/models/exaone4.cpp +123 -0
  579. package/cpp/llama.cpp/src/models/falcon-h1.cpp +113 -0
  580. package/cpp/llama.cpp/src/models/falcon.cpp +120 -0
  581. package/cpp/llama.cpp/src/models/gemma-embedding.cpp +120 -0
  582. package/cpp/llama.cpp/src/models/gemma.cpp +112 -0
  583. package/cpp/llama.cpp/src/models/gemma2-iswa.cpp +125 -0
  584. package/cpp/llama.cpp/src/models/gemma3-iswa.cpp +131 -0
  585. package/cpp/llama.cpp/src/models/gemma3n-iswa.cpp +377 -0
  586. package/cpp/llama.cpp/src/models/glm4-moe.cpp +153 -0
  587. package/cpp/llama.cpp/src/models/glm4.cpp +127 -0
  588. package/cpp/llama.cpp/src/models/gpt2.cpp +105 -0
  589. package/cpp/llama.cpp/src/models/gptneox.cpp +144 -0
  590. package/cpp/llama.cpp/src/models/granite-hybrid.cpp +196 -0
  591. package/cpp/llama.cpp/src/models/granite.cpp +211 -0
  592. package/cpp/llama.cpp/src/models/graph-context-mamba.cpp +283 -0
  593. package/cpp/llama.cpp/src/models/grok.cpp +159 -0
  594. package/cpp/llama.cpp/src/models/grovemoe.cpp +141 -0
  595. package/cpp/llama.cpp/src/models/hunyuan-dense.cpp +132 -0
  596. package/cpp/llama.cpp/src/models/hunyuan-moe.cpp +154 -0
  597. package/cpp/llama.cpp/src/models/internlm2.cpp +120 -0
  598. package/cpp/llama.cpp/src/models/jais.cpp +86 -0
  599. package/cpp/llama.cpp/src/models/jamba.cpp +106 -0
  600. package/cpp/llama.cpp/src/models/lfm2.cpp +173 -0
  601. package/cpp/llama.cpp/src/models/llada-moe.cpp +122 -0
  602. package/cpp/llama.cpp/src/models/llada.cpp +99 -0
  603. package/cpp/llama.cpp/src/models/llama-iswa.cpp +174 -0
  604. package/cpp/llama.cpp/src/models/llama.cpp +155 -0
  605. package/cpp/llama.cpp/src/models/mamba.cpp +55 -0
  606. package/cpp/llama.cpp/src/models/minicpm3.cpp +199 -0
  607. package/cpp/llama.cpp/src/models/minimax-m2.cpp +124 -0
  608. package/cpp/llama.cpp/src/models/models.h +485 -0
  609. package/cpp/llama.cpp/src/models/mpt.cpp +126 -0
  610. package/cpp/llama.cpp/src/models/nemotron-h.cpp +121 -0
  611. package/cpp/llama.cpp/src/models/nemotron.cpp +122 -0
  612. package/cpp/llama.cpp/src/models/neo-bert.cpp +104 -0
  613. package/cpp/llama.cpp/src/models/olmo.cpp +121 -0
  614. package/cpp/llama.cpp/src/models/olmo2.cpp +150 -0
  615. package/cpp/llama.cpp/src/models/olmoe.cpp +124 -0
  616. package/cpp/llama.cpp/src/models/openai-moe-iswa.cpp +124 -0
  617. package/cpp/llama.cpp/src/models/openelm.cpp +124 -0
  618. package/cpp/llama.cpp/src/models/orion.cpp +123 -0
  619. package/cpp/llama.cpp/src/models/pangu-embedded.cpp +121 -0
  620. package/cpp/llama.cpp/src/models/phi2.cpp +121 -0
  621. package/cpp/llama.cpp/src/models/phi3.cpp +152 -0
  622. package/cpp/llama.cpp/src/models/plamo.cpp +110 -0
  623. package/cpp/llama.cpp/src/models/plamo2.cpp +316 -0
  624. package/cpp/llama.cpp/src/models/plm.cpp +168 -0
  625. package/cpp/llama.cpp/src/models/qwen.cpp +108 -0
  626. package/cpp/llama.cpp/src/models/qwen2.cpp +117 -0
  627. package/cpp/llama.cpp/src/models/qwen2moe.cpp +151 -0
  628. package/cpp/llama.cpp/src/models/qwen2vl.cpp +117 -0
  629. package/cpp/llama.cpp/src/models/qwen3.cpp +117 -0
  630. package/cpp/llama.cpp/src/models/qwen3moe.cpp +124 -0
  631. package/cpp/llama.cpp/src/models/qwen3vl-moe.cpp +149 -0
  632. package/cpp/llama.cpp/src/models/qwen3vl.cpp +141 -0
  633. package/cpp/llama.cpp/src/models/refact.cpp +94 -0
  634. package/cpp/llama.cpp/src/models/rwkv6-base.cpp +162 -0
  635. package/cpp/llama.cpp/src/models/rwkv6.cpp +94 -0
  636. package/cpp/llama.cpp/src/models/rwkv6qwen2.cpp +86 -0
  637. package/cpp/llama.cpp/src/models/rwkv7-base.cpp +135 -0
  638. package/cpp/llama.cpp/src/models/rwkv7.cpp +90 -0
  639. package/cpp/llama.cpp/src/models/seed-oss.cpp +124 -0
  640. package/cpp/llama.cpp/src/models/smallthinker.cpp +120 -0
  641. package/cpp/llama.cpp/src/models/smollm3.cpp +128 -0
  642. package/cpp/llama.cpp/src/models/stablelm.cpp +146 -0
  643. package/cpp/llama.cpp/src/models/starcoder.cpp +100 -0
  644. package/cpp/llama.cpp/src/models/starcoder2.cpp +121 -0
  645. package/cpp/llama.cpp/src/models/t5-dec.cpp +166 -0
  646. package/cpp/llama.cpp/src/models/t5-enc.cpp +96 -0
  647. package/cpp/llama.cpp/src/models/wavtokenizer-dec.cpp +149 -0
  648. package/cpp/llama.cpp/src/models/xverse.cpp +108 -0
  649. package/cpp/llama.cpp/src/unicode.cpp +77 -0
  650. package/cpp/llama.cpp/src/unicode.h +43 -0
  651. package/cpp/llama.cpp/vendor/cpp-httplib/CMakeLists.txt +94 -0
  652. package/cpp/llama.cpp/vendor/cpp-httplib/httplib.cpp +9339 -0
  653. package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +433 -8222
  654. package/cpp/llama.cpp/vendor/cpp-httplib/patch-boringssl.cmake +6 -0
  655. package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +4179 -1900
  656. package/cpp/llama.cpp/vendor/minja/chat-template.hpp +9 -2
  657. package/cpp/llama.cpp/vendor/minja/minja.hpp +101 -22
  658. package/cpp/rn-completion.cpp +3 -27
  659. package/ios/include/chat.h +16 -3
  660. package/ios/include/common/minja/chat-template.hpp +9 -2
  661. package/ios/include/common/minja/minja.hpp +101 -22
  662. package/ios/include/common.h +57 -19
  663. package/ios/include/json-schema-to-grammar.h +2 -0
  664. package/ios/include/llama.h +44 -21
  665. package/ios/include/log.h +12 -4
  666. package/ios/include/sampling.h +3 -1
  667. package/ios/libs/llama.xcframework/Info.plist +20 -20
  668. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  669. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +6399 -5557
  670. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +19 -1
  671. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +1 -1
  672. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-metal.h +1 -6
  673. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +199 -6
  674. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +44 -21
  675. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  676. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  677. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +6362 -5520
  678. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4813 -4241
  679. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +19 -1
  680. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +1 -1
  681. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-metal.h +1 -6
  682. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +199 -6
  683. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +44 -21
  684. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  685. package/package.json +10 -4
  686. package/cpp/llama.cpp/ggml/src/ggml-cann/Doxyfile +0 -2579
  687. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -371
  688. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  689. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -379
  690. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  691. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -495
  692. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -486
  693. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  694. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  695. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  696. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  697. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  698. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  699. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  700. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  701. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  702. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  703. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  704. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  705. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  706. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  707. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  708. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  709. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  710. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  711. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  712. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  713. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  714. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  715. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  716. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  717. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  718. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  719. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  720. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  721. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  722. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  723. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  724. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  725. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  726. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  727. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  728. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  729. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  730. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  731. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  732. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  733. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  734. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  735. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  736. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  737. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  738. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  739. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  740. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  741. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  742. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  743. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  744. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  745. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  746. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  747. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  748. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  749. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  750. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  751. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  752. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  753. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  754. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  755. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  756. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  757. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  758. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  759. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  760. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  761. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  762. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  763. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  764. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  765. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  766. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  767. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  768. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  769. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  770. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  771. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  772. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  773. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  774. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  775. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  776. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  777. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  778. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  779. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +0 -6886
  780. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -154
  781. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  782. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  783. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  784. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +0 -97
  785. package/cpp/llama.cpp/models/ggml-vocab-aquila.gguf +0 -0
  786. package/cpp/llama.cpp/models/ggml-vocab-baichuan.gguf +0 -0
  787. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf +0 -0
  788. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +0 -112
  789. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +0 -46
  790. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf +0 -0
  791. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +0 -112
  792. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +0 -46
  793. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf +0 -0
  794. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +0 -112
  795. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +0 -46
  796. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf +0 -0
  797. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +0 -112
  798. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +0 -46
  799. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf +0 -0
  800. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +0 -112
  801. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +0 -46
  802. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf +0 -0
  803. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +0 -112
  804. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +0 -46
  805. package/cpp/llama.cpp/models/ggml-vocab-gpt-neox.gguf +0 -0
  806. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf +0 -0
  807. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +0 -112
  808. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +0 -46
  809. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf +0 -0
  810. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +0 -112
  811. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +0 -46
  812. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf +0 -0
  813. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +0 -112
  814. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +0 -46
  815. package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
  816. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf +0 -0
  817. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +0 -112
  818. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +0 -46
  819. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf +0 -0
  820. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +0 -112
  821. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +0 -46
  822. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf +0 -0
  823. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +0 -112
  824. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +0 -46
  825. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf +0 -0
  826. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +0 -112
  827. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +0 -46
  828. package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +0 -171
  829. package/cpp/llama.cpp/models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja +0 -202
  830. package/cpp/llama.cpp/models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja +0 -156
  831. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +0 -124
  832. package/cpp/llama.cpp/models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja +0 -152
  833. package/cpp/llama.cpp/models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja +0 -152
  834. package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +0 -62
  835. package/cpp/llama.cpp/models/templates/Qwen-Qwen2.5-7B-Instruct.jinja +0 -54
  836. package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +0 -85
  837. package/cpp/llama.cpp/models/templates/README.md +0 -25
  838. package/cpp/llama.cpp/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja +0 -1
  839. package/cpp/llama.cpp/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja +0 -1
  840. package/cpp/llama.cpp/models/templates/fireworks-ai-llama-3-firefunction-v2.jinja +0 -57
  841. package/cpp/llama.cpp/models/templates/google-gemma-2-2b-it.jinja +0 -4
  842. package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +0 -59
  843. package/cpp/llama.cpp/models/templates/llama-cpp-deepseek-r1.jinja +0 -76
  844. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +0 -34
  845. package/cpp/llama.cpp/models/templates/meetkai-functionary-medium-v3.1.jinja +0 -58
  846. package/cpp/llama.cpp/models/templates/meetkai-functionary-medium-v3.2.jinja +0 -287
  847. package/cpp/llama.cpp/models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja +0 -109
  848. package/cpp/llama.cpp/models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja +0 -93
  849. package/cpp/llama.cpp/models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja +0 -109
  850. package/cpp/llama.cpp/models/templates/microsoft-Phi-3.5-mini-instruct.jinja +0 -8
  851. package/cpp/llama.cpp/models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja +0 -87
  852. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +0 -43
  853. package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +0 -331
  854. package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +0 -105
  855. package/cpp/llama.cpp/prompts/LLM-questions.txt +0 -49
  856. package/cpp/llama.cpp/prompts/alpaca.txt +0 -1
  857. package/cpp/llama.cpp/prompts/assistant.txt +0 -31
  858. package/cpp/llama.cpp/prompts/chat-with-baichuan.txt +0 -4
  859. package/cpp/llama.cpp/prompts/chat-with-bob.txt +0 -7
  860. package/cpp/llama.cpp/prompts/chat-with-qwen.txt +0 -1
  861. package/cpp/llama.cpp/prompts/chat-with-vicuna-v0.txt +0 -7
  862. package/cpp/llama.cpp/prompts/chat-with-vicuna-v1.txt +0 -7
  863. package/cpp/llama.cpp/prompts/chat.txt +0 -28
  864. package/cpp/llama.cpp/prompts/dan-modified.txt +0 -1
  865. package/cpp/llama.cpp/prompts/dan.txt +0 -1
  866. package/cpp/llama.cpp/prompts/mnemonics.txt +0 -93
  867. package/cpp/llama.cpp/prompts/parallel-questions.txt +0 -43
  868. package/cpp/llama.cpp/prompts/reason-act.txt +0 -18
  869. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
  870. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  871. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5524
  872. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +0 -4247
  873. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-alloc.h +0 -76
  874. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +0 -354
  875. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-blas.h +0 -25
  876. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +0 -145
  877. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-metal.h +0 -66
  878. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +0 -256
  879. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +0 -2492
  880. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/gguf.h +0 -202
  881. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +0 -1391
  882. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Modules/module.modulemap +0 -17
  883. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Resources/Info.plist +0 -32
  884. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-alloc.h +0 -76
  885. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +0 -354
  886. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-blas.h +0 -25
  887. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +0 -145
  888. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-metal.h +0 -66
  889. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +0 -256
  890. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +0 -2492
  891. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/gguf.h +0 -202
  892. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +0 -1391
  893. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Modules/module.modulemap +0 -17
  894. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Resources/Info.plist +0 -32
  895. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  896. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-alloc.h +0 -76
  897. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +0 -354
  898. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-blas.h +0 -25
  899. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +0 -145
  900. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-metal.h +0 -66
  901. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +0 -256
  902. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +0 -2492
  903. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/gguf.h +0 -202
  904. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +0 -1391
  905. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Modules/module.modulemap +0 -17
  906. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Resources/Info.plist +0 -32
  907. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  908. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  909. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
  910. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  911. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5561
  912. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-alloc.h +0 -76
  913. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +0 -354
  914. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-blas.h +0 -25
  915. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +0 -145
  916. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-metal.h +0 -66
  917. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +0 -256
  918. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +0 -2492
  919. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/gguf.h +0 -202
  920. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +0 -1391
  921. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Info.plist +0 -35
  922. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Modules/module.modulemap +0 -17
  923. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  924. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
  925. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  926. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5524
  927. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +0 -4246
  928. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-alloc.h +0 -76
  929. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +0 -354
  930. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-blas.h +0 -25
  931. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +0 -145
  932. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-metal.h +0 -66
  933. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +0 -256
  934. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +0 -2492
  935. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/gguf.h +0 -202
  936. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -1391
  937. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Info.plist +0 -35
  938. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Modules/module.modulemap +0 -17
  939. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  940. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
  941. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  942. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5558
  943. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-alloc.h +0 -76
  944. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +0 -354
  945. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-blas.h +0 -25
  946. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +0 -145
  947. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-metal.h +0 -66
  948. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +0 -256
  949. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +0 -2492
  950. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/gguf.h +0 -202
  951. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +0 -1391
  952. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Info.plist +0 -32
  953. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Modules/module.modulemap +0 -17
  954. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  955. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
  956. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  957. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5520
  958. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +0 -4243
  959. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-alloc.h +0 -76
  960. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +0 -354
  961. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-blas.h +0 -25
  962. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +0 -145
  963. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-metal.h +0 -66
  964. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +0 -256
  965. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +0 -2492
  966. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/gguf.h +0 -202
  967. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -1391
  968. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Info.plist +0 -32
  969. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Modules/module.modulemap +0 -17
  970. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  971. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  972. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  973. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  974. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  975. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +0 -0
  976. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +0 -0
  977. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  978. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  979. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -7,8 +7,10 @@
7
7
  #include "unary-ops.h"
8
8
  #include "vec.h"
9
9
 
10
- #include <float.h>
10
+ #include <cfloat>
11
11
  #include <algorithm>
12
+ #include <cmath>
13
+ #include <functional>
12
14
 
13
15
  // ggml_compute_forward_dup
14
16
 
@@ -41,628 +43,15 @@ static void ggml_compute_forward_dup_same_cont(
41
43
  }
42
44
  }
43
45
 
44
- static void ggml_compute_forward_dup_f16(
45
- const ggml_compute_params * params,
46
- ggml_tensor * dst) {
47
-
48
- const ggml_tensor * src0 = dst->src[0];
49
-
50
- GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
51
-
52
- GGML_TENSOR_UNARY_OP_LOCALS
53
-
54
- const int ith = params->ith; // thread index
55
- const int nth = params->nth; // number of threads
56
-
57
- // parallelize by rows
58
- const int nr = ne01;
59
- // number of rows per thread
60
- const int dr = (nr + nth - 1) / nth;
61
- // row range for this thread
62
- const int ir0 = dr * ith;
63
- const int ir1 = MIN(ir0 + dr, nr);
64
-
65
- if (src0->type == dst->type &&
66
- ne00 == ne0 &&
67
- nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
68
- // copy by rows
69
- const size_t rs = ne00*nb00;
70
- for (int64_t i03 = 0; i03 < ne03; i03++) {
71
- for (int64_t i02 = 0; i02 < ne02; i02++) {
72
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
73
- memcpy(
74
- ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
75
- ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
76
- rs);
77
- }
78
- }
79
- }
80
- return;
81
- }
82
-
83
- // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
84
-
85
- if (ggml_is_contiguous(dst)) {
86
- if (nb00 == sizeof(ggml_fp16_t)) {
87
- if (dst->type == GGML_TYPE_F16) {
88
- size_t id = 0;
89
- const size_t rs = ne00 * nb00;
90
- char * dst_ptr = (char *) dst->data;
91
-
92
- for (int i03 = 0; i03 < ne03; i03++) {
93
- for (int i02 = 0; i02 < ne02; i02++) {
94
- id += rs * ir0;
95
- for (int i01 = ir0; i01 < ir1; i01++) {
96
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
97
- memcpy(dst_ptr + id, src0_ptr, rs);
98
- id += rs;
99
- }
100
- id += rs * (ne01 - ir1);
101
- }
102
- }
103
- } else if (dst->type == GGML_TYPE_F32) {
104
- size_t id = 0;
105
- float * dst_ptr = (float *) dst->data;
106
-
107
- for (int i03 = 0; i03 < ne03; i03++) {
108
- for (int i02 = 0; i02 < ne02; i02++) {
109
- id += ne00 * ir0;
110
- for (int i01 = ir0; i01 < ir1; i01++) {
111
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
112
- for (int i00 = 0; i00 < ne00; i00++) {
113
- dst_ptr[id] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
114
- id++;
115
- }
116
- }
117
- id += ne00 * (ne01 - ir1);
118
- }
119
- }
120
- } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
121
- ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
122
- float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
123
-
124
- size_t id = 0;
125
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
126
- char * dst_ptr = (char *) dst->data;
127
-
128
- for (int i03 = 0; i03 < ne03; i03++) {
129
- for (int i02 = 0; i02 < ne02; i02++) {
130
- id += rs * ir0;
131
- for (int i01 = ir0; i01 < ir1; i01++) {
132
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
133
-
134
- for (int i00 = 0; i00 < ne00; i00++) {
135
- src0_f32[i00] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
136
- }
137
-
138
- quantize_row_q(src0_f32, dst_ptr + id, ne00);
139
- id += rs;
140
- }
141
- id += rs * (ne01 - ir1);
142
- }
143
- }
144
- } else {
145
- GGML_ABORT("fatal error"); // TODO: implement
146
- }
147
- } else {
148
- //printf("%s: this is not optimal - fix me\n", __func__);
149
-
150
- if (dst->type == GGML_TYPE_F32) {
151
- size_t id = 0;
152
- float * dst_ptr = (float *) dst->data;
153
-
154
- for (int i03 = 0; i03 < ne03; i03++) {
155
- for (int i02 = 0; i02 < ne02; i02++) {
156
- id += ne00 * ir0;
157
- for (int i01 = ir0; i01 < ir1; i01++) {
158
- for (int i00 = 0; i00 < ne00; i00++) {
159
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
160
-
161
- dst_ptr[id] = GGML_CPU_FP16_TO_FP32(*src0_ptr);
162
- id++;
163
- }
164
- }
165
- id += ne00 * (ne01 - ir1);
166
- }
167
- }
168
- } else if (dst->type == GGML_TYPE_F16) {
169
- size_t id = 0;
170
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
171
-
172
- for (int i03 = 0; i03 < ne03; i03++) {
173
- for (int i02 = 0; i02 < ne02; i02++) {
174
- id += ne00 * ir0;
175
- for (int i01 = ir0; i01 < ir1; i01++) {
176
- for (int i00 = 0; i00 < ne00; i00++) {
177
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
178
-
179
- dst_ptr[id] = *src0_ptr;
180
- id++;
181
- }
182
- }
183
- id += ne00 * (ne01 - ir1);
184
- }
185
- }
186
- } else {
187
- GGML_ABORT("fatal error"); // TODO: implement
188
- }
189
- }
190
- return;
191
- }
192
-
193
- // dst counters
194
- int64_t i10 = 0;
195
- int64_t i11 = 0;
196
- int64_t i12 = 0;
197
- int64_t i13 = 0;
198
-
199
- if (dst->type == GGML_TYPE_F16) {
200
- for (int64_t i03 = 0; i03 < ne03; i03++) {
201
- for (int64_t i02 = 0; i02 < ne02; i02++) {
202
- i10 += ne00 * ir0;
203
- while (i10 >= ne0) {
204
- i10 -= ne0;
205
- if (++i11 == ne1) {
206
- i11 = 0;
207
- if (++i12 == ne2) {
208
- i12 = 0;
209
- if (++i13 == ne3) {
210
- i13 = 0;
211
- }
212
- }
213
- }
214
- }
215
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
216
- for (int64_t i00 = 0; i00 < ne00; i00++) {
217
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
218
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
219
-
220
- memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
221
-
222
- if (++i10 == ne00) {
223
- i10 = 0;
224
- if (++i11 == ne01) {
225
- i11 = 0;
226
- if (++i12 == ne02) {
227
- i12 = 0;
228
- if (++i13 == ne03) {
229
- i13 = 0;
230
- }
231
- }
232
- }
233
- }
234
- }
235
- }
236
- i10 += ne00 * (ne01 - ir1);
237
- while (i10 >= ne0) {
238
- i10 -= ne0;
239
- if (++i11 == ne1) {
240
- i11 = 0;
241
- if (++i12 == ne2) {
242
- i12 = 0;
243
- if (++i13 == ne3) {
244
- i13 = 0;
245
- }
246
- }
247
- }
248
- }
249
- }
250
- }
251
- } else if (dst->type == GGML_TYPE_F32) {
252
- for (int64_t i03 = 0; i03 < ne03; i03++) {
253
- for (int64_t i02 = 0; i02 < ne02; i02++) {
254
- i10 += ne00 * ir0;
255
- while (i10 >= ne0) {
256
- i10 -= ne0;
257
- if (++i11 == ne1) {
258
- i11 = 0;
259
- if (++i12 == ne2) {
260
- i12 = 0;
261
- if (++i13 == ne3) {
262
- i13 = 0;
263
- }
264
- }
265
- }
266
- }
267
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
268
- for (int64_t i00 = 0; i00 < ne00; i00++) {
269
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
270
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
271
-
272
- *(float *) dst_ptr = GGML_CPU_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
273
-
274
- if (++i10 == ne0) {
275
- i10 = 0;
276
- if (++i11 == ne1) {
277
- i11 = 0;
278
- if (++i12 == ne2) {
279
- i12 = 0;
280
- if (++i13 == ne3) {
281
- i13 = 0;
282
- }
283
- }
284
- }
285
- }
286
- }
287
- }
288
- i10 += ne00 * (ne01 - ir1);
289
- while (i10 >= ne0) {
290
- i10 -= ne0;
291
- if (++i11 == ne1) {
292
- i11 = 0;
293
- if (++i12 == ne2) {
294
- i12 = 0;
295
- if (++i13 == ne3) {
296
- i13 = 0;
297
- }
298
- }
299
- }
300
- }
301
- }
302
- }
303
- } else {
304
- GGML_ABORT("fatal error"); // TODO: implement
305
- }
306
- }
307
-
308
- static void ggml_compute_forward_dup_bf16(
309
- const ggml_compute_params * params,
310
- ggml_tensor * dst) {
311
-
312
- const ggml_tensor * src0 = dst->src[0];
313
-
314
- GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
315
-
316
- GGML_TENSOR_UNARY_OP_LOCALS
317
-
318
- const int ith = params->ith; // thread index
319
- const int nth = params->nth; // number of threads
320
-
321
- // parallelize by rows
322
- const int nr = ne01;
323
- // number of rows per thread
324
- const int dr = (nr + nth - 1) / nth;
325
- // row range for this thread
326
- const int ir0 = dr * ith;
327
- const int ir1 = MIN(ir0 + dr, nr);
328
-
329
- if (src0->type == dst->type &&
330
- ne00 == ne0 &&
331
- nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
332
- // copy by rows
333
- const size_t rs = ne00*nb00;
334
- for (int64_t i03 = 0; i03 < ne03; i03++) {
335
- for (int64_t i02 = 0; i02 < ne02; i02++) {
336
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
337
- memcpy(
338
- ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
339
- ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
340
- rs);
341
- }
342
- }
343
- }
344
- return;
345
- }
346
-
347
- // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
348
-
349
- if (ggml_is_contiguous(dst)) {
350
- if (nb00 == sizeof(ggml_bf16_t)) {
351
- if (dst->type == GGML_TYPE_BF16) {
352
- size_t id = 0;
353
- const size_t rs = ne00 * nb00;
354
- char * dst_ptr = (char *) dst->data;
355
-
356
- for (int i03 = 0; i03 < ne03; i03++) {
357
- for (int i02 = 0; i02 < ne02; i02++) {
358
- id += rs * ir0;
359
- for (int i01 = ir0; i01 < ir1; i01++) {
360
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
361
- memcpy(dst_ptr + id, src0_ptr, rs);
362
- id += rs;
363
- }
364
- id += rs * (ne01 - ir1);
365
- }
366
- }
367
- } else if (dst->type == GGML_TYPE_F16) {
368
- size_t id = 0;
369
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
370
-
371
- for (int i03 = 0; i03 < ne03; i03++) {
372
- for (int i02 = 0; i02 < ne02; i02++) {
373
- id += ne00 * ir0;
374
- for (int i01 = ir0; i01 < ir1; i01++) {
375
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
376
- for (int i00 = 0; i00 < ne00; i00++) {
377
- dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
378
- id++;
379
- }
380
- }
381
- id += ne00 * (ne01 - ir1);
382
- }
383
- }
384
- } else if (dst->type == GGML_TYPE_F32) {
385
- size_t id = 0;
386
- float * dst_ptr = (float *) dst->data;
387
-
388
- for (int i03 = 0; i03 < ne03; i03++) {
389
- for (int i02 = 0; i02 < ne02; i02++) {
390
- id += ne00 * ir0;
391
- for (int i01 = ir0; i01 < ir1; i01++) {
392
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
393
- for (int i00 = 0; i00 < ne00; i00++) {
394
- dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
395
- id++;
396
- }
397
- }
398
- id += ne00 * (ne01 - ir1);
399
- }
400
- }
401
- } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
402
- ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
403
- float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
404
-
405
- size_t id = 0;
406
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
407
- char * dst_ptr = (char *) dst->data;
408
-
409
- for (int i03 = 0; i03 < ne03; i03++) {
410
- for (int i02 = 0; i02 < ne02; i02++) {
411
- id += rs * ir0;
412
- for (int i01 = ir0; i01 < ir1; i01++) {
413
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
414
-
415
- for (int i00 = 0; i00 < ne00; i00++) {
416
- src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
417
- }
418
-
419
- quantize_row_q(src0_f32, dst_ptr + id, ne00);
420
- id += rs;
421
- }
422
- id += rs * (ne01 - ir1);
423
- }
424
- }
425
- } else {
426
- GGML_ABORT("fatal error"); // TODO: implement
427
- }
428
- } else {
429
- //printf("%s: this is not optimal - fix me\n", __func__);
430
-
431
- if (dst->type == GGML_TYPE_F32) {
432
- size_t id = 0;
433
- float * dst_ptr = (float *) dst->data;
434
-
435
- for (int i03 = 0; i03 < ne03; i03++) {
436
- for (int i02 = 0; i02 < ne02; i02++) {
437
- id += ne00 * ir0;
438
- for (int i01 = ir0; i01 < ir1; i01++) {
439
- for (int i00 = 0; i00 < ne00; i00++) {
440
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
441
-
442
- dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
443
- id++;
444
- }
445
- }
446
- id += ne00 * (ne01 - ir1);
447
- }
448
- }
449
- } else if (dst->type == GGML_TYPE_BF16) {
450
- size_t id = 0;
451
- ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
452
-
453
- for (int i03 = 0; i03 < ne03; i03++) {
454
- for (int i02 = 0; i02 < ne02; i02++) {
455
- id += ne00 * ir0;
456
- for (int i01 = ir0; i01 < ir1; i01++) {
457
- for (int i00 = 0; i00 < ne00; i00++) {
458
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
459
-
460
- dst_ptr[id] = *src0_ptr;
461
- id++;
462
- }
463
- }
464
- id += ne00 * (ne01 - ir1);
465
- }
466
- }
467
- } else if (dst->type == GGML_TYPE_F16) {
468
- size_t id = 0;
469
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
470
-
471
- for (int i03 = 0; i03 < ne03; i03++) {
472
- for (int i02 = 0; i02 < ne02; i02++) {
473
- id += ne00 * ir0;
474
- for (int i01 = ir0; i01 < ir1; i01++) {
475
- for (int i00 = 0; i00 < ne00; i00++) {
476
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
477
-
478
- dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
479
- id++;
480
- }
481
- }
482
- id += ne00 * (ne01 - ir1);
483
- }
484
- }
485
- } else {
486
- GGML_ABORT("fatal error"); // TODO: implement
487
- }
488
- }
489
- return;
490
- }
491
-
492
- // dst counters
493
- int64_t i10 = 0;
494
- int64_t i11 = 0;
495
- int64_t i12 = 0;
496
- int64_t i13 = 0;
497
-
498
- if (dst->type == GGML_TYPE_BF16) {
499
- for (int64_t i03 = 0; i03 < ne03; i03++) {
500
- for (int64_t i02 = 0; i02 < ne02; i02++) {
501
- i10 += ne00 * ir0;
502
- while (i10 >= ne0) {
503
- i10 -= ne0;
504
- if (++i11 == ne1) {
505
- i11 = 0;
506
- if (++i12 == ne2) {
507
- i12 = 0;
508
- if (++i13 == ne3) {
509
- i13 = 0;
510
- }
511
- }
512
- }
513
- }
514
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
515
- for (int64_t i00 = 0; i00 < ne00; i00++) {
516
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
517
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
518
-
519
- memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
520
-
521
- if (++i10 == ne00) {
522
- i10 = 0;
523
- if (++i11 == ne01) {
524
- i11 = 0;
525
- if (++i12 == ne02) {
526
- i12 = 0;
527
- if (++i13 == ne03) {
528
- i13 = 0;
529
- }
530
- }
531
- }
532
- }
533
- }
534
- }
535
- i10 += ne00 * (ne01 - ir1);
536
- while (i10 >= ne0) {
537
- i10 -= ne0;
538
- if (++i11 == ne1) {
539
- i11 = 0;
540
- if (++i12 == ne2) {
541
- i12 = 0;
542
- if (++i13 == ne3) {
543
- i13 = 0;
544
- }
545
- }
546
- }
547
- }
548
- }
549
- }
550
- } else if (dst->type == GGML_TYPE_F16) {
551
- for (int64_t i03 = 0; i03 < ne03; i03++) {
552
- for (int64_t i02 = 0; i02 < ne02; i02++) {
553
- i10 += ne00 * ir0;
554
- while (i10 >= ne0) {
555
- i10 -= ne0;
556
- if (++i11 == ne1) {
557
- i11 = 0;
558
- if (++i12 == ne2) {
559
- i12 = 0;
560
- if (++i13 == ne3) {
561
- i13 = 0;
562
- }
563
- }
564
- }
565
- }
566
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
567
- for (int64_t i00 = 0; i00 < ne00; i00++) {
568
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
569
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
570
-
571
- *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
572
-
573
- if (++i10 == ne0) {
574
- i10 = 0;
575
- if (++i11 == ne1) {
576
- i11 = 0;
577
- if (++i12 == ne2) {
578
- i12 = 0;
579
- if (++i13 == ne3) {
580
- i13 = 0;
581
- }
582
- }
583
- }
584
- }
585
- }
586
- }
587
- i10 += ne00 * (ne01 - ir1);
588
- while (i10 >= ne0) {
589
- i10 -= ne0;
590
- if (++i11 == ne1) {
591
- i11 = 0;
592
- if (++i12 == ne2) {
593
- i12 = 0;
594
- if (++i13 == ne3) {
595
- i13 = 0;
596
- }
597
- }
598
- }
599
- }
600
- }
601
- }
602
- } else if (dst->type == GGML_TYPE_F32) {
603
- for (int64_t i03 = 0; i03 < ne03; i03++) {
604
- for (int64_t i02 = 0; i02 < ne02; i02++) {
605
- i10 += ne00 * ir0;
606
- while (i10 >= ne0) {
607
- i10 -= ne0;
608
- if (++i11 == ne1) {
609
- i11 = 0;
610
- if (++i12 == ne2) {
611
- i12 = 0;
612
- if (++i13 == ne3) {
613
- i13 = 0;
614
- }
615
- }
616
- }
617
- }
618
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
619
- for (int64_t i00 = 0; i00 < ne00; i00++) {
620
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
621
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
622
-
623
- *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
624
-
625
- if (++i10 == ne0) {
626
- i10 = 0;
627
- if (++i11 == ne1) {
628
- i11 = 0;
629
- if (++i12 == ne2) {
630
- i12 = 0;
631
- if (++i13 == ne3) {
632
- i13 = 0;
633
- }
634
- }
635
- }
636
- }
637
- }
638
- }
639
- i10 += ne00 * (ne01 - ir1);
640
- while (i10 >= ne0) {
641
- i10 -= ne0;
642
- if (++i11 == ne1) {
643
- i11 = 0;
644
- if (++i12 == ne2) {
645
- i12 = 0;
646
- if (++i13 == ne3) {
647
- i13 = 0;
648
- }
649
- }
650
- }
651
- }
652
- }
653
- }
654
- } else {
655
- GGML_ABORT("fatal error"); // TODO: implement
656
- }
657
- }
658
-
659
- static void ggml_compute_forward_dup_f32(
46
+ template<typename src_t, typename dst_t>
47
+ static void ggml_compute_forward_dup_flt(
660
48
  const ggml_compute_params * params,
661
49
  ggml_tensor * dst) {
662
50
 
663
51
  const ggml_tensor * src0 = dst->src[0];
664
52
 
665
53
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
54
+ GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type));
666
55
 
667
56
  GGML_TENSOR_UNARY_OP_LOCALS
668
57
 
@@ -677,6 +66,7 @@ static void ggml_compute_forward_dup_f32(
677
66
  const int ir0 = dr * ith;
678
67
  const int ir1 = MIN(ir0 + dr, nr);
679
68
 
69
+ // case: type & row size equal
680
70
  if (src0->type == dst->type &&
681
71
  ne00 == ne0 &&
682
72
  nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
@@ -695,103 +85,78 @@ static void ggml_compute_forward_dup_f32(
695
85
  return;
696
86
  }
697
87
 
88
+ // case: dst tensor is contiguous
698
89
  if (ggml_is_contiguous(dst)) {
699
- // TODO: simplify
700
- if (nb00 == sizeof(float)) {
701
- if (ggml_get_type_traits_cpu(dst->type)->from_float) {
702
- ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
703
-
90
+ if (nb00 == sizeof(src_t)) {
91
+ if constexpr (std::is_same_v<dst_t, src_t>) {
92
+ // same type
704
93
  size_t id = 0;
705
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
94
+ const size_t rs = ne00 * nb00;
706
95
  char * dst_ptr = (char *) dst->data;
707
96
 
708
97
  for (int i03 = 0; i03 < ne03; i03++) {
709
98
  for (int i02 = 0; i02 < ne02; i02++) {
710
99
  id += rs * ir0;
711
100
  for (int i01 = ir0; i01 < ir1; i01++) {
712
- const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
713
- from_float(src0_ptr, dst_ptr + id, ne00);
101
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
102
+ memcpy(dst_ptr + id, src0_ptr, rs);
714
103
  id += rs;
715
104
  }
716
105
  id += rs * (ne01 - ir1);
717
106
  }
718
107
  }
719
108
  } else {
720
- GGML_ABORT("fatal error"); // TODO: implement
721
- }
722
- } else {
723
- //printf("%s: this is not optimal - fix me\n", __func__);
724
-
725
- if (dst->type == GGML_TYPE_F32) {
109
+ // casting between non-quantized types
726
110
  size_t id = 0;
727
- float * dst_ptr = (float *) dst->data;
111
+ dst_t * dst_ptr = (dst_t *) dst->data;
728
112
 
729
113
  for (int i03 = 0; i03 < ne03; i03++) {
730
114
  for (int i02 = 0; i02 < ne02; i02++) {
731
115
  id += ne00 * ir0;
732
116
  for (int i01 = ir0; i01 < ir1; i01++) {
117
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
733
118
  for (int i00 = 0; i00 < ne00; i00++) {
734
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
735
-
736
- dst_ptr[id] = *src0_ptr;
119
+ float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
120
+ dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
737
121
  id++;
738
122
  }
739
123
  }
740
124
  id += ne00 * (ne01 - ir1);
741
125
  }
742
126
  }
743
- } else if (dst->type == GGML_TYPE_F16) {
744
- size_t id = 0;
745
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
746
-
747
- for (int i03 = 0; i03 < ne03; i03++) {
748
- for (int i02 = 0; i02 < ne02; i02++) {
749
- id += ne00 * ir0;
750
- for (int i01 = ir0; i01 < ir1; i01++) {
751
- for (int i00 = 0; i00 < ne00; i00++) {
752
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
127
+ }
128
+ } else {
129
+ //printf("%s: this is not optimal - fix me\n", __func__);
753
130
 
754
- dst_ptr[id] = GGML_CPU_FP32_TO_FP16(*src0_ptr);
755
- id++;
756
- }
757
- }
758
- id += ne00 * (ne01 - ir1);
759
- }
760
- }
761
- } else if (dst->type == GGML_TYPE_BF16) {
762
- size_t id = 0;
763
- ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
131
+ size_t id = 0;
132
+ dst_t * dst_ptr = (dst_t *) dst->data;
764
133
 
765
- for (int i03 = 0; i03 < ne03; i03++) {
766
- for (int i02 = 0; i02 < ne02; i02++) {
767
- id += ne00 * ir0;
768
- for (int i01 = ir0; i01 < ir1; i01++) {
769
- for (int i00 = 0; i00 < ne00; i00++) {
770
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
134
+ for (int i03 = 0; i03 < ne03; i03++) {
135
+ for (int i02 = 0; i02 < ne02; i02++) {
136
+ id += ne00 * ir0;
137
+ for (int i01 = ir0; i01 < ir1; i01++) {
138
+ for (int i00 = 0; i00 < ne00; i00++) {
139
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
771
140
 
772
- dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
773
- id++;
774
- }
141
+ float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
142
+ dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
143
+ id++;
775
144
  }
776
- id += ne00 * (ne01 - ir1);
777
145
  }
146
+ id += ne00 * (ne01 - ir1);
778
147
  }
779
- } else {
780
- GGML_ABORT("fatal error"); // TODO: implement
781
148
  }
782
149
  }
783
-
784
150
  return;
785
151
  }
786
152
 
787
153
  // dst counters
788
-
789
154
  int64_t i10 = 0;
790
155
  int64_t i11 = 0;
791
156
  int64_t i12 = 0;
792
157
  int64_t i13 = 0;
793
158
 
794
- if (dst->type == GGML_TYPE_F32) {
159
+ if constexpr (std::is_same_v<dst_t, src_t>) {
795
160
  for (int64_t i03 = 0; i03 < ne03; i03++) {
796
161
  for (int64_t i02 = 0; i02 < ne02; i02++) {
797
162
  i10 += ne00 * ir0;
@@ -812,15 +177,15 @@ static void ggml_compute_forward_dup_f32(
812
177
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
813
178
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
814
179
 
815
- memcpy(dst_ptr, src0_ptr, sizeof(float));
180
+ memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
816
181
 
817
- if (++i10 == ne0) {
182
+ if (++i10 == ne00) {
818
183
  i10 = 0;
819
- if (++i11 == ne1) {
184
+ if (++i11 == ne01) {
820
185
  i11 = 0;
821
- if (++i12 == ne2) {
186
+ if (++i12 == ne02) {
822
187
  i12 = 0;
823
- if (++i13 == ne3) {
188
+ if (++i13 == ne03) {
824
189
  i13 = 0;
825
190
  }
826
191
  }
@@ -843,7 +208,8 @@ static void ggml_compute_forward_dup_f32(
843
208
  }
844
209
  }
845
210
  }
846
- } else if (dst->type == GGML_TYPE_F16) {
211
+
212
+ } else {
847
213
  for (int64_t i03 = 0; i03 < ne03; i03++) {
848
214
  for (int64_t i02 = 0; i02 < ne02; i02++) {
849
215
  i10 += ne00 * ir0;
@@ -864,7 +230,8 @@ static void ggml_compute_forward_dup_f32(
864
230
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
865
231
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
866
232
 
867
- *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr);
233
+ float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
234
+ *(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
868
235
 
869
236
  if (++i10 == ne0) {
870
237
  i10 = 0;
@@ -895,60 +262,63 @@ static void ggml_compute_forward_dup_f32(
895
262
  }
896
263
  }
897
264
  }
898
- } else if (dst->type == GGML_TYPE_BF16) {
899
- for (int64_t i03 = 0; i03 < ne03; i03++) {
900
- for (int64_t i02 = 0; i02 < ne02; i02++) {
901
- i10 += ne00 * ir0;
902
- while (i10 >= ne0) {
903
- i10 -= ne0;
904
- if (++i11 == ne1) {
905
- i11 = 0;
906
- if (++i12 == ne2) {
907
- i12 = 0;
908
- if (++i13 == ne3) {
909
- i13 = 0;
910
- }
911
- }
912
- }
913
- }
914
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
915
- for (int64_t i00 = 0; i00 < ne00; i00++) {
916
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
917
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
265
+ }
266
+ }
918
267
 
919
- *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
920
268
 
921
- if (++i10 == ne0) {
922
- i10 = 0;
923
- if (++i11 == ne1) {
924
- i11 = 0;
925
- if (++i12 == ne2) {
926
- i12 = 0;
927
- if (++i13 == ne3) {
928
- i13 = 0;
929
- }
930
- }
931
- }
932
- }
933
- }
934
- }
935
- i10 += ne00 * (ne01 - ir1);
936
- while (i10 >= ne0) {
937
- i10 -= ne0;
938
- if (++i11 == ne1) {
939
- i11 = 0;
940
- if (++i12 == ne2) {
941
- i12 = 0;
942
- if (++i13 == ne3) {
943
- i13 = 0;
944
- }
945
- }
269
+ template<typename src_t>
270
+ static void ggml_compute_forward_dup_to_q(
271
+ const ggml_compute_params * params,
272
+ ggml_tensor * dst) {
273
+
274
+ const ggml_tensor * src0 = dst->src[0];
275
+
276
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
277
+ GGML_ASSERT(!ggml_is_quantized(src0->type));
278
+
279
+ GGML_TENSOR_UNARY_OP_LOCALS
280
+
281
+ const int ith = params->ith; // thread index
282
+ const int nth = params->nth; // number of threads
283
+
284
+ // parallelize by rows
285
+ const int nr = ne01;
286
+ // number of rows per thread
287
+ const int dr = (nr + nth - 1) / nth;
288
+ // row range for this thread
289
+ const int ir0 = dr * ith;
290
+ const int ir1 = MIN(ir0 + dr, nr);
291
+
292
+ if (ggml_is_contiguous(dst) &&
293
+ nb00 == sizeof(src_t) &&
294
+ ggml_get_type_traits_cpu(dst->type)->from_float) {
295
+ // casting non-quantized types --> intermediate f32 --> quantized
296
+ ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
297
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
298
+
299
+ size_t id = 0;
300
+ size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
301
+ char * dst_ptr = (char *) dst->data;
302
+
303
+ for (int i03 = 0; i03 < ne03; i03++) {
304
+ for (int i02 = 0; i02 < ne02; i02++) {
305
+ id += rs * ir0;
306
+ for (int i01 = ir0; i01 < ir1; i01++) {
307
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
308
+
309
+ for (int i00 = 0; i00 < ne00; i00++) {
310
+ src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
946
311
  }
312
+
313
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
314
+ id += rs;
947
315
  }
316
+ id += rs * (ne01 - ir1);
948
317
  }
949
318
  }
950
319
  } else {
951
- GGML_ABORT("fatal error"); // TODO: implement
320
+ // printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
321
+ GGML_ABORT("not implemented");
952
322
  }
953
323
  }
954
324
 
@@ -1102,7 +472,7 @@ static void ggml_compute_forward_dup_bytes(
1102
472
  }
1103
473
  }
1104
474
 
1105
- static void ggml_compute_forward_dup_q(
475
+ static void ggml_compute_forward_dup_from_q(
1106
476
  const ggml_compute_params * params,
1107
477
  ggml_tensor * dst) {
1108
478
 
@@ -1167,20 +537,35 @@ void ggml_compute_forward_dup(
1167
537
  switch (src0->type) {
1168
538
  case GGML_TYPE_F16:
1169
539
  {
1170
- ggml_compute_forward_dup_f16(params, dst);
540
+ /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_fp16_t>(params, dst);
541
+ else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_bf16_t>(params, dst);
542
+ else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_fp16_t, float >(params, dst);
543
+ else ggml_compute_forward_dup_to_q<ggml_fp16_t>(params, dst);
1171
544
  } break;
1172
545
  case GGML_TYPE_BF16:
1173
546
  {
1174
- ggml_compute_forward_dup_bf16(params, dst);
547
+ /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_fp16_t>(params, dst);
548
+ else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(params, dst);
549
+ else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_bf16_t, float >(params, dst);
550
+ else ggml_compute_forward_dup_to_q<ggml_bf16_t>(params, dst);
1175
551
  } break;
1176
552
  case GGML_TYPE_F32:
1177
553
  {
1178
- ggml_compute_forward_dup_f32(params, dst);
554
+ /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<float, ggml_fp16_t>(params, dst);
555
+ else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<float, ggml_bf16_t>(params, dst);
556
+ else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<float, float >(params, dst);
557
+ else if (dst->type == GGML_TYPE_I32) ggml_compute_forward_dup_flt<float, int32_t >(params, dst);
558
+ else ggml_compute_forward_dup_to_q<float>(params, dst);
559
+ } break;
560
+ case GGML_TYPE_I32:
561
+ {
562
+ if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
563
+ else GGML_ABORT("not implemented");
1179
564
  } break;
1180
565
  default:
1181
566
  {
1182
567
  if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
1183
- ggml_compute_forward_dup_q(params, dst);
568
+ ggml_compute_forward_dup_from_q(params, dst);
1184
569
  break;
1185
570
  }
1186
571
  GGML_ABORT("fatal error");
@@ -2002,7 +1387,57 @@ void ggml_compute_forward_sum(
2002
1387
  } break;
2003
1388
  case GGML_TYPE_BF16:
2004
1389
  {
2005
- ggml_compute_forward_sum_bf16(params, dst);
1390
+ ggml_compute_forward_sum_bf16(params, dst);
1391
+ } break;
1392
+ default:
1393
+ {
1394
+ GGML_ABORT("fatal error");
1395
+ }
1396
+ }
1397
+ }
1398
+
1399
+ // ggml_compute_forward_cumsum
1400
+
1401
+ static void ggml_compute_forward_cumsum_f32(
1402
+ const ggml_compute_params * params,
1403
+ ggml_tensor * dst) {
1404
+
1405
+ const ggml_tensor * src0 = dst->src[0];
1406
+
1407
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
1408
+ GGML_ASSERT(dst->nb[0] == sizeof(float));
1409
+
1410
+ GGML_TENSOR_UNARY_OP_LOCALS
1411
+
1412
+ GGML_ASSERT(ne0 == ne00);
1413
+ GGML_ASSERT(ne1 == ne01);
1414
+ GGML_ASSERT(ne2 == ne02);
1415
+ GGML_ASSERT(ne3 == ne03);
1416
+
1417
+ const auto [ir0, ir1] = get_thread_range(params, src0);
1418
+
1419
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
1420
+ const int64_t i03 = ir/(ne02*ne01);
1421
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
1422
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
1423
+
1424
+ float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
1425
+ float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
1426
+
1427
+ ggml_vec_cumsum_f32(ne00, dst_row, src_row);
1428
+ }
1429
+ }
1430
+
1431
+ void ggml_compute_forward_cumsum(
1432
+ const ggml_compute_params * params,
1433
+ ggml_tensor * dst) {
1434
+
1435
+ const ggml_tensor * src0 = dst->src[0];
1436
+
1437
+ switch (src0->type) {
1438
+ case GGML_TYPE_F32:
1439
+ {
1440
+ ggml_compute_forward_cumsum_f32(params, dst);
2006
1441
  } break;
2007
1442
  default:
2008
1443
  {
@@ -2757,6 +2192,83 @@ static void ggml_compute_forward_gelu(
2757
2192
  }
2758
2193
  }
2759
2194
 
2195
+ // ggml_compute_fill
2196
+
2197
+ static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
2198
+ const float c = ggml_get_op_params_f32(dst, 0);
2199
+
2200
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
2201
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
2202
+
2203
+ const auto [ir0, ir1] = get_thread_range(params, dst);
2204
+
2205
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
2206
+ const int64_t i03 = ir/(ne2*ne1);
2207
+ const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
2208
+ const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
2209
+
2210
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2211
+
2212
+ ggml_vec_set_f32(ne0, dst_ptr, c);
2213
+ }
2214
+ }
2215
+
2216
+ void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
2217
+ ggml_compute_forward_fill_f32(params, dst);
2218
+ }
2219
+
2220
+ // ggml_compute_tri
2221
+
2222
+ static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
2223
+ const ggml_tensor * src0 = dst->src[0];
2224
+
2225
+ const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
2226
+
2227
+ GGML_ASSERT(ggml_is_contiguous(src0));
2228
+
2229
+ GGML_TENSOR_UNARY_OP_LOCALS
2230
+
2231
+ const auto [ir0, ir1] = get_thread_range(params, src0);
2232
+
2233
+ bool (*bipred)(int, int);
2234
+
2235
+ switch (ttype) {
2236
+ case GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
2237
+ case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
2238
+ case GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
2239
+ case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
2240
+ default: GGML_ABORT("invalid tri type");
2241
+ }
2242
+
2243
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
2244
+ const int64_t i03 = ir/(ne02*ne01);
2245
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
2246
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
2247
+
2248
+ const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
2249
+ float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2250
+
2251
+ for (int i0 = 0; i0 < ne0; ++i0) {
2252
+ dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
2253
+ }
2254
+ }
2255
+ }
2256
+
2257
+ void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
2258
+ const ggml_tensor * src0 = dst->src[0];
2259
+
2260
+ switch (src0->type) {
2261
+ case GGML_TYPE_F32:
2262
+ {
2263
+ ggml_compute_forward_tri_f32(params, dst);
2264
+ } break;
2265
+ default:
2266
+ {
2267
+ GGML_ABORT("fatal error");
2268
+ }
2269
+ }
2270
+ }
2271
+
2760
2272
  // ggml_compute_forward_gelu_erf
2761
2273
 
2762
2274
  static void ggml_compute_forward_gelu_erf_f32(
@@ -4084,31 +3596,27 @@ static void ggml_compute_forward_norm_f32(
4084
3596
 
4085
3597
  GGML_ASSERT(eps >= 0.0f);
4086
3598
 
4087
- // TODO: optimize
4088
3599
  for (int64_t i03 = 0; i03 < ne03; i03++) {
4089
3600
  for (int64_t i02 = 0; i02 < ne02; i02++) {
4090
3601
  for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
4091
3602
  const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
4092
3603
 
4093
- ggml_float sum = 0.0;
4094
- for (int64_t i00 = 0; i00 < ne00; i00++) {
4095
- sum += (ggml_float)x[i00];
4096
- }
4097
-
3604
+ float sum = 0.0;
3605
+ ggml_vec_sum_f32(ne00, &sum, x);
4098
3606
  float mean = sum/ne00;
4099
3607
 
4100
3608
  float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3609
+ float variance = 0;
4101
3610
 
4102
- ggml_float sum2 = 0.0;
4103
- for (int64_t i00 = 0; i00 < ne00; i00++) {
4104
- float v = x[i00] - mean;
4105
- y[i00] = v;
4106
- sum2 += (ggml_float)(v*v);
4107
- }
3611
+ #ifdef GGML_USE_ACCELERATE
3612
+ mean = -mean;
3613
+ vDSP_vsadd(x, 1, &mean, y, 1, ne00);
3614
+ vDSP_measqv(y, 1, &variance, ne00);
3615
+ #else
3616
+ variance = ggml_vec_cvar_f32(ne00, y, x, mean);
3617
+ #endif //GGML_USE_ACCELERATE
4108
3618
 
4109
- float variance = sum2/ne00;
4110
3619
  const float scale = 1.0f/sqrtf(variance + eps);
4111
-
4112
3620
  ggml_vec_scale_f32(ne00, y, scale);
4113
3621
  }
4114
3622
  }
@@ -5076,46 +4584,6 @@ void ggml_compute_forward_cont(
5076
4584
  ggml_compute_forward_dup(params, dst);
5077
4585
  }
5078
4586
 
5079
- // ggml_compute_forward_reshape
5080
-
5081
- void ggml_compute_forward_reshape(
5082
- const ggml_compute_params * params,
5083
- ggml_tensor * dst) {
5084
- // NOP
5085
- GGML_UNUSED(params);
5086
- GGML_UNUSED(dst);
5087
- }
5088
-
5089
- // ggml_compute_forward_view
5090
-
5091
- void ggml_compute_forward_view(
5092
- const ggml_compute_params * params,
5093
- ggml_tensor * dst) {
5094
- // NOP
5095
- GGML_UNUSED(params);
5096
- GGML_UNUSED(dst);
5097
- }
5098
-
5099
- // ggml_compute_forward_permute
5100
-
5101
- void ggml_compute_forward_permute(
5102
- const ggml_compute_params * params,
5103
- ggml_tensor * dst) {
5104
- // NOP
5105
- GGML_UNUSED(params);
5106
- GGML_UNUSED(dst);
5107
- }
5108
-
5109
- // ggml_compute_forward_transpose
5110
-
5111
- void ggml_compute_forward_transpose(
5112
- const ggml_compute_params * params,
5113
- ggml_tensor * dst) {
5114
- // NOP
5115
- GGML_UNUSED(params);
5116
- GGML_UNUSED(dst);
5117
- }
5118
-
5119
4587
  // ggml_compute_forward_get_rows
5120
4588
 
5121
4589
  static void ggml_compute_forward_get_rows_q(
@@ -5356,6 +4824,7 @@ void ggml_compute_forward_get_rows(
5356
4824
  //}
5357
4825
  }
5358
4826
 
4827
+ template<typename idx_t>
5359
4828
  static void ggml_compute_forward_set_rows_f32(
5360
4829
  const ggml_compute_params * params,
5361
4830
  ggml_tensor * dst) {
@@ -5394,7 +4863,7 @@ static void ggml_compute_forward_set_rows_f32(
5394
4863
  const int64_t i11 = i02%ne11;
5395
4864
  const int64_t i10 = i;
5396
4865
 
5397
- const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4866
+ const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
5398
4867
 
5399
4868
  GGML_ASSERT(i1 >= 0 && i1 < ne1);
5400
4869
 
@@ -5411,11 +4880,18 @@ void ggml_compute_forward_set_rows(
5411
4880
  ggml_tensor * dst) {
5412
4881
 
5413
4882
  const ggml_tensor * src0 = dst->src[0];
4883
+ const ggml_tensor * src1 = dst->src[1];
5414
4884
 
5415
4885
  switch (src0->type) {
5416
4886
  case GGML_TYPE_F32:
5417
4887
  {
5418
- ggml_compute_forward_set_rows_f32(params, dst);
4888
+ if (src1->type == GGML_TYPE_I64) {
4889
+ ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
4890
+ } else if (src1->type == GGML_TYPE_I32) {
4891
+ ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
4892
+ } else {
4893
+ GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
4894
+ }
5419
4895
  } break;
5420
4896
  default:
5421
4897
  {
@@ -6067,270 +5543,117 @@ static void rope_yarn(
6067
5543
  mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
6068
5544
  }
6069
5545
  *cos_theta = cosf(theta) * mscale;
6070
- *sin_theta = sinf(theta) * mscale;
6071
- }
6072
-
6073
- static void ggml_rope_cache_init(
6074
- float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
6075
- float * cache, float sin_sign, float theta_scale) {
6076
- // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
6077
- float theta = theta_base;
6078
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
6079
- const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
6080
- rope_yarn(
6081
- theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
6082
- );
6083
- cache[i0 + 1] *= sin_sign;
6084
-
6085
- theta *= theta_scale;
6086
- }
6087
- }
6088
-
6089
- static void ggml_mrope_cache_init(
6090
- float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
6091
- float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
6092
- float * cache, float sin_sign, float theta_scale) {
6093
- // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
6094
- float theta_t = theta_base_t;
6095
- float theta_h = theta_base_h;
6096
- float theta_w = theta_base_w;
6097
- float theta_e = theta_base_e; // extra position id for vision encoder
6098
- int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
6099
- int sec_w = sections[1] + sections[0];
6100
- int sec_e = sections[2] + sec_w;
6101
- GGML_ASSERT(sect_dims <= ne0);
6102
-
6103
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
6104
- const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
6105
-
6106
- int sector = (i0 / 2) % sect_dims;
6107
- if (indep_sects) {
6108
- // compute theta independently for each dim sections
6109
- // (i.e. reset corresponding theta when `i0` go from one section to another)
6110
- if (sector == 0) {
6111
- theta_t = theta_base_t;
6112
- }
6113
- else if (sector == sections[0]) {
6114
- theta_h = theta_base_h;;
6115
- }
6116
- else if (sector == sec_w) {
6117
- theta_w = theta_base_w;
6118
- }
6119
- else if (sector == sec_e) {
6120
- theta_e = theta_base_e;
6121
- }
6122
- }
6123
-
6124
- float theta = theta_t;
6125
- if (sector >= sections[0] && sector < sec_w) {
6126
- theta = theta_h;
6127
- }
6128
- else if (sector >= sec_w && sector < sec_w + sections[2]) {
6129
- theta = theta_w;
6130
- }
6131
- else if (sector >= sec_w + sections[2]) {
6132
- theta = theta_e;
6133
- }
6134
-
6135
- rope_yarn(
6136
- theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
6137
- );
6138
- cache[i0 + 1] *= sin_sign;
6139
-
6140
- theta_t *= theta_scale;
6141
- theta_w *= theta_scale;
6142
- theta_h *= theta_scale;
6143
- theta_e *= theta_scale;
6144
- }
6145
- }
6146
-
6147
- static void ggml_compute_forward_rope_f32(
6148
- const ggml_compute_params * params,
6149
- ggml_tensor * dst,
6150
- const bool forward) {
6151
-
6152
- const ggml_tensor * src0 = dst->src[0];
6153
- const ggml_tensor * src1 = dst->src[1];
6154
- const ggml_tensor * src2 = dst->src[2];
6155
-
6156
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
6157
- int sections[4];
6158
-
6159
- //const int n_past = ((int32_t *) dst->op_params)[0];
6160
- const int n_dims = ((int32_t *) dst->op_params)[1];
6161
- const int mode = ((int32_t *) dst->op_params)[2];
6162
- //const int n_ctx = ((int32_t *) dst->op_params)[3];
6163
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
6164
-
6165
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
6166
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
6167
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
6168
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
6169
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
6170
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
6171
- memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
6172
-
6173
- GGML_TENSOR_UNARY_OP_LOCALS
6174
-
6175
- //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
6176
- //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
6177
-
6178
- GGML_ASSERT(nb00 == sizeof(float));
6179
-
6180
- const int ith = params->ith;
6181
- const int nth = params->nth;
6182
-
6183
- const int nr = ggml_nrows(dst);
6184
-
6185
- GGML_ASSERT(n_dims <= ne0);
6186
- GGML_ASSERT(n_dims % 2 == 0);
6187
-
6188
- // rows per thread
6189
- const int dr = (nr + nth - 1)/nth;
6190
-
6191
- // row range for this thread
6192
- const int ir0 = dr*ith;
6193
- const int ir1 = MIN(ir0 + dr, nr);
6194
-
6195
- // row index used to determine which thread to use
6196
- int ir = 0;
6197
-
6198
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
6199
-
6200
- float corr_dims[2];
6201
- ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
6202
-
6203
- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
6204
- const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
6205
- const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
6206
-
6207
- if (is_mrope) {
6208
- GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
6209
- }
6210
-
6211
- if (is_vision) {
6212
- GGML_ASSERT(n_dims == ne0/2);
6213
- }
6214
-
6215
- const float * freq_factors = NULL;
6216
- if (src2 != NULL) {
6217
- GGML_ASSERT(src2->type == GGML_TYPE_F32);
6218
- GGML_ASSERT(src2->ne[0] >= n_dims / 2);
6219
- freq_factors = (const float *) src2->data;
6220
- }
6221
-
6222
- // backward process uses inverse rotation by cos and sin.
6223
- // cos and sin build a rotation matrix, where the inverse is the transpose.
6224
- // this essentially just switches the sign of sin.
6225
- const float sin_sign = forward ? 1.0f : -1.0f;
6226
-
6227
- const int32_t * pos = (const int32_t *) src1->data;
6228
-
6229
- for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
6230
- for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
6231
-
6232
- float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
6233
- if (!is_mrope) {
6234
- const int64_t p = pos[i2];
6235
- ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
6236
- }
6237
- else {
6238
- const int64_t p_t = pos[i2];
6239
- const int64_t p_h = pos[i2 + ne2];
6240
- const int64_t p_w = pos[i2 + ne2 * 2];
6241
- const int64_t p_e = pos[i2 + ne2 * 3];
6242
- ggml_mrope_cache_init(
6243
- p_t, p_h, p_w, p_e, sections, is_vision,
6244
- freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
6245
- }
6246
-
6247
- for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
6248
- if (ir++ < ir0) continue;
6249
- if (ir > ir1) break;
6250
-
6251
- if (is_neox || is_mrope) {
6252
- if (is_vision){
6253
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
6254
- const int64_t ic = i0/2;
6255
-
6256
- const float cos_theta = cache[i0 + 0];
6257
- const float sin_theta = cache[i0 + 1];
6258
-
6259
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
6260
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
6261
-
6262
- const float x0 = src[0];
6263
- const float x1 = src[n_dims];
5546
+ *sin_theta = sinf(theta) * mscale;
5547
+ }
6264
5548
 
6265
- dst_data[0] = x0*cos_theta - x1*sin_theta;
6266
- dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
6267
- }
6268
- } else {
6269
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
6270
- const int64_t ic = i0/2;
5549
+ static void ggml_rope_cache_init(
5550
+ float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5551
+ float * cache, float sin_sign, float theta_scale) {
5552
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
5553
+ float theta = theta_base;
5554
+ for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
5555
+ const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
5556
+ rope_yarn(
5557
+ theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
5558
+ );
5559
+ cache[i0 + 1] *= sin_sign;
6271
5560
 
6272
- const float cos_theta = cache[i0 + 0];
6273
- const float sin_theta = cache[i0 + 1];
5561
+ theta *= theta_scale;
5562
+ }
5563
+ }
6274
5564
 
6275
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
6276
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5565
+ static void ggml_mrope_cache_init(
5566
+ float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
5567
+ float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5568
+ float * cache, float sin_sign, float theta_scale) {
5569
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
5570
+ float theta_t = theta_base_t;
5571
+ float theta_h = theta_base_h;
5572
+ float theta_w = theta_base_w;
5573
+ float theta_e = theta_base_e; // extra position id for vision encoder
5574
+ int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
5575
+ int sec_w = sections[1] + sections[0];
5576
+ int sec_e = sections[2] + sec_w;
5577
+ GGML_ASSERT(sect_dims <= ne0);
6277
5578
 
6278
- const float x0 = src[0];
6279
- const float x1 = src[n_dims/2];
5579
+ for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
5580
+ const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
6280
5581
 
6281
- dst_data[0] = x0*cos_theta - x1*sin_theta;
6282
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
6283
- }
6284
- }
6285
- } else {
6286
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
6287
- const float cos_theta = cache[i0 + 0];
6288
- const float sin_theta = cache[i0 + 1];
5582
+ int sector = (i0 / 2) % sect_dims;
5583
+ if (indep_sects) {
5584
+ // compute theta independently for each dim sections
5585
+ // (i.e. reset corresponding theta when `i0` go from one section to another)
5586
+ if (sector == 0) {
5587
+ theta_t = theta_base_t;
5588
+ }
5589
+ else if (sector == sections[0]) {
5590
+ theta_h = theta_base_h;;
5591
+ }
5592
+ else if (sector == sec_w) {
5593
+ theta_w = theta_base_w;
5594
+ }
5595
+ else if (sector == sec_e) {
5596
+ theta_e = theta_base_e;
5597
+ }
5598
+ }
6289
5599
 
6290
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
6291
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5600
+ float theta = theta_t;
5601
+ if (is_imrope) { // qwen3vl apply interleaved mrope
5602
+ if (sector % 3 == 1 && sector < 3 * sections[1]) {
5603
+ theta = theta_h;
5604
+ } else if (sector % 3 == 2 && sector < 3 * sections[2]) {
5605
+ theta = theta_w;
5606
+ } else if (sector % 3 == 0 && sector < 3 * sections[0]) {
5607
+ theta = theta_t;
5608
+ } else {
5609
+ theta = theta_e;
5610
+ }
5611
+ } else {
5612
+ if (sector >= sections[0] && sector < sec_w) {
5613
+ theta = theta_h;
5614
+ }
5615
+ else if (sector >= sec_w && sector < sec_w + sections[2]) {
5616
+ theta = theta_w;
5617
+ }
5618
+ else if (sector >= sec_w + sections[2]) {
5619
+ theta = theta_e;
5620
+ }
5621
+ }
6292
5622
 
6293
- const float x0 = src[0];
6294
- const float x1 = src[1];
5623
+ rope_yarn(
5624
+ theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
5625
+ );
5626
+ cache[i0 + 1] *= sin_sign;
6295
5627
 
6296
- dst_data[0] = x0*cos_theta - x1*sin_theta;
6297
- dst_data[1] = x0*sin_theta + x1*cos_theta;
6298
- }
6299
- }
5628
+ theta_t *= theta_scale;
5629
+ theta_w *= theta_scale;
5630
+ theta_h *= theta_scale;
5631
+ theta_e *= theta_scale;
5632
+ }
5633
+ }
6300
5634
 
6301
- if (is_vision) {
6302
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
6303
- const int64_t ic = i0/2;
6304
5635
 
6305
- const float cos_theta = cache[i0 + 0];
6306
- const float sin_theta = cache[i0 + 1];
5636
+ template<typename T>
5637
+ static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
5638
+ for (int64_t i0 = 0; i0 < n; i0 += 2) {
5639
+ const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
6307
5640
 
6308
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
6309
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5641
+ const float cos_theta = cache[i0 + 0];
5642
+ const float sin_theta = cache[i0 + 1];
6310
5643
 
6311
- const float x0 = src[0];
6312
- const float x1 = src[n_dims];
5644
+ const T * const src = src_data + ic;
5645
+ T * dst = dst_data + ic;
6313
5646
 
6314
- dst_data[0] = x0*cos_theta - x1*sin_theta;
6315
- dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
6316
- }
6317
- } else {
6318
- // fill the remain channels with data from src tensor
6319
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
6320
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
6321
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5647
+ const float x0 = type_conversion_table<T>::to_f32(src[0]);
5648
+ const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
6322
5649
 
6323
- dst_data[0] = src[0];
6324
- dst_data[1] = src[1];
6325
- }
6326
- }
6327
- }
6328
- }
6329
- }
5650
+ dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
5651
+ dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
5652
+ }
6330
5653
  }
6331
5654
 
6332
- // TODO: deduplicate f16/f32 code
6333
- static void ggml_compute_forward_rope_f16(
5655
+ template<typename T> //float or ggml_fp16_t
5656
+ static void ggml_compute_forward_rope_flt(
6334
5657
  const ggml_compute_params * params,
6335
5658
  ggml_tensor * dst,
6336
5659
  const bool forward) {
@@ -6339,6 +5662,9 @@ static void ggml_compute_forward_rope_f16(
6339
5662
  const ggml_tensor * src1 = dst->src[1];
6340
5663
  const ggml_tensor * src2 = dst->src[2];
6341
5664
 
5665
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
5666
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
5667
+
6342
5668
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
6343
5669
  int sections[4];
6344
5670
 
@@ -6347,6 +5673,7 @@ static void ggml_compute_forward_rope_f16(
6347
5673
  const int mode = ((int32_t *) dst->op_params)[2];
6348
5674
  //const int n_ctx = ((int32_t *) dst->op_params)[3];
6349
5675
  const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
5676
+
6350
5677
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
6351
5678
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
6352
5679
  memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@@ -6355,13 +5682,13 @@ static void ggml_compute_forward_rope_f16(
6355
5682
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
6356
5683
  memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
6357
5684
 
6358
-
6359
5685
  GGML_TENSOR_UNARY_OP_LOCALS
6360
5686
 
6361
5687
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
6362
5688
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
6363
5689
 
6364
- GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
5690
+ GGML_ASSERT(nb0 == nb00);
5691
+ GGML_ASSERT(nb0 == sizeof(T));
6365
5692
 
6366
5693
  const int ith = params->ith;
6367
5694
  const int nth = params->nth;
@@ -6386,11 +5713,11 @@ static void ggml_compute_forward_rope_f16(
6386
5713
  float corr_dims[2];
6387
5714
  ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
6388
5715
 
6389
- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
6390
- const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
5716
+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
5717
+ const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
6391
5718
  const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
6392
5719
 
6393
- if (is_mrope) {
5720
+ if (mrope_used) {
6394
5721
  GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
6395
5722
  }
6396
5723
 
@@ -6412,11 +5739,11 @@ static void ggml_compute_forward_rope_f16(
6412
5739
 
6413
5740
  const int32_t * pos = (const int32_t *) src1->data;
6414
5741
 
6415
- for (int64_t i3 = 0; i3 < ne3; i3++) {
6416
- for (int64_t i2 = 0; i2 < ne2; i2++) {
5742
+ for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
5743
+ for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
6417
5744
 
6418
5745
  float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
6419
- if (!is_mrope) {
5746
+ if (!mrope_used) {
6420
5747
  const int64_t p = pos[i2];
6421
5748
  ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
6422
5749
  }
@@ -6426,90 +5753,44 @@ static void ggml_compute_forward_rope_f16(
6426
5753
  const int64_t p_w = pos[i2 + ne2 * 2];
6427
5754
  const int64_t p_e = pos[i2 + ne2 * 3];
6428
5755
  ggml_mrope_cache_init(
6429
- p_t, p_h, p_w, p_e, sections, is_vision,
5756
+ p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
6430
5757
  freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
6431
5758
  }
6432
5759
 
6433
- for (int64_t i1 = 0; i1 < ne1; i1++) {
5760
+ for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
6434
5761
  if (ir++ < ir0) continue;
6435
5762
  if (ir > ir1) break;
6436
5763
 
6437
- if (is_neox || is_mrope) {
6438
- if (is_vision) {
6439
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
6440
- const int64_t ic = i0/2;
6441
-
6442
- const float cos_theta = cache[i0 + 0];
6443
- const float sin_theta = cache[i0 + 1];
6444
-
6445
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
6446
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
6447
-
6448
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
6449
- const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
6450
-
6451
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
6452
- dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
6453
- }
6454
- } else {
6455
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
6456
- const int64_t ic = i0/2;
6457
-
6458
- const float cos_theta = cache[i0 + 0];
6459
- const float sin_theta = cache[i0 + 1];
6460
-
6461
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
6462
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
6463
-
6464
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
6465
- const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
6466
-
6467
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
6468
- dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
6469
- }
6470
- }
6471
- } else {
6472
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
6473
- const float cos_theta = cache[i0 + 0];
6474
- const float sin_theta = cache[i0 + 1];
6475
-
6476
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
6477
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
6478
-
6479
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
6480
- const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
6481
-
6482
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
6483
- dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
6484
- }
6485
- }
6486
-
6487
- if (is_vision) {
6488
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
6489
- const int64_t ic = i0/2;
6490
-
6491
- const float cos_theta = cache[i0 + 0];
6492
- const float sin_theta = cache[i0 + 1];
6493
-
6494
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
6495
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
6496
-
6497
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
6498
- const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
6499
-
6500
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
6501
- dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
6502
- }
6503
- } else {
5764
+ T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5765
+ T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
5766
+
5767
+ switch (mode) {
5768
+ case GGML_ROPE_TYPE_NORMAL:
5769
+ rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
5770
+ break;
5771
+ case GGML_ROPE_TYPE_NEOX:
5772
+ case GGML_ROPE_TYPE_MROPE:
5773
+ case GGML_ROPE_TYPE_IMROPE:
5774
+ rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
5775
+ break;
5776
+ case GGML_ROPE_TYPE_VISION:
5777
+ rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
5778
+ break;
5779
+ default:
5780
+ GGML_ABORT("rope type not supported");
5781
+ }
5782
+
5783
+ if (!is_vision) {
5784
+ // fill the remain channels with data from src tensor
6504
5785
  for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
6505
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
6506
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5786
+ const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5787
+ T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
6507
5788
 
6508
5789
  dst_data[0] = src[0];
6509
5790
  dst_data[1] = src[1];
6510
5791
  }
6511
5792
  }
6512
- }
5793
+ } //attn-heads
6513
5794
  }
6514
5795
  }
6515
5796
  }
@@ -6523,11 +5804,11 @@ void ggml_compute_forward_rope(
6523
5804
  switch (src0->type) {
6524
5805
  case GGML_TYPE_F16:
6525
5806
  {
6526
- ggml_compute_forward_rope_f16(params, dst, true);
5807
+ ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
6527
5808
  } break;
6528
5809
  case GGML_TYPE_F32:
6529
5810
  {
6530
- ggml_compute_forward_rope_f32(params, dst, true);
5811
+ ggml_compute_forward_rope_flt<float>(params, dst, true);
6531
5812
  } break;
6532
5813
  default:
6533
5814
  {
@@ -6547,11 +5828,11 @@ void ggml_compute_forward_rope_back(
6547
5828
  switch (src0->type) {
6548
5829
  case GGML_TYPE_F16:
6549
5830
  {
6550
- ggml_compute_forward_rope_f16(params, dst, false);
5831
+ ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
6551
5832
  } break;
6552
5833
  case GGML_TYPE_F32:
6553
5834
  {
6554
- ggml_compute_forward_rope_f32(params, dst, false);
5835
+ ggml_compute_forward_rope_flt<float>(params, dst, false);
6555
5836
  } break;
6556
5837
  default:
6557
5838
  {
@@ -6938,10 +6219,198 @@ void ggml_compute_forward_im2col_back_f32(
6938
6219
  const ggml_compute_params * params,
6939
6220
  ggml_tensor * dst) {
6940
6221
 
6941
- const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
6942
- const ggml_tensor * src1 = dst->src[1]; // convolution kernel
6222
+ const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
6223
+ const ggml_tensor * src1 = dst->src[1]; // convolution kernel
6224
+
6225
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
6226
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6227
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
6228
+
6229
+ GGML_TENSOR_BINARY_OP_LOCALS;
6230
+
6231
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6232
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6233
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6234
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6235
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6236
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6237
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6238
+
6239
+ const int ith = params->ith;
6240
+ const int nth = params->nth;
6241
+
6242
+ const int64_t N = is_2D ? ne3 : ne2;
6243
+ const int64_t IC = is_2D ? ne2 : ne1;
6244
+ const int64_t IH = is_2D ? ne1 : 1;
6245
+ const int64_t IW = ne0;
6246
+
6247
+ const int64_t KH = is_2D ? ne11 : 1;
6248
+ const int64_t KW = ne10;
6249
+
6250
+ const int64_t OH = is_2D ? ne02 : 1;
6251
+ const int64_t OW = ne01;
6252
+
6253
+ int ofs0 = is_2D ? nb3 : nb2;
6254
+ int ofs1 = is_2D ? nb2 : nb1;
6255
+
6256
+ GGML_ASSERT(nb0 == sizeof(float));
6257
+
6258
+ // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6259
+ {
6260
+ float * const wdata = (float *) dst->data;
6261
+
6262
+ for (int64_t in = 0; in < N; in++) {
6263
+ for (int64_t iic = ith; iic < IC; iic += nth) {
6264
+ for (int64_t iih = 0; iih < IH; iih++) {
6265
+ for (int64_t iiw = 0; iiw < IW; iiw++) {
6266
+
6267
+ // micro kernel
6268
+ float grad = 0.0f;
6269
+ for (int64_t ikh = 0; ikh < KH; ikh++) {
6270
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
6271
+ // For s0 > 1 some values were skipped over in the forward pass.
6272
+ // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
6273
+ const int64_t tmpw = (iiw + p0 - ikw*d0);
6274
+ if (tmpw % s0 != 0) {
6275
+ continue;
6276
+ }
6277
+ const int64_t iow = tmpw / s0;
6278
+
6279
+ // Equivalent logic as above except for s1.
6280
+ int64_t ioh;
6281
+ if (is_2D) {
6282
+ const int64_t tmph = iih + p1 - ikh*d1;
6283
+
6284
+ if (tmph % s1 != 0) {
6285
+ continue;
6286
+ }
6287
+
6288
+ ioh = tmph / s1;
6289
+ } else {
6290
+ ioh = 0;
6291
+ }
6292
+
6293
+ if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
6294
+ continue;
6295
+ }
6296
+
6297
+ const float * const grad_in = (const float *) src0->data
6298
+ + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6299
+ grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
6300
+ }
6301
+ }
6302
+ float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
6303
+ dst_data[iih*IW + iiw] = grad;
6304
+ }
6305
+ }
6306
+ }
6307
+ }
6308
+ }
6309
+ }
6310
+
6311
+
6312
+ // ggml_compute_forward_im2col_3d_f16
6313
+ // src0: kernel [OC*IC, KD, KH, KW]
6314
+ // src1: image [N*IC, ID, IH, IW]
6315
+ // dst: result [N*OD, OH, OW, IC * KD * KH * KW]
6316
+ static void ggml_compute_forward_im2col_3d_f16(
6317
+ const ggml_compute_params * params,
6318
+ ggml_tensor * dst) {
6319
+
6320
+ const ggml_tensor * src0 = dst->src[0];
6321
+ const ggml_tensor * src1 = dst->src[1];
6322
+
6323
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
6324
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6325
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
6326
+
6327
+ GGML_TENSOR_BINARY_OP_LOCALS;
6328
+
6329
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6330
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6331
+ const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6332
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6333
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6334
+ const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6335
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6336
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6337
+ const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6338
+ const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6339
+
6340
+
6341
+ const int ith = params->ith;
6342
+ const int nth = params->nth;
6343
+
6344
+ const int64_t N = ne13 / IC;
6345
+ const int64_t ID = ne12;
6346
+ const int64_t IH = ne11;
6347
+ const int64_t IW = ne10;
6348
+
6349
+ const int64_t OC = ne03 / IC;
6350
+ GGML_UNUSED(OC);
6351
+ const int64_t KD = ne02;
6352
+ const int64_t KH = ne01;
6353
+ const int64_t KW = ne00;
6354
+
6355
+ const int64_t OD = ne3 / N;
6356
+ const int64_t OH = ne2;
6357
+ const int64_t OW = ne1;
6358
+ const int64_t OH_OW = OH*OW;
6359
+ const int64_t KD_KH_KW = KD*KH*KW;
6360
+ const int64_t KH_KW = KH*KW;
6361
+ const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6362
+
6363
+ GGML_ASSERT(nb10 == sizeof(float));
6364
+
6365
+ // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6366
+ {
6367
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
6368
+
6369
+ for (int64_t in = 0; in < N; in++) {
6370
+ for (int64_t iod = 0; iod < OD; iod++) {
6371
+ for (int64_t ioh = 0; ioh < OH; ioh++) {
6372
+ for (int64_t iow = 0; iow < OW; iow++) {
6373
+ for (int64_t iic = ith; iic < IC; iic += nth) {
6374
+
6375
+ // micro kernel
6376
+ ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6377
+ const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6378
+
6379
+ for (int64_t ikd = 0; ikd < KD; ikd++) {
6380
+ for (int64_t ikh = 0; ikh < KH; ikh++) {
6381
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
6382
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
6383
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
6384
+ const int64_t iid = iod*s2 + ikd*d2 - p2;
6385
+
6386
+ if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
6387
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6388
+ } else {
6389
+ const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6390
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);
6391
+ }
6392
+ }
6393
+ }
6394
+ }
6395
+ }
6396
+ }
6397
+ }
6398
+ }
6399
+ }
6400
+ }
6401
+ }
6402
+
6403
+ // ggml_compute_forward_im2col_3d_f32
6404
+ // src0: kernel [OC*IC, KD, KH, KW]
6405
+ // src1: image [N*IC, ID, IH, IW]
6406
+ // dst: result [N*OD, OH, OW, IC * KD * KH * KW]
6407
+ static void ggml_compute_forward_im2col_3d_f32(
6408
+ const ggml_compute_params * params,
6409
+ ggml_tensor * dst) {
6410
+
6411
+ const ggml_tensor * src0 = dst->src[0];
6412
+ const ggml_tensor * src1 = dst->src[1];
6943
6413
 
6944
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
6945
6414
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
6946
6415
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
6947
6416
 
@@ -6949,77 +6418,72 @@ void ggml_compute_forward_im2col_back_f32(
6949
6418
 
6950
6419
  const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6951
6420
  const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6952
- const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6953
- const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6954
- const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6955
- const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6956
- const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6421
+ const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6422
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6423
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6424
+ const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6425
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6426
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6427
+ const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6428
+ const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6429
+
6957
6430
 
6958
6431
  const int ith = params->ith;
6959
6432
  const int nth = params->nth;
6960
6433
 
6961
- const int64_t N = is_2D ? ne3 : ne2;
6962
- const int64_t IC = is_2D ? ne2 : ne1;
6963
- const int64_t IH = is_2D ? ne1 : 1;
6964
- const int64_t IW = ne0;
6434
+ const int64_t N = ne13 / IC;
6435
+ const int64_t ID = ne12;
6436
+ const int64_t IH = ne11;
6437
+ const int64_t IW = ne10;
6965
6438
 
6966
- const int64_t KH = is_2D ? ne11 : 1;
6967
- const int64_t KW = ne10;
6439
+ const int64_t OC = ne03 / IC;
6440
+ GGML_UNUSED(OC);
6441
+ const int64_t KD = ne02;
6442
+ const int64_t KH = ne01;
6443
+ const int64_t KW = ne00;
6968
6444
 
6969
- const int64_t OH = is_2D ? ne02 : 1;
6970
- const int64_t OW = ne01;
6445
+ const int64_t OD = ne3 / N;
6446
+ const int64_t OH = ne2;
6447
+ const int64_t OW = ne1;
6971
6448
 
6972
- int ofs0 = is_2D ? nb3 : nb2;
6973
- int ofs1 = is_2D ? nb2 : nb1;
6449
+ const int64_t OH_OW = OH*OW;
6450
+ const int64_t KD_KH_KW = KD*KH*KW;
6451
+ const int64_t KH_KW = KH*KW;
6452
+ const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6974
6453
 
6975
- GGML_ASSERT(nb0 == sizeof(float));
6454
+ GGML_ASSERT(nb10 == sizeof(float));
6976
6455
 
6977
- // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6456
+ // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6978
6457
  {
6979
6458
  float * const wdata = (float *) dst->data;
6980
6459
 
6981
6460
  for (int64_t in = 0; in < N; in++) {
6982
- for (int64_t iic = ith; iic < IC; iic += nth) {
6983
- for (int64_t iih = 0; iih < IH; iih++) {
6984
- for (int64_t iiw = 0; iiw < IW; iiw++) {
6985
-
6986
- // micro kernel
6987
- float grad = 0.0f;
6988
- for (int64_t ikh = 0; ikh < KH; ikh++) {
6989
- for (int64_t ikw = 0; ikw < KW; ikw++) {
6990
- // For s0 > 1 some values were skipped over in the forward pass.
6991
- // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
6992
- const int64_t tmpw = (iiw + p0 - ikw*d0);
6993
- if (tmpw % s0 != 0) {
6994
- continue;
6995
- }
6996
- const int64_t iow = tmpw / s0;
6997
-
6998
- // Equivalent logic as above except for s1.
6999
- int64_t ioh;
7000
- if (is_2D) {
7001
- const int64_t tmph = iih + p1 - ikh*d1;
7002
-
7003
- if (tmph % s1 != 0) {
7004
- continue;
6461
+ for (int64_t iod = 0; iod < OD; iod++) {
6462
+ for (int64_t ioh = 0; ioh < OH; ioh++) {
6463
+ for (int64_t iow = 0; iow < OW; iow++) {
6464
+ for (int64_t iic = ith; iic < IC; iic += nth) {
6465
+
6466
+ // micro kernel
6467
+ float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6468
+ const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6469
+
6470
+ for (int64_t ikd = 0; ikd < KD; ikd++) {
6471
+ for (int64_t ikh = 0; ikh < KH; ikh++) {
6472
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
6473
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
6474
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
6475
+ const int64_t iid = iod*s2 + ikd*d2 - p2;
6476
+
6477
+ if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
6478
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6479
+ } else {
6480
+ const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6481
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
6482
+ }
7005
6483
  }
7006
-
7007
- ioh = tmph / s1;
7008
- } else {
7009
- ioh = 0;
7010
- }
7011
-
7012
- if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
7013
- continue;
7014
6484
  }
7015
-
7016
- const float * const grad_in = (const float *) src0->data
7017
- + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
7018
- grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
7019
6485
  }
7020
6486
  }
7021
- float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
7022
- dst_data[iih*IW + iiw] = grad;
7023
6487
  }
7024
6488
  }
7025
6489
  }
@@ -7027,6 +6491,26 @@ void ggml_compute_forward_im2col_back_f32(
7027
6491
  }
7028
6492
  }
7029
6493
 
6494
+
6495
+ void ggml_compute_forward_im2col_3d(
6496
+ const ggml_compute_params * params,
6497
+ ggml_tensor * dst) {
6498
+ switch (dst->type) {
6499
+ case GGML_TYPE_F16:
6500
+ {
6501
+ ggml_compute_forward_im2col_3d_f16(params, dst);
6502
+ } break;
6503
+ case GGML_TYPE_F32:
6504
+ {
6505
+ ggml_compute_forward_im2col_3d_f32(params, dst);
6506
+ } break;
6507
+ default:
6508
+ {
6509
+ GGML_ABORT("fatal error");
6510
+ }
6511
+ }
6512
+ }
6513
+
7030
6514
  static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
7031
6515
  void * a, void * b, float * c) {
7032
6516
  const ggml_type_traits * traits = ggml_get_type_traits(type);
@@ -7480,7 +6964,11 @@ static void ggml_compute_forward_conv_2d_dw_cwhn(
7480
6964
  const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
7481
6965
 
7482
6966
  #ifdef GGML_SIMD
7483
- const int64_t pkg_size = GGML_F32_EPR;
6967
+ #if defined(__ARM_FEATURE_SVE)
6968
+ const int64_t pkg_size = svcntw();
6969
+ #else
6970
+ const int64_t pkg_size = GGML_F32_EPR;
6971
+ #endif
7484
6972
  const int64_t pkg_count = c / pkg_size;
7485
6973
  const int64_t c_pkg_end = pkg_count * pkg_size;
7486
6974
  #else
@@ -7903,10 +7391,17 @@ static void ggml_compute_forward_upscale_f32(
7903
7391
  float sf1 = (float)ne1/src0->ne[1];
7904
7392
  float sf2 = (float)ne2/src0->ne[2];
7905
7393
  float sf3 = (float)ne3/src0->ne[3];
7394
+ float pixel_offset = 0.5f;
7906
7395
 
7907
7396
  const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
7908
7397
  const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
7909
7398
 
7399
+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7400
+ pixel_offset = 0.0f;
7401
+ sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
7402
+ sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
7403
+ }
7404
+
7910
7405
  if (mode == GGML_SCALE_MODE_NEAREST) {
7911
7406
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7912
7407
  const int64_t i03 = i3 / sf3;
@@ -7926,13 +7421,6 @@ static void ggml_compute_forward_upscale_f32(
7926
7421
  }
7927
7422
  }
7928
7423
  } else if (mode == GGML_SCALE_MODE_BILINEAR) {
7929
- float pixel_offset = 0.5f;
7930
- if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7931
- pixel_offset = 0.0f;
7932
- sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
7933
- sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
7934
- }
7935
-
7936
7424
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7937
7425
  const int64_t i03 = i3 / sf3;
7938
7426
  for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
@@ -7967,6 +7455,51 @@ static void ggml_compute_forward_upscale_f32(
7967
7455
 
7968
7456
  const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
7969
7457
 
7458
+ float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7459
+ *y_dst = val;
7460
+ }
7461
+ }
7462
+ }
7463
+ }
7464
+ } else if (mode == GGML_SCALE_MODE_BICUBIC) {
7465
+ // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
7466
+ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
7467
+ auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
7468
+ auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
7469
+ auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
7470
+ const float w0 = weight2(x + 1);
7471
+ const float w1 = weight1(x + 0);
7472
+ const float w2 = weight1(1 - x);
7473
+ const float w3 = weight2(2 - x);
7474
+ return p0*w0 + p1*w1 + p2*w2 + p3*w3;
7475
+ };
7476
+
7477
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
7478
+ const int64_t i03 = i3 / sf3;
7479
+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7480
+ const int64_t i02 = i2 / sf2;
7481
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
7482
+ const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
7483
+ const int64_t y0 = (int64_t)floorf(y);
7484
+ const float dy = y - (float)y0;
7485
+
7486
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
7487
+ const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
7488
+ const int64_t x0 = (int64_t)floorf(x);
7489
+ const float dx = x - (float)x0;
7490
+
7491
+ auto p = [=](int64_t x_off, int64_t y_off) -> float {
7492
+ int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
7493
+ int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
7494
+ return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7495
+ };
7496
+
7497
+ const float val = bicubic(
7498
+ bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
7499
+ bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
7500
+ bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
7501
+ bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
7502
+
7970
7503
  float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7971
7504
  *y_dst = val;
7972
7505
  }
@@ -8014,6 +7547,15 @@ static void ggml_compute_forward_pad_f32(
8014
7547
  GGML_TENSOR_UNARY_OP_LOCALS
8015
7548
 
8016
7549
  float * dst_ptr = (float *) dst->data;
7550
+ const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
7551
+ const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
7552
+ const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
7553
+ const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
7554
+ const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
7555
+ const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
7556
+ const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
7557
+ const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
7558
+
8017
7559
 
8018
7560
  // TODO: optimize
8019
7561
 
@@ -8022,10 +7564,12 @@ static void ggml_compute_forward_pad_f32(
8022
7564
  for (int64_t i0 = 0; i0 < ne0; ++i0) {
8023
7565
  for (int64_t i3 = 0; i3 < ne3; ++i3) {
8024
7566
  const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
8025
-
8026
- const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
8027
-
8028
- if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
7567
+ if ((i0 >= lp0 && i0 < ne0 - rp0) \
7568
+ && (i1 >= lp1 && i1 < ne1 - rp1) \
7569
+ && (i2 >= lp2 && i2 < ne2 - rp2) \
7570
+ && (i3 >= lp3 && i3 < ne3 - rp3)) {
7571
+ const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
7572
+ const float * src_ptr = (const float *)((char *) src0->data + src_idx);
8029
7573
  dst_ptr[dst_idx] = *src_ptr;
8030
7574
  } else {
8031
7575
  dst_ptr[dst_idx] = 0;
@@ -8224,7 +7768,7 @@ static void ggml_compute_forward_timestep_embedding_f32(
8224
7768
  embed_data[j + half] = sinf(arg);
8225
7769
  }
8226
7770
  if (dim % 2 != 0 && ith == 0) {
8227
- embed_data[dim] = 0.f;
7771
+ embed_data[2 * half] = 0.f;
8228
7772
  }
8229
7773
  }
8230
7774
  }
@@ -8249,6 +7793,18 @@ void ggml_compute_forward_timestep_embedding(
8249
7793
 
8250
7794
  // ggml_compute_forward_argsort
8251
7795
 
7796
+ template<enum ggml_sort_order order>
7797
+ struct argsort_cmp {
7798
+ const float * data;
7799
+ bool operator()(int32_t a, int32_t b) const {
7800
+ if constexpr (order == GGML_SORT_ORDER_ASC) {
7801
+ return data[a] < data[b];
7802
+ } else {
7803
+ return data[a] > data[b];
7804
+ }
7805
+ }
7806
+ };
7807
+
8252
7808
  static void ggml_compute_forward_argsort_f32(
8253
7809
  const ggml_compute_params * params,
8254
7810
  ggml_tensor * dst) {
@@ -8267,23 +7823,25 @@ static void ggml_compute_forward_argsort_f32(
8267
7823
  ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
8268
7824
 
8269
7825
  for (int64_t i = ith; i < nr; i += nth) {
8270
- int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
8271
7826
  const float * src_data = (float *)((char *) src0->data + i*nb01);
8272
7827
 
7828
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7829
+
8273
7830
  for (int64_t j = 0; j < ne0; j++) {
8274
7831
  dst_data[j] = j;
8275
7832
  }
8276
7833
 
8277
- // C doesn't have a functional sort, so we do a bubble sort instead
8278
- for (int64_t j = 0; j < ne0; j++) {
8279
- for (int64_t k = j + 1; k < ne0; k++) {
8280
- if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
8281
- (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
8282
- int32_t tmp = dst_data[j];
8283
- dst_data[j] = dst_data[k];
8284
- dst_data[k] = tmp;
8285
- }
8286
- }
7834
+ switch (order) {
7835
+ case GGML_SORT_ORDER_ASC:
7836
+ std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
7837
+ break;
7838
+
7839
+ case GGML_SORT_ORDER_DESC:
7840
+ std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
7841
+ break;
7842
+
7843
+ default:
7844
+ GGML_ABORT("invalid sort order");
8287
7845
  }
8288
7846
  }
8289
7847
  }
@@ -8308,10 +7866,10 @@ void ggml_compute_forward_argsort(
8308
7866
 
8309
7867
  // ggml_compute_forward_flash_attn_ext
8310
7868
 
8311
- static void ggml_compute_forward_flash_attn_ext_f16(
7869
+ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8312
7870
  const ggml_compute_params * params,
8313
- ggml_tensor * dst) {
8314
-
7871
+ ggml_tensor * dst,
7872
+ int ir0, int ir1) {
8315
7873
  const ggml_tensor * q = dst->src[0];
8316
7874
  const ggml_tensor * k = dst->src[1];
8317
7875
  const ggml_tensor * v = dst->src[2];
@@ -8327,9 +7885,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8327
7885
  GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8328
7886
  GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8329
7887
 
8330
- const int ith = params->ith;
8331
- const int nth = params->nth;
8332
-
8333
7888
  const int64_t DK = nek0;
8334
7889
  const int64_t DV = nev0;
8335
7890
  const int64_t N = neq1;
@@ -8363,16 +7918,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8363
7918
 
8364
7919
  // parallelize by q rows using ggml_vec_dot_f32
8365
7920
 
8366
- // total rows in q
8367
- const int nr = neq1*neq2*neq3;
8368
-
8369
- // rows per thread
8370
- const int dr = (nr + nth - 1)/nth;
8371
-
8372
- // row range for this thread
8373
- const int ir0 = dr*ith;
8374
- const int ir1 = MIN(ir0 + dr, nr);
8375
-
8376
7921
  float scale = 1.0f;
8377
7922
  float max_bias = 0.0f;
8378
7923
  float logit_softcap = 0.0f;
@@ -8399,6 +7944,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8399
7944
  GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
8400
7945
  GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
8401
7946
 
7947
+ int ith = params->ith;
7948
+
8402
7949
  // loop over n_batch and n_head
8403
7950
  for (int ir = ir0; ir < ir1; ++ir) {
8404
7951
  // q indices
@@ -8530,7 +8077,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8530
8077
  }
8531
8078
 
8532
8079
  // V /= S
8533
- const float S_inv = 1.0f/S;
8080
+ const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
8534
8081
  ggml_vec_scale_f32(DV, VKQ32, S_inv);
8535
8082
 
8536
8083
  // dst indices
@@ -8546,6 +8093,91 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8546
8093
  }
8547
8094
  }
8548
8095
 
8096
+ static void ggml_compute_forward_flash_attn_ext_f16(
8097
+ const ggml_compute_params * params,
8098
+ ggml_tensor * dst) {
8099
+
8100
+ const ggml_tensor * q = dst->src[0];
8101
+ const ggml_tensor * k = dst->src[1];
8102
+ const ggml_tensor * v = dst->src[2];
8103
+
8104
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
8105
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8106
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
8107
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
8108
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
8109
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
8110
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8111
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8112
+
8113
+ const int64_t DK = nek0;
8114
+ const int64_t DV = nev0;
8115
+ const int64_t N = neq1;
8116
+
8117
+ GGML_ASSERT(ne0 == DV);
8118
+ GGML_ASSERT(ne2 == N);
8119
+
8120
+ // input tensor rows must be contiguous
8121
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8122
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8123
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8124
+
8125
+ GGML_ASSERT(neq0 == DK);
8126
+ GGML_ASSERT(nek0 == DK);
8127
+ GGML_ASSERT(nev0 == DV);
8128
+
8129
+ GGML_ASSERT(neq1 == N);
8130
+
8131
+ // dst cannot be transposed or permuted
8132
+ GGML_ASSERT(nb0 == sizeof(float));
8133
+ GGML_ASSERT(nb0 <= nb1);
8134
+ GGML_ASSERT(nb1 <= nb2);
8135
+ GGML_ASSERT(nb2 <= nb3);
8136
+
8137
+ // parallelize by q rows using ggml_vec_dot_f32
8138
+
8139
+ // total rows in q
8140
+ const int64_t nr = neq1*neq2*neq3;
8141
+
8142
+ // rows per thread
8143
+ const int ith = params->ith;
8144
+ const int nth = params->nth;
8145
+
8146
+ // disable for NUMA
8147
+ const bool disable_chunking = ggml_is_numa();
8148
+
8149
+ // 4x chunks per thread
8150
+ int nth_scaled = nth * 4;
8151
+ int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
8152
+ int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
8153
+
8154
+ if (nth == 1 || nchunk < nth || disable_chunking) {
8155
+ nchunk = nth;
8156
+ }
8157
+
8158
+ if (ith == 0) {
8159
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
8160
+ ggml_threadpool_chunk_set(params->threadpool, nth);
8161
+ }
8162
+
8163
+ ggml_barrier(params->threadpool);
8164
+
8165
+ // The number of elements in each chunk
8166
+ const int64_t dr = (nr + nchunk - 1) / nchunk;
8167
+
8168
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
8169
+ int current_chunk = ith;
8170
+
8171
+ while (current_chunk < nchunk) {
8172
+ const int64_t ir0 = dr * current_chunk;
8173
+ const int64_t ir1 = MIN(ir0 + dr, nr);
8174
+
8175
+ ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
8176
+
8177
+ current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
8178
+ }
8179
+ }
8180
+
8549
8181
  void ggml_compute_forward_flash_attn_ext(
8550
8182
  const ggml_compute_params * params,
8551
8183
  ggml_tensor * dst) {
@@ -9032,7 +8664,7 @@ static void ggml_compute_forward_ssm_scan_f32(
9032
8664
  // n_head
9033
8665
  for (int h = ih0; h < ih1; ++h) {
9034
8666
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
9035
- const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8667
+ const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
9036
8668
  const float dA = expf(dt_soft_plus * A[h]);
9037
8669
  const int g = h / (nh / ng); // repeat_interleave
9038
8670
 
@@ -9129,7 +8761,7 @@ static void ggml_compute_forward_ssm_scan_f32(
9129
8761
  // n_head
9130
8762
  for (int h = ih0; h < ih1; ++h) {
9131
8763
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
9132
- const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8764
+ const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
9133
8765
  const int g = h / (nh / ng); // repeat_interleave
9134
8766
 
9135
8767
  // dim
@@ -9392,6 +9024,34 @@ void ggml_compute_forward_unary(
9392
9024
  {
9393
9025
  ggml_compute_forward_exp(params, dst);
9394
9026
  } break;
9027
+ case GGML_UNARY_OP_FLOOR:
9028
+ {
9029
+ ggml_compute_forward_floor(params, dst);
9030
+ } break;
9031
+ case GGML_UNARY_OP_CEIL:
9032
+ {
9033
+ ggml_compute_forward_ceil(params, dst);
9034
+ } break;
9035
+ case GGML_UNARY_OP_ROUND:
9036
+ {
9037
+ ggml_compute_forward_round(params, dst);
9038
+ } break;
9039
+ case GGML_UNARY_OP_TRUNC:
9040
+ {
9041
+ ggml_compute_forward_trunc(params, dst);
9042
+ } break;
9043
+ case GGML_UNARY_OP_XIELU:
9044
+ {
9045
+ ggml_compute_forward_xielu(params, dst);
9046
+ } break;
9047
+ case GGML_UNARY_OP_EXPM1:
9048
+ {
9049
+ ggml_compute_forward_expm1(params, dst);
9050
+ } break;
9051
+ case GGML_UNARY_OP_SOFTPLUS:
9052
+ {
9053
+ ggml_compute_forward_softplus(params, dst);
9054
+ } break;
9395
9055
  default:
9396
9056
  {
9397
9057
  GGML_ABORT("fatal error");
@@ -9988,6 +9648,75 @@ void ggml_compute_forward_gla(
9988
9648
  }
9989
9649
  }
9990
9650
 
9651
+ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
9652
+ const struct ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
9653
+ const struct ggml_tensor * src1 = dst->src[1]; // B (RHS)
9654
+
9655
+ GGML_TENSOR_BINARY_OP_LOCALS;
9656
+
9657
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
9658
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9659
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
9660
+
9661
+ GGML_ASSERT(ne00 == ne01); // A must be square
9662
+ GGML_ASSERT(ne0 == ne10); // solution cols == B cols
9663
+ GGML_ASSERT(ne1 == ne11); // solution rows == B rows
9664
+
9665
+ GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
9666
+ GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
9667
+
9668
+ const int ith = params->ith;
9669
+ const int nth = params->nth;
9670
+
9671
+ const int64_t k = ne10; // number of RHS columns
9672
+ const int64_t n = ne11; // A is n×n
9673
+ const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
9674
+
9675
+ // chunks per thread
9676
+ const int64_t dr = (nr + nth - 1)/nth;
9677
+
9678
+ // chunk range for this thread
9679
+ const int64_t ir0 = dr*ith;
9680
+ const int64_t ir1 = MIN(ir0 + dr, nr);
9681
+
9682
+ const float * A = (const float *) src0->data; // [n, n, B1, B2]
9683
+ const float * B = (const float *) src1->data; // [n, k, B1, B2]
9684
+ float * X = ( float *) dst->data; // [n, k, B1, B2]
9685
+
9686
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
9687
+ const int64_t i03 = ir/(ne02*k);
9688
+ const int64_t i02 = (ir - i03*ne02*k)/k;
9689
+ const int64_t i01 = (ir - i03*ne02*k - i02*k);
9690
+
9691
+ const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
9692
+ const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
9693
+
9694
+ float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
9695
+
9696
+ for (int64_t i00 = 0; i00 < n; ++i00) {
9697
+ float sum = 0.0f;
9698
+ for (int64_t t = 0; t < i00; ++t) {
9699
+ sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
9700
+ }
9701
+
9702
+ const float diag = A_batch[i00 * n + i00];
9703
+ GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
9704
+ X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
9705
+ }
9706
+ }
9707
+ }
9708
+
9709
+ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
9710
+ const ggml_tensor * src0 = dst->src[0];
9711
+ const ggml_tensor * src1 = dst->src[1];
9712
+
9713
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
9714
+ ggml_compute_forward_solve_tri_f32(params, dst);
9715
+ } else {
9716
+ GGML_ABORT("fatal error");
9717
+ }
9718
+ }
9719
+
9991
9720
  // ggml_compute_forward_rwkv_wkv7
9992
9721
 
9993
9722
  static void ggml_compute_forward_rwkv_wkv7_f32(