@novastera-oss/llamarn 0.2.7 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (319) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  13. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  18. package/cpp/LlamaCppModel.cpp +56 -22
  19. package/cpp/build-info.cpp +2 -2
  20. package/cpp/llama.cpp/CMakeLists.txt +1 -2
  21. package/cpp/llama.cpp/README.md +4 -5
  22. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  23. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  24. package/cpp/llama.cpp/common/arg.cpp +24 -0
  25. package/cpp/llama.cpp/common/chat.cpp +37 -20
  26. package/cpp/llama.cpp/common/chat.h +2 -0
  27. package/cpp/llama.cpp/common/common.cpp +3 -0
  28. package/cpp/llama.cpp/common/common.h +5 -0
  29. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  30. package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
  31. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  32. package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
  33. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  34. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  35. package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
  36. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  100. package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
  101. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
  103. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
  104. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
  105. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  133. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  134. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  135. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  136. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  137. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
  138. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  141. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  142. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  144. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  168. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  169. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  170. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  172. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  173. package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
  177. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
  178. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
  179. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  180. package/cpp/llama.cpp/include/llama.h +8 -43
  181. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  182. package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
  183. package/cpp/llama.cpp/src/llama-arch.h +36 -1
  184. package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
  185. package/cpp/llama.cpp/src/llama-batch.h +105 -70
  186. package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
  187. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  188. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  189. package/cpp/llama.cpp/src/llama-context.h +13 -13
  190. package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
  191. package/cpp/llama.cpp/src/llama-graph.h +78 -79
  192. package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
  193. package/cpp/llama.cpp/src/llama-hparams.h +11 -0
  194. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
  195. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  196. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
  197. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
  198. package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
  199. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
  200. package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
  201. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
  202. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  203. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  204. package/cpp/llama.cpp/src/llama-memory.h +21 -22
  205. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  206. package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
  207. package/cpp/llama.cpp/src/llama-model.h +40 -0
  208. package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
  209. package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
  210. package/cpp/llama.cpp/src/llama-vocab.h +42 -0
  211. package/cpp/rn-utils.h +3 -0
  212. package/ios/include/chat.h +2 -0
  213. package/ios/include/common.h +5 -0
  214. package/ios/include/llama.h +8 -43
  215. package/ios/libs/llama.xcframework/Info.plist +19 -19
  216. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  217. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  218. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  219. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  220. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
  221. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  223. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  224. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  225. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  231. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  232. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  248. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  250. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
  254. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
  255. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  261. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  262. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  263. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  264. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  265. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  267. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  268. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
  269. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
  270. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
  274. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  275. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  276. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  277. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  278. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  279. package/package.json +1 -1
  280. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  315. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  316. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  317. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  318. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  319. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -61,9 +61,6 @@
61
61
  #define m512i(p) (__m512i)(p)
62
62
  #endif
63
63
 
64
- // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
65
- float ggml_table_f32_f16[1 << 16];
66
-
67
64
  #if defined(__linux__) || \
68
65
  defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \
69
66
  (defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH)
@@ -205,19 +202,34 @@ void ggml_print_backtrace(void) {
205
202
  }
206
203
  #endif
207
204
 
205
+ static ggml_abort_callback_t g_abort_callback = NULL;
206
+
207
+ // Set the abort callback (passing null will restore original abort functionality: printing a message to stdout)
208
+ GGML_API ggml_abort_callback_t ggml_set_abort_callback(ggml_abort_callback_t callback) {
209
+ ggml_abort_callback_t ret_val = g_abort_callback;
210
+ g_abort_callback = callback;
211
+ return ret_val;
212
+ }
213
+
208
214
  void ggml_abort(const char * file, int line, const char * fmt, ...) {
209
215
  fflush(stdout);
210
216
 
211
- fprintf(stderr, "%s:%d: ", file, line);
217
+ char message[2048];
218
+ int offset = snprintf(message, sizeof(message), "%s:%d: ", file, line);
212
219
 
213
220
  va_list args;
214
221
  va_start(args, fmt);
215
- vfprintf(stderr, fmt, args);
222
+ vsnprintf(message + offset, sizeof(message) - offset, fmt, args);
216
223
  va_end(args);
217
224
 
218
- fprintf(stderr, "\n");
225
+ if (g_abort_callback) {
226
+ g_abort_callback(message);
227
+ } else {
228
+ // default: print error and backtrace to stderr
229
+ fprintf(stderr, "%s\n", message);
230
+ ggml_print_backtrace();
231
+ }
219
232
 
220
- ggml_print_backtrace();
221
233
  abort();
222
234
  }
223
235
 
@@ -461,6 +473,14 @@ bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
461
473
  return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
462
474
  }
463
475
 
