@novastera-oss/llamarn 0.3.1 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (347) hide show
  1. package/README.md +86 -3
  2. package/RNLlamaCpp.podspec +1 -1
  3. package/android/CMakeLists.txt +11 -3
  4. package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
  5. package/android/src/main/cpp/include/llama.h +53 -114
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  13. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  20. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  21. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  22. package/cpp/LlamaCppModel.cpp +2 -10
  23. package/cpp/PureCppImpl.cpp +71 -4
  24. package/cpp/SystemUtils.cpp +3 -7
  25. package/cpp/build-info.cpp +2 -2
  26. package/cpp/llama.cpp/CMakeLists.txt +2 -0
  27. package/cpp/llama.cpp/CODEOWNERS +1 -1
  28. package/cpp/llama.cpp/Makefile +6 -1605
  29. package/cpp/llama.cpp/README.md +5 -1
  30. package/cpp/llama.cpp/common/arg.cpp +230 -51
  31. package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
  32. package/cpp/llama.cpp/common/chat.cpp +539 -8
  33. package/cpp/llama.cpp/common/chat.h +8 -1
  34. package/cpp/llama.cpp/common/common.cpp +60 -15
  35. package/cpp/llama.cpp/common/common.h +64 -15
  36. package/cpp/llama.cpp/common/speculative.cpp +135 -54
  37. package/cpp/llama.cpp/common/speculative.h +8 -1
  38. package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
  39. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
  40. package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
  41. package/cpp/llama.cpp/flake.nix +0 -5
  42. package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
  43. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
  44. package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
  45. package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
  46. package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
  47. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
  51. package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
  52. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
  54. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
  55. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
  57. package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
  73. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
  74. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
  75. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
  76. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
  77. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
  80. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
  82. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
  83. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
  84. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
  85. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  100. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
  101. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
  102. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
  103. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
  104. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
  105. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
  106. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
  107. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
  108. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
  110. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
  111. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
  112. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
  113. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
  114. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
  117. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
  118. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  119. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  120. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
  121. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
  122. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
  123. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  124. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  125. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
  126. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  127. package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  128. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
  131. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
  132. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
  134. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
  135. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
  136. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
  137. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  138. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  141. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  142. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
  143. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
  144. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  145. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
  147. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
  148. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
  149. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  150. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
  152. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  153. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  154. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  161. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  164. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  166. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
  167. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
  168. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
  169. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
  170. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
  173. package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
  174. package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
  175. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
  176. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  177. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
  178. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
  179. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
  180. package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
  181. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
  182. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
  183. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  184. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  185. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  186. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  187. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
  188. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  189. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
  190. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  191. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  192. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
  193. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
  194. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
  195. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
  196. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
  197. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
  198. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
  199. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
  200. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
  201. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  202. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
  203. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
  204. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
  205. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  206. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  207. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
  208. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  209. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
  210. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  211. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  212. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  213. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
  214. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  215. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
  216. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
  217. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
  218. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
  219. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
  220. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
  221. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
  222. package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  223. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
  224. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
  225. package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
  226. package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
  227. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
  228. package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
  229. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
  230. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
  231. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
  232. package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
  233. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
  234. package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
  235. package/cpp/llama.cpp/include/llama.h +53 -114
  236. package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
  237. package/cpp/llama.cpp/models/templates/README.md +2 -1
  238. package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
  239. package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
  240. package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
  241. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
  242. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
  243. package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
  244. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  245. package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
  246. package/cpp/llama.cpp/src/llama-adapter.h +3 -0
  247. package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
  248. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  249. package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
  250. package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
  251. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  252. package/cpp/llama.cpp/src/llama-context.cpp +61 -252
  253. package/cpp/llama.cpp/src/llama-context.h +10 -15
  254. package/cpp/llama.cpp/src/llama-cparams.h +0 -1
  255. package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
  256. package/cpp/llama.cpp/src/llama-graph.h +90 -51
  257. package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
  258. package/cpp/llama.cpp/src/llama-hparams.h +21 -6
  259. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
  260. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
  261. package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
  262. package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
  263. package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
  264. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
  265. package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
  266. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
  267. package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
  268. package/cpp/llama.cpp/src/llama-memory.h +13 -10
  269. package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
  270. package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
  271. package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
  272. package/cpp/llama.cpp/src/llama-model.h +28 -4
  273. package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
  274. package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
  275. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  276. package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
  277. package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
  278. package/cpp/rn-completion.cpp +3 -27
  279. package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
  280. package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
  281. package/ios/include/chat.h +8 -1
  282. package/ios/include/common/minja/chat-template.hpp +16 -7
  283. package/ios/include/common/minja/minja.hpp +47 -12
  284. package/ios/include/common.h +64 -15
  285. package/ios/include/llama.h +53 -114
  286. package/ios/include/speculative.h +8 -1
  287. package/ios/libs/llama.xcframework/Info.plist +18 -18
  288. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  289. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
  290. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  291. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
  292. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
  293. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  294. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  295. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
  296. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  297. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  298. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  299. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  300. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  301. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  302. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  303. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
  304. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
  305. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
  306. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
  307. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
  308. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
  309. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
  310. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  311. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
  312. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
  313. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
  314. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  315. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  316. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  317. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
  318. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  319. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
  320. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
  321. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  322. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  323. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  324. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  325. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  326. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  327. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  328. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  329. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  330. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
  331. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  332. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
  333. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
  334. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  335. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  336. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
  337. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
  338. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  339. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  340. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  341. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  342. package/lib/module/NativeRNLlamaCpp.js.map +1 -1
  343. package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
  344. package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
  345. package/package.json +1 -2
  346. package/src/NativeRNLlamaCpp.ts +7 -0
  347. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