476
+ const char * ggml_version(void) {
477
+ return GGML_VERSION;
478
+ }
479
+
480
+ const char * ggml_commit(void) {
481
+ return GGML_COMMIT;
482
+ }
483
+
464
484
  //
465
485
  // timing
466
486
  //
@@ -936,6 +956,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
936
956
  "TRANSPOSE",
937
957
  "GET_ROWS",
938
958
  "GET_ROWS_BACK",
959
+ "SET_ROWS",
939
960
  "DIAG",
940
961
  "DIAG_MASK_INF",
941
962
  "DIAG_MASK_ZERO",
@@ -947,6 +968,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
947
968
  "CONV_TRANSPOSE_1D",
948
969
  "IM2COL",
949
970
  "IM2COL_BACK",
971
+ "CONV_2D",
950
972
  "CONV_2D_DW",
951
973
  "CONV_TRANSPOSE_2D",
952
974
  "POOL_1D",
@@ -955,6 +977,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
955
977
  "UPSCALE",
956
978
  "PAD",
957
979
  "PAD_REFLECT_1D",
980
+ "ROLL",
958
981
  "ARANGE",
959
982
  "TIMESTEP_EMBEDDING",
960
983
  "ARGSORT",
@@ -983,9 +1006,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
983
1006
  "CROSS_ENTROPY_LOSS",
984
1007
  "CROSS_ENTROPY_LOSS_BACK",
985
1008
  "OPT_STEP_ADAMW",
1009
+
1010
+ "GLU",
986
1011
  };
987
1012
 
988
- static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
1013
+ static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
989
1014
 
990
1015
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
991
1016
  "none",
@@ -1031,6 +1056,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1031
1056
  "transpose(x)",
1032
1057
  "get_rows(x)",
1033
1058
  "get_rows_back(x)",
1059
+ "set_rows(x)",
1034
1060
  "diag(x)",
1035
1061
  "diag_mask_inf(x)",
1036
1062
  "diag_mask_zero(x)",
@@ -1042,6 +1068,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1042
1068
  "conv_transpose_1d(x)",
1043
1069
  "im2col(x)",
1044
1070
  "im2col_back(x)",
1071
+ "conv_2d(x)",
1045
1072
  "conv_2d_dw(x)",
1046
1073
  "conv_transpose_2d(x)",
1047
1074
  "pool_1d(x)",
@@ -1050,6 +1077,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1050
1077
  "upscale(x)",
1051
1078
  "pad(x)",
1052
1079
  "pad_reflect_1d(x)",
1080
+ "roll(x)",
1053
1081
  "arange(start, stop, step)",
1054
1082
  "timestep_embedding(timesteps, dim, max_period)",
1055
1083
  "argsort(x)",
@@ -1078,9 +1106,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1078
1106
  "cross_entropy_loss(x,y)",
1079
1107
  "cross_entropy_loss_back(x,y)",
1080
1108
  "adamw(x)",
1109
+
1110
+ "glu(x)",
1081
1111
  };
1082
1112
 
1083
- static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
1113
+ static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
1084
1114
 
1085
1115
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1086
1116
 
@@ -1106,6 +1136,17 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
1106
1136
  static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
1107
1137
 
1108
1138
 
1139
+ static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
1140
+ "REGLU",
1141
+ "GEGLU",
1142
+ "SWIGLU",
1143
+ "GEGLU_ERF",
1144
+ "GEGLU_QUICK",
1145
+ };
1146
+
1147
+ static_assert(GGML_GLU_OP_COUNT == 5, "GGML_GLU_OP_COUNT != 5");
1148
+
1149
+
1109
1150
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
1110
1151
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
1111
1152
 
@@ -1208,11 +1249,19 @@ const char * ggml_unary_op_name(enum ggml_unary_op op) {
1208
1249
  return GGML_UNARY_OP_NAME[op];
1209
1250
  }
1210
1251
 
1252
+ const char * ggml_glu_op_name(enum ggml_glu_op op) {
1253
+ return GGML_GLU_OP_NAME[op];
1254
+ }
1255
+
1211
1256
  const char * ggml_op_desc(const struct ggml_tensor * t) {
1212
1257
  if (t->op == GGML_OP_UNARY) {
1213
1258
  enum ggml_unary_op uop = ggml_get_unary_op(t);
1214
1259
  return ggml_unary_op_name(uop);
1215
1260
  }
1261
+ if (t->op == GGML_OP_GLU) {
1262
+ enum ggml_glu_op gop = ggml_get_glu_op(t);
1263
+ return ggml_glu_op_name(gop);
1264
+ }
1216
1265
  return ggml_op_name(t->op);
1217
1266
  }
1218
1267
 
@@ -1349,6 +1398,12 @@ bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
1349
1398
  tensor->nb[2] == ggml_type_size(tensor->type);
1350
1399
  }
1351
1400
 
1401
+ bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor) {
1402
+ return
1403
+ tensor->ne[0] == ggml_blck_size(tensor->type) ||
1404
+ tensor->nb[0] == ggml_type_size(tensor->type);
1405
+ }
1406
+
1352
1407
  static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
1353
1408
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1354
1409
 
@@ -1420,14 +1475,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
1420
1475
  // initialize time system (required on Windows)
1421
1476
  ggml_time_init();
1422
1477
 
1423
- for (int i = 0; i < (1 << 16); ++i) {
1424
- union {
1425
- uint16_t u16;
1426
- ggml_fp16_t fp16;
1427
- } u = {i};
1428
- ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
1429
- }
1430
-
1431
1478
  is_first_call = false;
1432
1479
  }
1433
1480
 
@@ -1731,6 +1778,11 @@ enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
1731
1778
  return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
1732
1779
  }
1733
1780
 
1781
+ enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) {
1782
+ GGML_ASSERT(tensor->op == GGML_OP_GLU);
1783
+ return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0);
1784
+ }
1785
+
1734
1786
  const char * ggml_get_name(const struct ggml_tensor * tensor) {
1735
1787
  return tensor->name;
1736
1788
  }
@@ -2610,6 +2662,156 @@ struct ggml_tensor * ggml_exp_inplace(
2610
2662
  return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
2611
2663
  }
2612
2664
 
2665
+ // ggml_glu
2666
+
2667
+ static struct ggml_tensor * ggml_glu_impl(
2668
+ struct ggml_context * ctx,
2669
+ struct ggml_tensor * a,
2670
+ struct ggml_tensor * b,
2671
+ enum ggml_glu_op op,
2672
+ bool swapped) {
2673
+ GGML_ASSERT(ggml_is_contiguous_1(a));
2674
+
2675
+ if (b) {
2676
+ GGML_ASSERT(ggml_is_contiguous_1(b));
2677
+ GGML_ASSERT(ggml_are_same_shape(a, b));
2678
+ GGML_ASSERT(a->type == b->type);
2679
+ }
2680
+
2681
+ int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
2682
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);
2683
+
2684
+ ggml_set_op_params_i32(result, 0, (int32_t) op);
2685
+ ggml_set_op_params_i32(result, 1, (int32_t) swapped);
2686
+
2687
+ result->op = GGML_OP_GLU;
2688
+ result->src[0] = a;
2689
+ result->src[1] = b;
2690
+
2691
+ return result;
2692
+ }
2693
+
2694
+ struct ggml_tensor * ggml_glu(
2695
+ struct ggml_context * ctx,
2696
+ struct ggml_tensor * a,
2697
+ enum ggml_glu_op op,
2698
+ bool swapped) {
2699
+ return ggml_glu_impl(ctx, a, NULL, op, swapped);
2700
+ }
2701
+
2702
+ struct ggml_tensor * ggml_glu_split(
2703
+ struct ggml_context * ctx,
2704
+ struct ggml_tensor * a,
2705
+ struct ggml_tensor * b,
2706
+ enum ggml_glu_op op) {
2707
+ return ggml_glu_impl(ctx, a, b, op, false);
2708
+ }
2709
+
2710
+ // ggml_reglu
2711
+
2712
+ struct ggml_tensor * ggml_reglu(
2713
+ struct ggml_context * ctx,
2714
+ struct ggml_tensor * a) {
2715
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, false);
2716
+ }
2717
+
2718
+ struct ggml_tensor * ggml_reglu_swapped(
2719
+ struct ggml_context * ctx,
2720
+ struct ggml_tensor * a) {
2721
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, true);
2722
+ }
2723
+
2724
+ struct ggml_tensor * ggml_reglu_split(
2725
+ struct ggml_context * ctx,
2726
+ struct ggml_tensor * a,
2727
+ struct ggml_tensor * b) {
2728
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_REGLU, false);
2729
+ }
2730
+
2731
+ // ggml_geglu
2732
+
2733
+ struct ggml_tensor * ggml_geglu(
2734
+ struct ggml_context * ctx,
2735
+ struct ggml_tensor * a) {
2736
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, false);
2737
+ }
2738
+
2739
+ struct ggml_tensor * ggml_geglu_swapped(
2740
+ struct ggml_context * ctx,
2741
+ struct ggml_tensor * a) {
2742
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, true);
2743
+ }
2744
+
2745
+ struct ggml_tensor * ggml_geglu_split(
2746
+ struct ggml_context * ctx,
2747
+ struct ggml_tensor * a,
2748
+ struct ggml_tensor * b) {
2749
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU, false);
2750
+ }
2751
+
2752
+ // ggml_swiglu
2753
+
2754
+ struct ggml_tensor * ggml_swiglu(
2755
+ struct ggml_context * ctx,
2756
+ struct ggml_tensor * a) {
2757
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, false);
2758
+ }
2759
+
2760
+ struct ggml_tensor * ggml_swiglu_swapped(
2761
+ struct ggml_context * ctx,
2762
+ struct ggml_tensor * a) {
2763
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, true);
2764
+ }
2765
+
2766
+ struct ggml_tensor * ggml_swiglu_split(
2767
+ struct ggml_context * ctx,
2768
+ struct ggml_tensor * a,
2769
+ struct ggml_tensor * b) {
2770
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
2771
+ }
2772
+
2773
+ // ggml_geglu_erf
2774
+
2775
+ struct ggml_tensor * ggml_geglu_erf(
2776
+ struct ggml_context * ctx,
2777
+ struct ggml_tensor * a) {
2778
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, false);
2779
+ }
2780
+
2781
+ struct ggml_tensor * ggml_geglu_erf_swapped(
2782
+ struct ggml_context * ctx,
2783
+ struct ggml_tensor * a) {
2784
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, true);
2785
+ }
2786
+
2787
+ struct ggml_tensor * ggml_geglu_erf_split(
2788
+ struct ggml_context * ctx,
2789
+ struct ggml_tensor * a,
2790
+ struct ggml_tensor * b) {
2791
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_ERF, false);
2792
+ }
2793
+
2794
+ // ggml_geglu_quick
2795
+
2796
+ struct ggml_tensor * ggml_geglu_quick(
2797
+ struct ggml_context * ctx,
2798
+ struct ggml_tensor * a) {
2799
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, false);
2800
+ }
2801
+
2802
+ struct ggml_tensor * ggml_geglu_quick_swapped(
2803
+ struct ggml_context * ctx,
2804
+ struct ggml_tensor * a) {
2805
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, true);
2806
+ }
2807
+
2808
+ struct ggml_tensor * ggml_geglu_quick_split(
2809
+ struct ggml_context * ctx,
2810
+ struct ggml_tensor * a,
2811
+ struct ggml_tensor * b) {
2812
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);
2813
+ }
2814
+
2613
2815
  // ggml_norm