@@ -8,6 +8,7 @@
8
8
  #include "vec.h"
9
9
 
10
10
  #include <float.h>
11
+ #include <algorithm>
11
12
 
12
13
  // ggml_compute_forward_dup
13
14
 
@@ -1283,6 +1284,7 @@ void ggml_compute_forward_add(
1283
1284
  case GGML_TYPE_Q5_0:
1284
1285
  case GGML_TYPE_Q5_1:
1285
1286
  case GGML_TYPE_Q8_0:
1287
+ case GGML_TYPE_MXFP4:
1286
1288
  case GGML_TYPE_Q2_K:
1287
1289
  case GGML_TYPE_Q3_K:
1288
1290
  case GGML_TYPE_Q4_K:
@@ -1309,6 +1311,77 @@ void ggml_compute_forward_add(
1309
1311
  }
1310
1312
  }
1311
1313
 
1314
+ // ggml_compute_forward_add_id
1315
+
1316
+ static void ggml_compute_forward_add_id_f32(
1317
+ const ggml_compute_params * params,
1318
+ ggml_tensor * dst) {
1319
+
1320
+ const ggml_tensor * src0 = dst->src[0];
1321
+ const ggml_tensor * src1 = dst->src[1];
1322
+ const ggml_tensor * src2 = dst->src[2];
1323
+
1324
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
1325
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
1326
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
1327
+ GGML_ASSERT(src2->type == GGML_TYPE_I32);
1328
+
1329
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
1330
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
1331
+
1332
+ const int ith = params->ith;
1333
+ const int nth = params->nth;
1334
+
1335
+ const int nr = ggml_nrows(src0);
1336
+
1337
+ GGML_TENSOR_TERNARY_OP_LOCALS
1338
+
1339
+ GGML_ASSERT( nb0 == sizeof(float));
1340
+ GGML_ASSERT(nb10 == sizeof(float));
1341
+
1342
+ // rows per thread
1343
+ const int dr = (nr + nth - 1)/nth;
1344
+
1345
+ // row range for this thread
1346
+ const int ir0 = dr*ith;
1347
+ const int ir1 = MIN(ir0 + dr, nr);
1348
+
1349
+ for (int ir = ir0; ir < ir1; ++ir) {
1350
+ // src0 indices
1351
+ const int i3 = ir/(ne2*ne1);
1352
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
1353
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
1354
+
1355
+ // src1 indices
1356
+ const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
1357
+
1358
+ GGML_ASSERT(i11 >= 0 && i11 < ne11);
1359
+
1360
+ ggml_vec_add_f32(ne0,
1361
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
1362
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
1363
+ (float *) ((char *) src1->data + i11*nb11));
1364
+ }
1365
+ }
1366
+
1367
+ void ggml_compute_forward_add_id(
1368
+ const ggml_compute_params * params,
1369
+ ggml_tensor * dst) {
1370
+
1371
+ const ggml_tensor * src0 = dst->src[0];
1372
+
1373
+ switch (src0->type) {
1374
+ case GGML_TYPE_F32:
1375
+ {
1376
+ ggml_compute_forward_add_id_f32(params, dst);
1377
+ } break;
1378
+ default:
1379
+ {
1380
+ GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
1381
+ }
1382
+ }
1383
+ }
1384
+
1312
1385
  // ggml_compute_forward_add1