2614
2816
 
2615
2817
  static struct ggml_tensor * ggml_norm_impl(
@@ -2867,12 +3069,14 @@ static struct ggml_tensor * ggml_scale_impl(
2867
3069
  struct ggml_context * ctx,
2868
3070
  struct ggml_tensor * a,
2869
3071
  float s,
3072
+ float b,
2870
3073
  bool inplace) {
2871
3074
  GGML_ASSERT(ggml_is_padded_1d(a));
2872
3075
 
2873
3076
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
2874
3077
 
2875
- ggml_set_op_params(result, &s, sizeof(s));
3078
+ float params[2] = { s, b };
3079
+ ggml_set_op_params(result, &params, sizeof(params));
2876
3080
 
2877
3081
  result->op = GGML_OP_SCALE;
2878
3082
  result->src[0] = a;
@@ -2884,14 +3088,30 @@ struct ggml_tensor * ggml_scale(
2884
3088
  struct ggml_context * ctx,
2885
3089
  struct ggml_tensor * a,
2886
3090
  float s) {
2887
- return ggml_scale_impl(ctx, a, s, false);
3091
+ return ggml_scale_impl(ctx, a, s, 0.0, false);
2888
3092
  }
2889
3093
 
2890
3094
  struct ggml_tensor * ggml_scale_inplace(
2891
3095
  struct ggml_context * ctx,
2892
3096
  struct ggml_tensor * a,
2893
3097
  float s) {
2894
- return ggml_scale_impl(ctx, a, s, true);
3098
+ return ggml_scale_impl(ctx, a, s, 0.0, true);
3099
+ }
3100
+
3101
+ struct ggml_tensor * ggml_scale_bias(
3102
+ struct ggml_context * ctx,
3103
+ struct ggml_tensor * a,
3104
+ float s,
3105
+ float b) {
3106
+ return ggml_scale_impl(ctx, a, s, b, false);
3107
+ }
3108
+
3109
+ struct ggml_tensor * ggml_scale_bias_inplace(
3110
+ struct ggml_context * ctx,
3111
+ struct ggml_tensor * a,
3112
+ float s,
3113
+ float b) {
3114
+ return ggml_scale_impl(ctx, a, s, b, true);
2895
3115
  }
2896
3116
 
2897
3117
  // ggml_set
@@ -3393,6 +3613,35 @@ struct ggml_tensor * ggml_get_rows_back(
3393
3613
  return result;
3394
3614
  }
3395
3615
 
3616
+ // ggml_set_rows
3617
+
3618
+ struct ggml_tensor * ggml_set_rows(
3619
+ struct ggml_context * ctx,
3620
+ struct ggml_tensor * a,
3621
+ struct ggml_tensor * b,
3622
+ struct ggml_tensor * c) {
3623
+ GGML_ASSERT(a->ne[0] == b->ne[0]);
3624
+ GGML_ASSERT(a->ne[2] == b->ne[2]);
3625
+ GGML_ASSERT(a->ne[3] == b->ne[3]);
3626
+ GGML_ASSERT(b->ne[1] == c->ne[0]);
3627
+ GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
3628
+ GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
3629
+ GGML_ASSERT(c->ne[3] == 1);
3630
+ GGML_ASSERT(b->type == GGML_TYPE_F32);
3631
+ GGML_ASSERT(c->type == GGML_TYPE_I64);
3632
+
3633
+ GGML_ASSERT(ggml_is_contiguous_rows(a));
3634
+ GGML_ASSERT(ggml_is_contiguous_rows(b));
3635
+
3636
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
3637
+
3638
+ result->op = GGML_OP_SET_ROWS;
3639
+ result->src[0] = b;
3640
+ result->src[1] = c;
3641
+
3642
+ return result;
3643
+ }
3644
+
3396
3645
  // ggml_diag
3397
3646
 
3398
3647
  struct ggml_tensor * ggml_diag(
@@ -3487,9 +3736,10 @@ static struct ggml_tensor * ggml_soft_max_impl(
3487
3736
  if (mask) {
3488
3737
  GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
3489
3738
  GGML_ASSERT(ggml_is_contiguous(mask));
3490
- GGML_ASSERT(ggml_is_matrix(mask));
3491
3739
  GGML_ASSERT(mask->ne[0] == a->ne[0]);
3492
3740
  GGML_ASSERT(mask->ne[1] >= a->ne[1]);
3741
+ GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
3742
+ GGML_ASSERT(a->ne[3]%mask->ne[3] == 0);
3493
3743
  }
3494
3744
 
3495
3745
  if (max_bias > 0.0f) {
@@ -4129,6 +4379,44 @@ struct ggml_tensor * ggml_conv_2d_dw_direct(
4129
4379
  return result;
4130
4380
  }
4131
4381
 
4382
+ // ggml_conv_2d_direct
4383
+
4384
+ struct ggml_tensor * ggml_conv_2d_direct(
4385
+ struct ggml_context * ctx,
4386
+ struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
4387
+ struct ggml_tensor * b, // input data [W, H, C, N]
4388
+ int s0, // stride dimension 0
4389
+ int s1, // stride dimension 1
4390
+ int p0, // padding dimension 0
4391
+ int p1, // padding dimension 1
4392
+ int d0, // dilation dimension 0
4393
+ int d1) {// dilation dimension 1
4394
+
4395
+ GGML_ASSERT(a->ne[2] == b->ne[2]);
4396
+ //GGML_ASSERT(a->type == b->type);
4397
+
4398
+ int64_t ne[4];
4399
+ ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
4400
+ ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
4401
+ ne[2] = a->ne[3];
4402
+ ne[3] = b->ne[3];
4403
+
4404
+ struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
4405
+
4406
+ ggml_set_op_params_i32(result, 0, s0);
4407
+ ggml_set_op_params_i32(result, 1, s1);
4408
+ ggml_set_op_params_i32(result, 2, p0);
4409
+ ggml_set_op_params_i32(result, 3, p1);
4410
+ ggml_set_op_params_i32(result, 4, d0);
4411
+ ggml_set_op_params_i32(result, 5, d1);
4412
+
4413
+ result->op = GGML_OP_CONV_2D;
4414
+ result->src[0] = a;
4415
+ result->src[1] = b;
4416
+
4417
+ return result;
4418
+ }
4419
+
4132
4420
  // ggml_conv_transpose_2d_p0
4133
4421
 
4134
4422
  static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
@@ -4245,24 +4533,21 @@ struct ggml_tensor * ggml_pool_2d_back(
4245
4533
  return result;
4246
4534
  }
4247
4535
 
4248
- // ggml_upscale
4536
+ // ggml_upscale / ggml_interpolate
4249
4537
 
4250
- static struct ggml_tensor * ggml_upscale_impl(
4538
+ static struct ggml_tensor * ggml_interpolate_impl(
4251
4539
  struct ggml_context * ctx,
4252
4540
  struct ggml_tensor * a,
4253
- int ne0,
4254
- int ne1,
4255
- int ne2,
4256
- int ne3,
4257
- enum ggml_scale_mode mode) {
4258
- GGML_ASSERT(a->ne[0] <= ne0);
4259
- GGML_ASSERT(a->ne[1] <= ne1);
4260
- GGML_ASSERT(a->ne[2] <= ne2);
4261
- GGML_ASSERT(a->ne[3] <= ne3);
4541
+ int64_t ne0,
4542
+ int64_t ne1,
4543
+ int64_t ne2,
4544
+ int64_t ne3,
4545
+ uint32_t mode) {
4546
+ GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT);
4262
4547
 
4263
4548
  struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
4264
4549
 
4265
- ggml_set_op_params_i32(result, 0, mode);
4550
+ ggml_set_op_params_i32(result, 0, (int32_t)mode);
4266
4551
 
4267
4552
  result->op = GGML_OP_UPSCALE;
4268
4553
  result->src[0] = a;
@@ -4275,7 +4560,8 @@ struct ggml_tensor * ggml_upscale(
4275
4560
  struct ggml_tensor * a,
4276
4561
  int scale_factor,
4277
4562
  enum ggml_scale_mode mode) {
4278
- return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
4563
+ GGML_ASSERT(scale_factor > 1);
4564
+ return ggml_interpolate_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
4279
4565
  }
4280
4566
 
4281
4567
  struct ggml_tensor * ggml_upscale_ext(
@@ -4286,7 +4572,18 @@ struct ggml_tensor * ggml_upscale_ext(
4286
4572
  int ne2,
4287
4573
  int ne3,
4288
4574
  enum ggml_scale_mode mode) {
4289
- return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4575
+ return ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4576
+ }
4577
+
4578
+ struct ggml_tensor * ggml_interpolate(
4579
+ struct ggml_context * ctx,
4580
+ struct ggml_tensor * a,
4581
+ int64_t ne0,
4582
+ int64_t ne1,
4583
+ int64_t ne2,
4584
+ int64_t ne3,
4585
+ uint32_t mode) {
4586
+ return ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4290
4587
  }
4291
4588
 
4292
4589
  // ggml_pad
@@ -4341,6 +4638,34 @@ struct ggml_tensor * ggml_pad_reflect_1d(
4341
4638
  return result;
4342
4639
  }
4343
4640
 
4641
+ // ggml_roll
4642
+
4643
+ struct ggml_tensor * ggml_roll(
4644
+ struct ggml_context * ctx,
4645
+ struct ggml_tensor * a,
4646
+ int shift0,
4647
+ int shift1,
4648
+ int shift2,
4649
+ int shift3) {
4650
+ GGML_ASSERT(a->nb[0] == ggml_type_size(a->type));
4651
+ GGML_ASSERT(abs(shift0) < a->ne[0]);
4652
+ GGML_ASSERT(abs(shift1) < a->ne[1]);
4653
+ GGML_ASSERT(abs(shift2) < a->ne[2]);
4654
+ GGML_ASSERT(abs(shift3) < a->ne[3]);
4655
+
4656
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
4657
+
4658
+ ggml_set_op_params_i32(result, 0, shift0);
4659
+ ggml_set_op_params_i32(result, 1, shift1);
4660
+ ggml_set_op_params_i32(result, 2, shift2);
4661
+ ggml_set_op_params_i32(result, 3, shift3);
4662
+
4663
+ result->op = GGML_OP_ROLL;
4664
+ result->src[0] = a;
4665
+
4666
+ return result;
4667
+ }
4668
+
4344
4669
  // ggml_arange
4345
4670
 
4346
4671
  struct ggml_tensor * ggml_arange(
@@ -4435,13 +4760,17 @@ struct ggml_tensor * ggml_flash_attn_ext(
4435
4760
  GGML_ASSERT(ggml_can_mul_mat(k, q));
4436
4761
  // TODO: check if vT can be multiplied by (k*qT)
4437
4762
 
4763
+ GGML_ASSERT(q->ne[3] == k->ne[3]);
4764
+ GGML_ASSERT(q->ne[3] == v->ne[3]);
4765
+
4438
4766
  if (mask) {
4439
4767
  GGML_ASSERT(ggml_is_contiguous(mask));
4440
- GGML_ASSERT(mask->ne[2] == 1);
4441
- GGML_ASSERT(mask->ne[3] == 1);
4442
4768
  GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
4443
4769
  "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
4444
4770
  //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
4771
+
4772
+ GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
4773
+ GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
4445
4774
  }
4446
4775
 
4447
4776
  if (max_bias > 0.0f) {
@@ -4569,7 +4898,6 @@ struct ggml_tensor * ggml_ssm_conv(
4569
4898
  const int64_t n_s = sx->ne[2];
4570
4899
 
4571
4900
  // TODO: maybe support other strides than 1?
4572
- // FIXME: this is always true?
4573
4901
  GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
4574
4902
  GGML_ASSERT(sx->ne[1] == d_inner);
4575
4903
  GGML_ASSERT(n_t >= 0);
@@ -4592,36 +4920,49 @@ struct ggml_tensor * ggml_ssm_scan(
4592
4920
  struct ggml_tensor * dt,
4593
4921
  struct ggml_tensor * A,
4594
4922
  struct ggml_tensor * B,
4595
- struct ggml_tensor * C) {
4923
+ struct ggml_tensor * C,
4924
+ struct ggml_tensor * ids) {
4596
4925
  GGML_ASSERT(ggml_is_contiguous(s));
4597
- GGML_ASSERT(ggml_is_contiguous(x));
4598
4926
  GGML_ASSERT(ggml_is_contiguous(dt));
4599
4927
  GGML_ASSERT(ggml_is_contiguous(A));
4600
- GGML_ASSERT(ggml_is_matrix(A));
4601
- GGML_ASSERT(ggml_is_3d(B));
4602
- GGML_ASSERT(ggml_is_3d(s));
4928
+ GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
4603
4929
  GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
4604
4930
  GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
4605
- GGML_ASSERT(ggml_are_same_shape(x, dt));
4931
+ GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
4932
+ GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
4933
+ GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
4606
4934
  GGML_ASSERT(ggml_are_same_shape(B, C));
4935
+ GGML_ASSERT(ids->type == GGML_TYPE_I32);
4607
4936
 
4608
4937
  {
4609
4938
  const int64_t d_state = s->ne[0];
4610
- const int64_t d_inner = s->ne[1];
4611
- const int64_t n_seq_tokens = x->ne[1];
4612
- const int64_t n_seqs = x->ne[2];
4613
-
4614
- GGML_ASSERT(s->ne[2] == n_seqs);
4615
- GGML_ASSERT(x->ne[0] == d_inner);
4616
- GGML_ASSERT(A->ne[0] == d_state);
4617
- GGML_ASSERT(A->ne[1] == d_inner);
4939
+ const int64_t head_dim = x->ne[0];
4940
+ const int64_t n_head = x->ne[1];
4941
+ const int64_t n_seq_tokens = x->ne[2];
4942
+ const int64_t n_seqs = x->ne[3];
4943
+
4944
+ GGML_ASSERT(dt->ne[0] == n_head);
4945
+ GGML_ASSERT(dt->ne[1] == n_seq_tokens);
4946
+ GGML_ASSERT(dt->ne[2] == n_seqs);
4947
+ GGML_ASSERT(ggml_is_3d(dt));
4948
+ GGML_ASSERT(s->ne[1] == head_dim);
4949
+ GGML_ASSERT(s->ne[2] == n_head);
4618
4950
  GGML_ASSERT(B->ne[0] == d_state);
4619
- GGML_ASSERT(B->ne[1] == n_seq_tokens);
4620
- GGML_ASSERT(B->ne[2] == n_seqs);
4951
+ GGML_ASSERT(B->ne[2] == n_seq_tokens);
4952
+ GGML_ASSERT(B->ne[3] == n_seqs);
4953
+ GGML_ASSERT(ids->ne[0] == n_seqs);
4954
+ GGML_ASSERT(ggml_is_vector(ids));
4955
+ GGML_ASSERT(A->ne[1] == n_head);
4956
+ GGML_ASSERT(ggml_is_matrix(A));
4957
+
4958
+ if (A->ne[0] != 1) {
4959
+ // Mamba-1 has more granular decay factors
4960
+ GGML_ASSERT(A->ne[0] == d_state);
4961
+ }
4621
4962
  }
4622
4963
 
4623
4964
  // concatenated y + ssm_states
4624
- struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
4965
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
4625
4966
 
4626
4967
  result->op = GGML_OP_SSM_SCAN;
4627
4968
  result->src[0] = s;
@@ -4630,6 +4971,7 @@ struct ggml_tensor * ggml_ssm_scan(
4630
4971
  result->src[3] = A;
4631
4972
  result->src[4] = B;
4632
4973
  result->src[5] = C;
4974
+ result->src[6] = ids;
4633
4975
 
4634
4976
  return result;
4635
4977
  }
@@ -5453,7 +5795,7 @@ static void ggml_compute_backward(
5453
5795
  } break;
5454
5796
  case GGML_OP_MEAN: {
5455
5797
  if (src0_needs_grads) {
5456
- ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
5798
+ ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false));
5457
5799
  }
5458
5800
  } break;
5459
5801
  case GGML_OP_REPEAT: {
@@ -5530,7 +5872,7 @@ static void ggml_compute_backward(
5530
5872
  if (src0_needs_grads) {
5531
5873
  float s;
5532
5874
  memcpy(&s, tensor->op_params, sizeof(float));
5533
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false));
5875
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, 0.0, false));
5534
5876
  }
5535
5877
  } break;
5536
5878
  case GGML_OP_SET: {
@@ -5770,13 +6112,28 @@ static void ggml_compute_backward(
5770
6112
  }
5771
6113
  GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
5772
6114
  } break;
6115
+ case GGML_OP_GLU: {
6116
+ switch (ggml_get_glu_op(tensor)) {
6117
+ case GGML_GLU_OP_SWIGLU: {
6118
+ if (src0_needs_grads) {
6119
+ GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
6120
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
6121
+ }
6122
+ if (src1_needs_grads) {
6123
+ ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
6124
+ }
6125
+ } break;
6126
+ default: {
6127
+ GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
6128
+ } //break;
6129
+ }
6130
+ } break;
5773
6131
  case GGML_OP_NONE: {
5774
6132
  // noop
5775
6133
  } break;
5776
6134
  case GGML_OP_COUNT:
5777
6135
  default: {
5778
- fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
5779
- GGML_ABORT("fatal error");
6136
+ GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
5780
6137
  } //break;
5781
6138
  }
5782
6139
 
@@ -5785,19 +6142,32 @@ static void ggml_compute_backward(
5785
6142
  GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
5786
6143
  }
5787
6144
 
5788
- static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
6145
+ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
5789
6146
  // check if already visited
5790
- if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) {
5791
- return;
6147
+ size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
6148
+ GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL);
6149
+ if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
6150
+ // This is the first time we see this node in the current graph.
6151
+ cgraph->visited_hash_set.keys[node_hash_pos] = node;
6152
+ ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
6153
+ cgraph->use_counts[node_hash_pos] = 0;
6154
+ } else {
6155
+ // already visited
6156
+ return node_hash_pos;
5792
6157
  }
5793
6158
 
5794
6159
  for (int i = 0; i < GGML_MAX_SRC; ++i) {
5795
6160
  const int k =
5796
6161
  (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
5797
6162
  (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
5798
- /* unknown order, just fall back to using i*/ i;
5799
- if (node->src[k]) {
5800
- ggml_visit_parents(cgraph, node->src[k]);
6163
+ /* unknown order, just fall back to using i */ i;
6164
+
6165
+ struct ggml_tensor * src = node->src[k];
6166
+ if (src) {
6167
+ size_t src_hash_pos = ggml_visit_parents(cgraph, src);
6168
+
6169
+ // Update the use count for this operand.
6170
+ cgraph->use_counts[src_hash_pos]++;
5801
6171
  }
5802
6172
  }
5803
6173
 
@@ -5821,6 +6191,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
5821
6191
  cgraph->nodes[cgraph->n_nodes] = node;
5822
6192
  cgraph->n_nodes++;
5823
6193
  }
6194
+
6195
+ return node_hash_pos;
5824
6196
  }
5825
6197
 
5826
6198
  static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
@@ -5958,6 +6330,7 @@ static size_t ggml_graph_nbytes(size_t size, bool grads) {
5958
6330
  incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
5959
6331
  incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
5960
6332
  incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
6333
+ incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // use_counts
5961
6334
  incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
5962
6335
  if (grads) {
5963
6336
  incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
@@ -5987,11 +6360,12 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
5987
6360
 
5988
6361
  void * p = cgraph + 1;
5989
6362
 
5990
- struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
5991
- struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
5992
- struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
5993
- struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
5994
- struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
6363
+ struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6364
+ struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6365
+ int32_t * use_counts_ptr = incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t));
6366
+ struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6367
+ struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
6368
+ struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
5995
6369
 
5996
6370
  ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
5997
6371
 
@@ -6006,6 +6380,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
6006
6380
  /*.grads =*/ grads_ptr,
6007
6381
  /*.grad_accs =*/ grad_accs_ptr,
6008
6382
  /*.leafs =*/ leafs_ptr,
6383
+ /*.use_counts =*/ use_counts_ptr,
6009
6384
  /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
6010
6385
  /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
6011
6386
  };
@@ -6032,7 +6407,8 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
6032
6407
  /*.grads =*/ NULL, // gradients would need visited_hash_set
6033
6408
  /*.grad_accs =*/ NULL,
6034
6409
  /*.leafs =*/ NULL,
6035
- /*.visited_hash_set =*/ { 0, NULL, NULL },
6410
+ /*.use_counts =*/ cgraph0->use_counts,
6411
+ /*.visited_hash_set =*/ cgraph0->visited_hash_set,
6036
6412
  /*.order =*/ cgraph0->order,
6037
6413
  };
6038
6414
 
@@ -6059,7 +6435,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
6059
6435
  for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
6060
6436
  // copy all hashset keys (tensors) that are in use
6061
6437
  if (ggml_bitset_get(src->visited_hash_set.used, i)) {
6062
- ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
6438
+ size_t new_hash_pos = ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
6439
+ dst->use_counts[new_hash_pos] = src->use_counts[i];
6063
6440
  }
6064
6441
  }
6065
6442