1313
1386
 
1314
1387
  static void ggml_compute_forward_add1_f32(
@@ -1660,6 +1733,7 @@ void ggml_compute_forward_add1(
1660
1733
  case GGML_TYPE_Q5_1:
1661
1734
  case GGML_TYPE_Q8_0:
1662
1735
  case GGML_TYPE_Q8_1:
1736
+ case GGML_TYPE_MXFP4:
1663
1737
  case GGML_TYPE_Q2_K:
1664
1738
  case GGML_TYPE_Q3_K:
1665
1739
  case GGML_TYPE_Q4_K:
@@ -1787,6 +1861,7 @@ void ggml_compute_forward_acc(
1787
1861
  case GGML_TYPE_Q5_1:
1788
1862
  case GGML_TYPE_Q8_0:
1789
1863
  case GGML_TYPE_Q8_1:
1864
+ case GGML_TYPE_MXFP4:
1790
1865
  case GGML_TYPE_Q2_K:
1791
1866
  case GGML_TYPE_Q3_K:
1792
1867
  case GGML_TYPE_Q4_K:
@@ -3614,6 +3689,93 @@ static void ggml_compute_forward_swiglu(
3614
3689
  }
3615
3690
  }
3616
3691
 
3692
+ // ggml_compute_forward_swiglu_oai
3693
+
3694
+ static void ggml_compute_forward_swiglu_oai_f32(
3695
+ const ggml_compute_params * params,
3696
+ ggml_tensor * dst) {
3697
+
3698
+ const ggml_tensor * src0 = dst->src[0];
3699
+ const ggml_tensor * src1 = dst->src[1];
3700
+ char * src0_d = (char *) src0->data;
3701
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3702
+ const size_t src0_o = src0->nb[1];
3703
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3704
+
3705
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3706
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3707
+
3708
+ if (src1) {
3709
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3710
+ GGML_ASSERT(src0->type == src1->type);
3711
+ }
3712
+
3713
+ const int ith = params->ith;
3714
+ const int nth = params->nth;
3715
+
3716
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3717
+ const int nr = ggml_nrows(src0);
3718
+
3719
+ GGML_ASSERT(dst->ne[0] == nc);
3720
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3721
+
3722
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3723
+ const float alpha = ggml_get_op_params_f32(dst, 2);
3724
+ const float limit = ggml_get_op_params_f32(dst, 3);
3725
+
3726
+ // rows per thread
3727
+ const int dr = (nr + nth - 1)/nth;
3728
+
3729
+ // row range for this thread
3730
+ const int ir0 = dr*ith;
3731
+ const int ir1 = MIN(ir0 + dr, nr);
3732
+
3733
+ for (int i1 = ir0; i1 < ir1; i1++) {
3734
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3735
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3736
+ float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
3737
+
3738
+ if (!src1) {
3739
+ src0_p += swapped ? nc : 0;
3740
+ src1_p += swapped ? 0 : nc;
3741
+ }
3742
+
3743
+ for (int k = 0; k < nc; k++) {
3744
+ const float x = std::min(src0_p[k], limit);
3745
+ const float y = std::clamp(src1_p[k], -limit, limit);
3746
+ const float out_glu = x / (1.f + expf(alpha * (-x)));
3747
+ dst_p[k] = out_glu * (y + 1.f);
3748
+ }
3749
+
3750
+ #ifndef NDEBUG
3751
+ for (int k = 0; k < nc; k++) {
3752
+ const float x = dst_p[k];
3753
+ GGML_UNUSED(x);
3754
+ assert(!isnan(x));
3755
+ assert(!isinf(x));
3756
+ }
3757
+ #endif
3758
+ }
3759
+ }
3760
+
3761
+ static void ggml_compute_forward_swiglu_oai(
3762
+ const ggml_compute_params * params,
3763
+ ggml_tensor * dst) {
3764
+
3765
+ const ggml_tensor * src0 = dst->src[0];
3766
+
3767
+ switch (src0->type) {
3768
+ case GGML_TYPE_F32:
3769
+ {
3770
+ ggml_compute_forward_swiglu_oai_f32(params, dst);
3771
+ } break;
3772
+ default:
3773
+ {
3774
+ GGML_ABORT("fatal error");
3775
+ }
3776
+ }
3777
+ }
3778
+
3617
3779
  // ggml_compute_forward_geglu_erf
3618
3780
 
3619
3781
  static void ggml_compute_forward_geglu_erf_f32(
@@ -4599,6 +4761,7 @@ void ggml_compute_forward_out_prod(
4599
4761
  case GGML_TYPE_Q5_0:
4600
4762
  case GGML_TYPE_Q5_1:
4601
4763
  case GGML_TYPE_Q8_0:
4764
+ case GGML_TYPE_MXFP4:
4602
4765
  case GGML_TYPE_Q2_K:
4603
4766
  case GGML_TYPE_Q3_K:
4604
4767
  case GGML_TYPE_Q4_K:
@@ -4873,6 +5036,7 @@ void ggml_compute_forward_set(
4873
5036
  case GGML_TYPE_Q5_1:
4874
5037
  case GGML_TYPE_Q8_0:
4875
5038
  case GGML_TYPE_Q8_1:
5039
+ case GGML_TYPE_MXFP4:
4876
5040
  case GGML_TYPE_Q2_K:
4877
5041
  case GGML_TYPE_Q3_K:
4878
5042
  case GGML_TYPE_Q4_K:
@@ -5134,6 +5298,7 @@ void ggml_compute_forward_get_rows(
5134
5298
  case GGML_TYPE_Q5_1:
5135
5299
  case GGML_TYPE_Q8_0:
5136
5300
  case GGML_TYPE_Q8_1:
5301
+ case GGML_TYPE_MXFP4:
5137
5302
  case GGML_TYPE_Q2_K:
5138
5303
  case GGML_TYPE_Q3_K:
5139
5304
  case GGML_TYPE_Q4_K:
@@ -5523,6 +5688,7 @@ static void ggml_compute_forward_soft_max_f32(
5523
5688
 
5524
5689
  const ggml_tensor * src0 = dst->src[0];
5525
5690
  const ggml_tensor * src1 = dst->src[1];
5691
+ const ggml_tensor * src2 = dst->src[2];
5526
5692
 
5527
5693
  assert(ggml_is_contiguous(dst));
5528
5694
  assert(ggml_are_same_shape(src0, dst));
@@ -5557,6 +5723,9 @@ static void ggml_compute_forward_soft_max_f32(
5557
5723
 
5558
5724
  const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
5559
5725
 
5726
+ // sinks
5727
+ const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
5728
+
5560
5729
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5561
5730
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5562
5731
  for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
@@ -5599,9 +5768,18 @@ static void ggml_compute_forward_soft_max_f32(
5599
5768
  float max = -INFINITY;
5600
5769
  ggml_vec_max_f32(ne00, &max, wp);
5601
5770
 
5771
+ // if we have sinks, make a correction as if they were included in the softmax
5772
+ if (sk) {
5773
+ max = MAX(max, sk[i02]);
5774
+ }
5775
+
5602
5776
  ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
5603
5777
  assert(sum > 0.0);
5604
5778
 
5779
+ if (sk) {
5780
+ sum += (ggml_float) expf(sk[i02] - max);
5781
+ }
5782
+
5605
5783
  sum = 1.0/sum;
5606
5784
  ggml_vec_scale_f32(ne00, dp, sum);
5607
5785
 
@@ -5836,6 +6014,7 @@ void ggml_compute_forward_clamp(
5836
6014
  case GGML_TYPE_Q5_1:
5837
6015
  case GGML_TYPE_Q8_0:
5838
6016
  case GGML_TYPE_Q8_1:
6017
+ case GGML_TYPE_MXFP4:
5839
6018
  case GGML_TYPE_Q2_K:
5840
6019
  case GGML_TYPE_Q3_K:
5841
6020
  case GGML_TYPE_Q4_K:
@@ -7028,6 +7207,148 @@ void ggml_compute_forward_conv_2d(
7028
7207
  ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
7029
7208
  }
7030
7209
 
7210
+ // ggml_compute_forward_conv_3d
7211
+
7212
+ static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
7213
+ const ggml_tensor * kernel,
7214
+ const ggml_tensor * src,
7215
+ ggml_tensor * dst,
7216
+ ggml_type kernel_type) {
7217
+
7218
+ GGML_ASSERT(ggml_is_contiguous(kernel));
7219
+ GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
7220
+ GGML_ASSERT(kernel->type == kernel_type);
7221
+
7222
+ const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
7223
+
7224
+ const int32_t s0 = dst->op_params[0];
7225
+ const int32_t s1 = dst->op_params[1];
7226
+ const int32_t s2 = dst->op_params[2];
7227
+ const int32_t p0 = dst->op_params[3];
7228
+ const int32_t p1 = dst->op_params[4];
7229
+ const int32_t p2 = dst->op_params[5];
7230
+ const int32_t d0 = dst->op_params[6];
7231
+ const int32_t d1 = dst->op_params[7];
7232
+ const int32_t d2 = dst->op_params[8];
7233
+ const int32_t c = dst->op_params[9];
7234
+ const int32_t n = dst->op_params[10];
7235
+ const int32_t oc = dst->op_params[11];
7236
+
7237
+ const int64_t src_w = src->ne[0];
7238
+ const int64_t src_h = src->ne[1];
7239
+ const int64_t src_d = src->ne[2];
7240
+ const int64_t knl_w = kernel->ne[0];
7241
+ const int64_t knl_h = kernel->ne[1];
7242
+ const int64_t knl_d = kernel->ne[2];
7243
+ const int64_t dst_w = dst->ne[0];
7244
+ const int64_t dst_h = dst->ne[1];
7245
+ const int64_t dst_d = dst->ne[2];
7246
+
7247
+ const float * src_data = (float *) src->data;
7248
+ void * knl_data = kernel->data;
7249
+ float * dst_data = (float *) dst->data;
7250
+
7251
+ const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
7252
+ const int64_t knl_n_total = knl_n_per_channel * c;
7253
+ const int64_t patch_total = n * dst_w * dst_h * dst_d;
7254
+
7255
+ const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
7256
+ const int64_t batch_size = params->wsize / space_per_patch;
7257
+ const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
7258
+ const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
7259
+
7260
+ GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
7261
+
7262
+ void * tmp = params->wdata;
7263
+
7264
+ for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
7265
+ const int64_t patch_start_batch = batch_i * patches_per_batch;
7266
+ const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
7267
+ const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
7268
+
7269
+ const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
7270
+ const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
7271
+ const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
7272
+
7273
+ for (int64_t p = patch_start; p < patch_end; ++p) {
7274
+ const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
7275
+ const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
7276
+ const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
7277
+ const int64_t dst_z = p_in_batch / (dst_w * dst_h);
7278
+ const int64_t dst_y = p_in_depth / dst_w;
7279
+ const int64_t dst_x = p_in_depth % dst_w;
7280
+
7281
+ char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
7282
+
7283
+ for (int64_t ic = 0; ic < c; ++ic) {
7284
+ for (int64_t kz = 0; kz < knl_d; ++kz) {
7285
+ for (int64_t ky = 0; ky < knl_h; ++ky) {
7286
+ for (int64_t kx = 0; kx < knl_w; ++kx) {
7287
+ const int64_t sz = dst_z * s2 + kz * d2 - p2;
7288
+ const int64_t sy = dst_y * s1 + ky * d1 - p1;
7289
+ const int64_t sx = dst_x * s0 + kx * d0 - p0;
7290
+
7291
+ int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
7292
+
7293
+ float src_val;
7294
+ if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
7295
+ src_val = 0.0f;
7296
+ } else {
7297
+ const int64_t cn_idx = batch_idx * c + ic;
7298
+ const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
7299
+ src_val = *src_ptr;
7300
+ }
7301
+
7302
+ char * element_ptr = dst_row + dst_idx * traits->type_size;
7303
+ if (kernel_type == GGML_TYPE_F32) {
7304
+ *(float *)element_ptr = src_val;
7305
+ } else if (kernel_type == GGML_TYPE_F16) {
7306
+ *(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
7307
+ }
7308
+ }
7309
+ }
7310
+ }
7311
+ }
7312
+ }
7313
+
7314
+ ggml_barrier(params->threadpool);
7315
+
7316
+ float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
7317
+ ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
7318
+
7319
+ ggml_barrier(params->threadpool);
7320
+
7321
+ const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
7322
+ const int64_t permute_start = params->ith * permute_per_thread;
7323
+ const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
7324
+
7325
+ for (int64_t i = permute_start; i < permute_end; ++i) {
7326
+ const int64_t p = patch_start_batch + i;
7327
+ const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
7328
+ const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
7329
+ const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
7330
+ const int64_t dst_z = p_in_batch / (dst_w * dst_h);
7331
+ const int64_t dst_y = p_in_depth / dst_w;
7332
+ const int64_t dst_x = p_in_depth % dst_w;
7333
+
7334
+ for (int64_t ioc = 0; ioc < oc; ++ioc) {
7335
+ const float value = gemm_output[i * oc + ioc];
7336
+ const int64_t ocn_idx = batch_idx * oc + ioc;
7337
+ float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
7338
+ *dst_ptr = value;
7339
+ }
7340
+ }
7341
+ }
7342
+ }
7343
+
7344
+ void ggml_compute_forward_conv_3d(
7345
+ const ggml_compute_params * params,
7346
+ ggml_tensor * dst) {
7347
+ const ggml_tensor * src0 = dst->src[0];
7348
+ const ggml_tensor * src1 = dst->src[1];
7349
+ ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
7350
+ }
7351
+
7031
7352
  // ggml_compute_forward_conv_transpose_2d
7032
7353
 
7033
7354
  void ggml_compute_forward_conv_transpose_2d(
@@ -7989,12 +8310,14 @@ void ggml_compute_forward_argsort(
7989
8310
 
7990
8311
  static void ggml_compute_forward_flash_attn_ext_f16(
7991
8312
  const ggml_compute_params * params,
7992
- const ggml_tensor * q,
7993
- const ggml_tensor * k,
7994
- const ggml_tensor * v,
7995
- const ggml_tensor * mask,
7996
8313
  ggml_tensor * dst) {
7997
8314
 
8315
+ const ggml_tensor * q = dst->src[0];
8316
+ const ggml_tensor * k = dst->src[1];
8317
+ const ggml_tensor * v = dst->src[2];
8318
+ const ggml_tensor * mask = dst->src[3];
8319
+ const ggml_tensor * sinks = dst->src[4];
8320
+
7998
8321
  GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
7999
8322
  GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8000
8323
  GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
@@ -8189,6 +8512,23 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8189
8512
  }
8190
8513
  }
8191
8514
 
8515
+ // sinks
8516
+ if (sinks) {
8517
+ const float s = ((float *)((char *) sinks->data))[h];
8518
+
8519
+ float ms = 1.0f;
8520
+ float vs = 1.0f;
8521
+
8522
+ if (s > M) {
8523
+ ms = expf(M - s);
8524
+ ggml_vec_scale_f32(DV, VKQ32, ms);
8525
+ } else {
8526
+ vs = expf(s - M);
8527
+ }
8528
+
8529
+ S = S*ms + vs;
8530
+ }
8531
+
8192
8532
  // V /= S
8193
8533
  const float S_inv = 1.0f/S;
8194
8534
  ggml_vec_scale_f32(DV, VKQ32, S_inv);
@@ -8208,17 +8548,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8208
8548
 
8209
8549
  void ggml_compute_forward_flash_attn_ext(
8210
8550
  const ggml_compute_params * params,
8211
- const ggml_tensor * q,
8212
- const ggml_tensor * k,
8213
- const ggml_tensor * v,
8214
- const ggml_tensor * mask,
8215
8551
  ggml_tensor * dst) {
8216
8552
  switch (dst->op_params[3]) {
8217
8553
  case GGML_PREC_DEFAULT:
8218
8554
  case GGML_PREC_F32:
8219
8555
  {
8220
8556
  // uses F32 accumulators
8221
- ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
8557
+ ggml_compute_forward_flash_attn_ext_f16(params, dst);
8222
8558
  } break;
8223
8559
  default:
8224
8560
  {
@@ -8667,8 +9003,7 @@ static void ggml_compute_forward_ssm_scan_f32(
8667
9003
  GGML_ASSERT(src4->nb[0] == sizeof(float));
8668
9004
  GGML_ASSERT(src5->nb[0] == sizeof(float));
8669
9005
  GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
8670
- // allows optimizing the modulo since n_group should be a power of 2
8671
- GGML_ASSERT((ng & -ng) == ng);
9006
+ GGML_ASSERT(nh % ng == 0);
8672
9007
 
8673
9008
  // heads per thread
8674
9009
  const int dh = (nh + nth - 1)/nth;
@@ -8699,6 +9034,7 @@ static void ggml_compute_forward_ssm_scan_f32(
8699
9034
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8700
9035
  const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8701
9036
  const float dA = expf(dt_soft_plus * A[h]);
9037
+ const int g = h / (nh / ng); // repeat_interleave
8702
9038
 
8703
9039
  // dim
8704
9040
  for (int i1 = 0; i1 < nr; ++i1) {
@@ -8721,8 +9057,8 @@ static void ggml_compute_forward_ssm_scan_f32(
8721
9057
  // TODO: maybe unroll more?
8722
9058
  for (int j = 0; j < 1; j++) {
8723
9059
  GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
8724
- GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
8725
- GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
9060
+ GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
9061
+ GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
8726
9062
 
8727
9063
  t0 = GGML_F32_VEC_MUL(t0, adA);
8728
9064
  t1 = GGML_F32_VEC_MUL(t1, axdt);
@@ -8736,6 +9072,9 @@ static void ggml_compute_forward_ssm_scan_f32(
8736
9072
  }
8737
9073
 
8738
9074
  sumf = GGML_F32xt_REDUCE_ONE(sum);
9075
+ #elif defined(__riscv_v_intrinsic)
9076
+ // todo: RVV implementation
9077
+ const int np = 0;
8739
9078
  #else
8740
9079
  const int np = (nc & ~(GGML_F32_STEP - 1));
8741
9080
 
@@ -8751,8 +9090,8 @@ static void ggml_compute_forward_ssm_scan_f32(
8751
9090
  for (int i = 0; i < np; i += GGML_F32_STEP) {
8752
9091
  for (int j = 0; j < GGML_F32_ARR; j++) {
8753
9092
  ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
8754
- ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
8755
- az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
9093
+ ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
9094
+ az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
8756
9095
 
8757
9096
  ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
8758
9097
  ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
@@ -8774,7 +9113,7 @@ static void ggml_compute_forward_ssm_scan_f32(
8774
9113
  // d_state
8775
9114
  for (int i0 = np; i0 < nc; ++i0) {
8776
9115
  const int i = i0 + ii*nc;
8777
- const int ig = i0 + (h & (ng - 1))*nc;
9116
+ const int ig = i0 + g*nc;
8778
9117
  // state = prev_state * dA + dB * x
8779
9118
  const float state = (s0[i] * dA) + (B[ig] * x_dt);
8780
9119
  // y = rowwise_dotprod(state, C)
@@ -8791,6 +9130,7 @@ static void ggml_compute_forward_ssm_scan_f32(
8791
9130
  for (int h = ih0; h < ih1; ++h) {
8792
9131
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8793
9132
  const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
9133
+ const int g = h / (nh / ng); // repeat_interleave
8794
9134
 
8795
9135
  // dim
8796
9136
  for (int i1 = 0; i1 < nr; ++i1) {
@@ -8805,8 +9145,8 @@ static void ggml_compute_forward_ssm_scan_f32(
8805
9145
  // TODO: what happens when (d_state % svcntw()) != 0?
8806
9146
  for (int64_t k = 0; k < nc; k += svcntw()) {
8807
9147
  svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
8808
- svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
8809
- svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
9148
+ svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
9149
+ svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
8810
9150
  svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
8811
9151
 
8812
9152
  svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
@@ -8826,7 +9166,7 @@ static void ggml_compute_forward_ssm_scan_f32(
8826
9166
  // d_state
8827
9167
  for (int i0 = 0; i0 < nc; ++i0) {
8828
9168
  const int i = i0 + ii*nc;
8829
- const int ig = i0 + (h & (ng - 1))*nc;
9169
+ const int ig = i0 + g*nc;
8830
9170
  // state = prev_state * dA + dB * x
8831
9171
  const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
8832
9172
  // y = rowwise_dotprod(state, C)
@@ -9080,6 +9420,10 @@ void ggml_compute_forward_glu(
9080
9420
  {
9081
9421
  ggml_compute_forward_swiglu(params, dst);
9082
9422
  } break;
9423
+ case GGML_GLU_OP_SWIGLU_OAI:
9424
+ {
9425
+ ggml_compute_forward_swiglu_oai(params, dst);
9426
+ } break;
9083
9427
  case GGML_GLU_OP_GEGLU_ERF:
9084
9428
  {
9085
9429
  ggml_compute_forward_geglu_erf(params, dst);
@@ -9683,8 +10027,8 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
9683
10027
  int64_t h_stride_2d = head_size * head_size;
9684
10028
 
9685
10029
  #if defined(GGML_SIMD)
9686
- #if defined(__ARM_FEATURE_SVE)
9687
- // scalar Route to scalar implementation //TODO: Write SVE code
10030
+ #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
10031
+ // scalar Route to scalar implementation //TODO: Write SVE code and RVV code
9688
10032
  for (int64_t t = 0; t < T; t++) {
9689
10033
  int64_t t_offset = t * t_stride;
9690
10034
  int64_t state_offset = head_size * C * (t / (T / n_seqs));
@@ -10132,6 +10476,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
10132
10476
  const int ir1 = MIN(ir0 + dr, nr);
10133
10477
 
10134
10478
  const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
10479
+
10135
10480
  const float alpha = adamw_params_ptr[0];
10136
10481
  const float beta1 = adamw_params_ptr[1];
10137
10482
  const float beta2 = adamw_params_ptr[2];
@@ -10139,7 +10484,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
10139
10484
  const float wd = adamw_params_ptr[4];
10140
10485
  const float beta1h = adamw_params_ptr[5];
10141
10486
  const float beta2h = adamw_params_ptr[6];
10142
-
10487
+ const float keep = 1.f - alpha * wd;
10143
10488
  for (int ir = ir0; ir < ir1; ++ir) {
10144
10489
  const int64_t i03 = ir/(ne02*ne01);
10145
10490
  const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@@ -10162,7 +10507,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
10162
10507
  // The weight decay is applied independently of the Adam momenta m and v.
10163
10508
  // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
10164
10509
  // See: https://arxiv.org/pdf/1711.05101v3.pdf
10165
- w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
10510
+ w[i00] = w[i00] * keep - alpha * mh / vh;
10166
10511
  }
10167
10512
  }
10168
10513
  }
@@ -10184,3 +10529,63 @@ void ggml_compute_forward_opt_step_adamw(
10184
10529
  }
10185
10530
  }
10186
10531
  }
10532
+
10533
+ static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
10534
+ const ggml_tensor * src0 = dst->src[0];
10535
+ const ggml_tensor * src0_grad = dst->src[1];
10536
+ const ggml_tensor * sgd_params = dst->src[2];
10537
+
10538
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
10539
+ GGML_ASSERT(ggml_nelements(sgd_params) == 2);
10540
+
10541
+ const int ith = params->ith;
10542
+ const int nth = params->nth;
10543
+
10544
+ const int nr = ggml_nrows(src0);
10545
+
10546
+ GGML_TENSOR_UNARY_OP_LOCALS
10547
+ GGML_ASSERT(nb00 == sizeof(float));
10548
+
10549
+ // rows per thread
10550
+ const int dr = (nr + nth - 1) / nth;
10551
+
10552
+ // row range for this thread
10553
+ const int ir0 = dr * ith;
10554
+ const int ir1 = MIN(ir0 + dr, nr);
10555
+
10556
+ // using adamw param subset we care about - alpha, wd - could have a separate struct
10557
+ const float * sgd_params_ptr = ggml_get_data_f32(sgd_params);
10558
+ const float alpha = sgd_params_ptr[0];
10559
+ const float keep = 1.f - alpha * sgd_params_ptr[1];
10560
+
10561
+ for (int ir = ir0; ir < ir1; ++ir) {
10562
+ const int64_t i03 = ir / (ne02 * ne01);
10563
+ const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
10564
+ const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
10565
+
10566
+ const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
10567
+
10568
+ float * w = (float *) ((char *) src0->data + offset); // weight
10569
+ const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
10570
+
10571
+ for (int i00 = 0; i00 < ne00; ++i00) {
10572
+ w[i00] = w[i00] * keep - alpha * g[i00];
10573
+ }
10574
+ }
10575
+ }
10576
+
10577
+ void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
10578
+ const ggml_tensor * src0 = dst->src[0];
10579
+
10580
+ switch (src0->type) {
10581
+ case GGML_TYPE_F32:
10582
+ {
10583
+ ggml_compute_forward_opt_step_sgd_f32(params, dst);
10584
+ }
10585
+ break;
10586
+ default:
10587
+ {
10588
+ GGML_ABORT("fatal error - sgd is F32 only");
10589
+ }
10590
+ }
10591
+ }