@novastera-oss/llamarn 0.3.1 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (347) hide show
  1. package/README.md +86 -3
  2. package/RNLlamaCpp.podspec +1 -1
  3. package/android/CMakeLists.txt +11 -3
  4. package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
  5. package/android/src/main/cpp/include/llama.h +53 -114
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  13. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  20. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  21. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  22. package/cpp/LlamaCppModel.cpp +2 -10
  23. package/cpp/PureCppImpl.cpp +71 -4
  24. package/cpp/SystemUtils.cpp +3 -7
  25. package/cpp/build-info.cpp +2 -2
  26. package/cpp/llama.cpp/CMakeLists.txt +2 -0
  27. package/cpp/llama.cpp/CODEOWNERS +1 -1
  28. package/cpp/llama.cpp/Makefile +6 -1605
  29. package/cpp/llama.cpp/README.md +5 -1
  30. package/cpp/llama.cpp/common/arg.cpp +230 -51
  31. package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
  32. package/cpp/llama.cpp/common/chat.cpp +539 -8
  33. package/cpp/llama.cpp/common/chat.h +8 -1
  34. package/cpp/llama.cpp/common/common.cpp +60 -15
  35. package/cpp/llama.cpp/common/common.h +64 -15
  36. package/cpp/llama.cpp/common/speculative.cpp +135 -54
  37. package/cpp/llama.cpp/common/speculative.h +8 -1
  38. package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
  39. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
  40. package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
  41. package/cpp/llama.cpp/flake.nix +0 -5
  42. package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
  43. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
  44. package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
  45. package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
  46. package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
  47. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
  51. package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
  52. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
  54. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
  55. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
  57. package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
  73. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
  74. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
  75. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
  76. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
  77. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
  80. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
  82. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
  83. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
  84. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
  85. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  100. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
  101. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
  102. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
  103. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
  104. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
  105. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
  106. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
  107. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
  108. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
  110. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
  111. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
  112. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
  113. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
  114. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
  117. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
  118. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  119. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  120. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
  121. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
  122. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
  123. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  124. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  125. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
  126. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  127. package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  128. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
  131. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
  132. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
  134. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
  135. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
  136. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
  137. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  138. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  141. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  142. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
  143. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
  144. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  145. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
  147. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
  148. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
  149. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  150. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
  152. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  153. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  154. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  161. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  164. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  166. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
  167. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
  168. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
  169. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
  170. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
  173. package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
  174. package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
  175. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
  176. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  177. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
  178. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
  179. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
  180. package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
  181. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
  182. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
  183. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  184. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  185. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  186. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  187. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
  188. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  189. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
  190. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  191. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  192. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
  193. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
  194. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
  195. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
  196. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
  197. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
  198. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
  199. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
  200. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
  201. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  202. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
  203. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
  204. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
  205. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  206. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  207. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
  208. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  209. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
  210. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  211. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  212. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  213. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
  214. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  215. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
  216. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
  217. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
  218. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
  219. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
  220. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
  221. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
  222. package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  223. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
  224. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
  225. package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
  226. package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
  227. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
  228. package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
  229. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
  230. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
  231. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
  232. package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
  233. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
  234. package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
  235. package/cpp/llama.cpp/include/llama.h +53 -114
  236. package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
  237. package/cpp/llama.cpp/models/templates/README.md +2 -1
  238. package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
  239. package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
  240. package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
  241. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
  242. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
  243. package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
  244. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  245. package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
  246. package/cpp/llama.cpp/src/llama-adapter.h +3 -0
  247. package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
  248. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  249. package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
  250. package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
  251. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  252. package/cpp/llama.cpp/src/llama-context.cpp +61 -252
  253. package/cpp/llama.cpp/src/llama-context.h +10 -15
  254. package/cpp/llama.cpp/src/llama-cparams.h +0 -1
  255. package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
  256. package/cpp/llama.cpp/src/llama-graph.h +90 -51
  257. package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
  258. package/cpp/llama.cpp/src/llama-hparams.h +21 -6
  259. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
  260. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
  261. package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
  262. package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
  263. package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
  264. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
  265. package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
  266. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
  267. package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
  268. package/cpp/llama.cpp/src/llama-memory.h +13 -10
  269. package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
  270. package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
  271. package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
  272. package/cpp/llama.cpp/src/llama-model.h +28 -4
  273. package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
  274. package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
  275. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  276. package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
  277. package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
  278. package/cpp/rn-completion.cpp +3 -27
  279. package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
  280. package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
  281. package/ios/include/chat.h +8 -1
  282. package/ios/include/common/minja/chat-template.hpp +16 -7
  283. package/ios/include/common/minja/minja.hpp +47 -12
  284. package/ios/include/common.h +64 -15
  285. package/ios/include/llama.h +53 -114
  286. package/ios/include/speculative.h +8 -1
  287. package/ios/libs/llama.xcframework/Info.plist +18 -18
  288. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  289. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
  290. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  291. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
  292. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
  293. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  294. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  295. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
  296. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  297. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  298. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  299. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  300. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  301. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  302. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  303. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
  304. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
  305. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
  306. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
  307. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
  308. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
  309. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
  310. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  311. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
  312. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
  313. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
  314. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  315. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  316. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  317. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
  318. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  319. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
  320. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
  321. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  322. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  323. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  324. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  325. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  326. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  327. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  328. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  329. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  330. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
  331. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  332. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
  333. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
  334. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  335. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  336. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
  337. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
  338. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  339. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  340. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  341. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  342. package/lib/module/NativeRNLlamaCpp.js.map +1 -1
  343. package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
  344. package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
  345. package/package.json +1 -2
  346. package/src/NativeRNLlamaCpp.ts +7 -0
  347. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
@@ -64,9 +64,11 @@ struct ggml_opt_context {
64
64
  int32_t opt_i = 0;
65
65
  bool loss_per_datapoint = false;
66
66
 
67
- ggml_opt_get_optimizer_params get_opt_pars = nullptr;
68
- void * get_opt_pars_ud = nullptr;
69
- struct ggml_tensor * adamw_params = nullptr;
67
+ ggml_opt_get_optimizer_params get_opt_pars = nullptr;
68
+ void * get_opt_pars_ud = nullptr;
69
+ struct ggml_tensor * opt_step_params = nullptr; // Stores output of get_opt_pars.
70
+
71
+ enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
70
72
  };
71
73
 
72
74
  struct ggml_opt_result {
@@ -229,9 +231,13 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
229
231
  result.adamw.eps = 1e-8f;
230
232
  result.adamw.wd = 0.0f;
231
233
 
234
+ result.sgd.alpha = 1e-3f;
235
+ result.sgd.wd = 0.0f;
236
+
232
237
  return result;
233
238
  }
234
239
 
240
+
235
241
  struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
236
242
  return *((struct ggml_opt_optimizer_params *) userdata);
237
243
  }
@@ -249,6 +255,7 @@ struct ggml_opt_params ggml_opt_default_params(
249
255
  /*opt_period =*/ 1,
250
256
  /*get_opt_pars =*/ ggml_opt_get_default_optimizer_params,
251
257
  /*get_opt_pars_ud =*/ nullptr,
258
+ /*optimizer =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW,
252
259
  };
253
260
  }
254
261
 
@@ -316,9 +323,14 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
316
323
  GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
317
324
  GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
318
325
 
326
+ const enum ggml_opt_optimizer_type optimizer = opt_ctx->optimizer;
327
+
319
328
  const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
320
329
  !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
321
330
 
331
+ const bool need_momenta = opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT &&
332
+ opt_ctx->optimizer == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
333
+
322
334
  ggml_set_input(opt_ctx->inputs);
323
335
  ggml_set_output(opt_ctx->outputs);
324
336
 
@@ -340,8 +352,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
340
352
  // - pred (if using static graphs)
341
353
  // - ncorrect (if using static graphs, 2 tensors).
342
354
  constexpr size_t n_loss = 1;
343
- const size_t tensors_per_param = (accumulate ? 1 : 0) +
344
- (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
355
+ const size_t tensors_per_param = (accumulate ? 1 : 0) + (need_momenta ? 2 : 0);
345
356
  const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
346
357
  const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
347
358
  struct ggml_init_params params = {
@@ -458,7 +469,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
458
469
  }
459
470
  }
460
471
 
461
- if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
472
+ if (need_momenta && opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
462
473
  opt_ctx->grad_m.resize(n_nodes);
463
474
  opt_ctx->grad_v.resize(n_nodes);
464
475
  for (int i = 0; i < n_nodes; ++i) {
@@ -492,23 +503,36 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
492
503
  // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
493
504
  opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
494
505
 
495
- opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7);
496
- ggml_set_input(opt_ctx->adamw_params);
497
- ggml_set_name(opt_ctx->adamw_params, "adamw_params");
498
-
506
+ opt_ctx->opt_step_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, need_momenta ? 7 : 2);
507
+ ggml_tensor * adamw_params = opt_ctx->opt_step_params;
508
+ ggml_set_input(adamw_params);
509
+ const char * optimizer_name = ggml_opt_optimizer_name(opt_ctx->optimizer);
510
+ ggml_format_name(adamw_params, "%s_params", optimizer_name);
499
511
  for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
500
512
  struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
501
513
  struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
502
514
 
503
515
  if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
504
- struct ggml_tensor * m = opt_ctx->grad_m[i];
505
- struct ggml_tensor * v = opt_ctx->grad_v[i];
506
- struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
507
-
508
- ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str());
509
- ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str());
510
- ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
511
-
516
+ struct ggml_tensor * m = nullptr;
517
+ struct ggml_tensor * v = nullptr;
518
+ if (need_momenta) {
519
+ m = opt_ctx->grad_m[i];
520
+ v = opt_ctx->grad_v[i];
521
+ ggml_format_name(m, "AdamW m for %s", node->name);
522
+ ggml_format_name(v, "AdamW v for %s", node->name);
523
+ }
524
+ struct ggml_tensor * opt_step;
525
+ switch (optimizer) {
526
+ case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
527
+ opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params);
528
+ break;
529
+ case GGML_OPT_OPTIMIZER_TYPE_SGD:
530
+ opt_step = ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);
531
+ break;
532
+ default:
533
+ GGML_ABORT("fatal error");
534
+ }
535
+ ggml_format_name(opt_step, "%s step for %s", optimizer_name, node->name);
512
536
  ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
513
537
  }
514
538
  }
@@ -534,6 +558,7 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
534
558
  result->opt_period = params.opt_period;
535
559
  result->get_opt_pars = params.get_opt_pars;
536
560
  result->get_opt_pars_ud = params.get_opt_pars_ud;
561
+ result->optimizer = params.optimizer;
537
562
 
538
563
  GGML_ASSERT(result->opt_period >= 1);
539
564
 
@@ -756,29 +781,43 @@ void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
756
781
  void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
757
782
  GGML_ASSERT(opt_ctx->eval_ready);
758
783
  if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
759
- struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
760
-
761
- GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
762
- GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
763
- GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
764
- GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
765
- GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
766
- GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
767
- GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
768
- GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
769
-
770
- // beta1, beta2 after applying warmup
771
- const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
772
- const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
773
-
774
- float * adamw_par_data = ggml_get_data_f32(opt_ctx->adamw_params);
775
- adamw_par_data[0] = opt_pars.adamw.alpha;
776
- adamw_par_data[1] = opt_pars.adamw.beta1;
777
- adamw_par_data[2] = opt_pars.adamw.beta2;
778
- adamw_par_data[3] = opt_pars.adamw.eps;
779
- adamw_par_data[4] = opt_pars.adamw.wd;
780
- adamw_par_data[5] = beta1h;
781
- adamw_par_data[6] = beta2h;
784
+ const ggml_opt_optimizer_params & opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
785
+
786
+ switch (opt_ctx->optimizer) {
787
+ case GGML_OPT_OPTIMIZER_TYPE_ADAMW: {
788
+ GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
789
+ GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
790
+ GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
791
+ GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
792
+ GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
793
+ GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
794
+ GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
795
+ GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
796
+
797
+ // beta1, beta2 after applying warmup
798
+ const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
799
+ const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
800
+
801
+ float * adamw_par_data = ggml_get_data_f32(opt_ctx->opt_step_params);
802
+ adamw_par_data[0] = opt_pars.adamw.alpha;
803
+ adamw_par_data[1] = opt_pars.adamw.beta1;
804
+ adamw_par_data[2] = opt_pars.adamw.beta2;
805
+ adamw_par_data[3] = opt_pars.adamw.eps;
806
+ adamw_par_data[4] = opt_pars.adamw.wd;
807
+ adamw_par_data[5] = beta1h;
808
+ adamw_par_data[6] = beta2h;
809
+ } break;
810
+ case GGML_OPT_OPTIMIZER_TYPE_SGD: {
811
+ GGML_ASSERT(opt_pars.sgd.alpha > 0.0f);
812
+ GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);
813
+ GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);
814
+ float * sgd = ggml_get_data_f32(opt_ctx->opt_step_params);
815
+ sgd[0] = opt_pars.sgd.alpha;
816
+ sgd[1] = opt_pars.sgd.wd;
817
+ } break;
818
+ default:
819
+ GGML_ABORT("fatal error");
820
+ }
782
821
  }
783
822
 
784
823
  ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
@@ -963,6 +1002,7 @@ void ggml_opt_fit(
963
1002
  ggml_tensor * outputs,
964
1003
  ggml_opt_dataset_t dataset,
965
1004
  enum ggml_opt_loss_type loss_type,
1005
+ enum ggml_opt_optimizer_type optimizer,
966
1006
  ggml_opt_get_optimizer_params get_opt_pars,
967
1007
  int64_t nepoch,
968
1008
  int64_t nbatch_logical,
@@ -993,6 +1033,7 @@ void ggml_opt_fit(
993
1033
  params.opt_period = opt_period;
994
1034
  params.get_opt_pars = get_opt_pars;
995
1035
  params.get_opt_pars_ud = &epoch;
1036
+ params.optimizer = optimizer;
996
1037
  ggml_opt_context_t opt_ctx = ggml_opt_init(params);
997
1038
 
998
1039
  // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
@@ -1035,3 +1076,18 @@ void ggml_opt_fit(
1035
1076
  ggml_opt_result_free(result_train);
1036
1077
  ggml_opt_result_free(result_val);
1037
1078
  }
1079
+
1080
+ enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t c) {
1081
+ return c->optimizer;
1082
+ }
1083
+
1084
+ GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type o) {
1085
+ switch (o) {
1086
+ case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
1087
+ return "adamw";
1088
+ case GGML_OPT_OPTIMIZER_TYPE_SGD:
1089
+ return "sgd";
1090
+ default:
1091
+ return "undefined";
1092
+ };
1093
+ }
@@ -21,6 +21,17 @@
21
21
 
22
22
  #define UNUSED GGML_UNUSED
23
23
 
24
+ static inline int best_index_int8(int n, const int8_t * val, float x) {
25
+ if (x <= val[0]) return 0;
26
+ if (x >= val[n-1]) return n-1;
27
+ int ml = 0, mu = n-1;
28
+ while (mu-ml > 1) {
29
+ int mav = (ml+mu)/2;
30
+ if (x < val[mav]) mu = mav; else ml = mav;
31
+ }
32
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
33
+ }
34
+
24
35
  // reference implementation for deterministic creation of model files
25
36
  void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) {
26
37
  static const int qk = QK4_0;
@@ -246,6 +257,53 @@ void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_REST
246
257
  }
247
258
  }
248
259
 
260
+ static inline int best_index_mxfp4(float x, float e) {
261
+ int best_index = 0;
262
+ float best_err = fabsf(kvalues_mxfp4[0]*e - x);
263
+ for (int i = 1; i < 16; i++) {
264
+ float err = fabsf(kvalues_mxfp4[i]*e - x);
265
+ if (err < best_err) {
266
+ best_index = i;
267
+ best_err = err;
268
+ }
269
+ }
270
+ return best_index;
271
+ }
272
+
273
+ void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
274
+ static const int qk = QK_MXFP4;
275
+
276
+ assert(k % qk == 0);
277
+
278
+ const int nb = k / qk;
279
+
280
+ for (int i = 0; i < nb; i++) {
281
+ float amax = 0.0f; // absolute max
282
+
283
+ for (int j = 0; j < qk; j++) {
284
+ const float v = x[i*qk + j];
285
+
286
+ if (amax < fabsf(v)) {
287
+ amax = fabsf(v);
288
+ }
289
+ }
290
+
291
+ const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;
292
+
293
+ const float d = GGML_E8M0_TO_FP32_HALF(e);
294
+
295
+ y[i].e = e;
296
+
297
+ for (int j = 0; j < qk/2; ++j) {
298
+ const uint8_t x0 = best_index_mxfp4(x[i*qk + 0 + j], d);
299
+ const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d);
300
+
301
+ y[i].qs[j] = x0;
302
+ y[i].qs[j] |= x1 << 4;
303
+ }
304
+ }
305
+ }
306
+
249
307
  void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
250
308
  static const int qk = QK4_0;
251
309
 
@@ -356,6 +414,26 @@ void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRI
356
414
  }
357
415
  }
358
416
 
417
+ void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
418
+ static const int qk = QK_MXFP4;
419
+
420
+ assert(k % qk == 0);
421
+
422
+ const int nb = k / qk;
423
+
424
+ for (int i = 0; i < nb; i++) {
425
+ const float d = GGML_E8M0_TO_FP32_HALF(x[i].e);
426
+
427
+ for (int j = 0; j < qk/2; ++j) {
428
+ const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F];
429
+ const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >> 4];
430
+
431
+ y[i*qk + j + 0 ] = x0*d;
432
+ y[i*qk + j + qk/2] = x1*d;
433
+ }
434
+ }
435
+ }
436
+
359
437
  //
360
438
  // 2-6 bit quantization in super-blocks
361
439
  //
@@ -488,7 +566,7 @@ static float make_q3_quants(int n, int nmax, const float * GGML_RESTRICT x, int8
488
566
  for (int i = 0; i < n; ++i) {
489
567
  L[i] += nmax;
490
568
  }
491
- return sumlx / suml2;
569
+ return suml2 > 0.0f ? sumlx / suml2 : 0.0f;
492
570
  }
493
571
  for (int i = 0; i < n; ++i) {
494
572
  int l = nearest_int(iscale * x[i]);
@@ -823,7 +901,7 @@ static float make_qp_quants(int n, int nmax, const float * GGML_RESTRICT x, uint
823
901
  for (int i = 0; i < n; ++i) {
824
902
  max = MAX(max, x[i]);
825
903
  }
826
- if (!max) { // all zero
904
+ if (max < GROUP_MAX_EPS) { // all zero
827
905
  for (int i = 0; i < n; ++i) { L[i] = 0; }
828
906
  return 0.f;
829
907
  }
@@ -888,7 +966,7 @@ static float make_qp_quants(int n, int nmax, const float * GGML_RESTRICT x, uint
888
966
  break;
889
967
  }
890
968
  }
891
- return sumlx/suml2;
969
+ return suml2 > 0.0f ? sumlx / suml2 : 0.0f;
892
970
  }
893
971
 
894
972
  static void quantize_row_q2_K_impl(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int k, const float * GGML_RESTRICT quant_weights) {
@@ -2014,6 +2092,12 @@ size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
2014
2092
  return nrow * row_size;
2015
2093
  }
2016
2094
 
2095
+ size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
2096
+ GGML_UNUSED(quant_weights);
2097
+ quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
2098
+ return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
2099
+ }
2100
+
2017
2101
  // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
2018
2102
 
2019
2103
  void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
@@ -4182,7 +4266,7 @@ static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_R
4182
4266
  sumw[j+1] = sumw[j] + weight[i];
4183
4267
  }
4184
4268
  }
4185
- float best_score = -FLT_MIN, scale = max;
4269
+ float best_score = -FLT_MAX, scale = max;
4186
4270
  int besti1 = -1, besti2 = -1, best_shift = 0;
4187
4271
  for (int i1 = 0; i1 <= block_size; ++i1) {
4188
4272
  for (int i2 = i1; i2 <= block_size; ++i2) {
@@ -4358,7 +4442,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
4358
4442
  idx[2*j] = j;
4359
4443
  }
4360
4444
  qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
4361
- float best_score = -FLT_MIN, scale = max;
4445
+ float best_score = -FLT_MAX, scale = max;
4362
4446
  int besti1 = -1, besti2 = -1, best_k = -1;
4363
4447
  // 0: +, +
4364
4448
  // 1: +, -
@@ -4551,17 +4635,6 @@ size_t quantize_iq1_m(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
4551
4635
 
4552
4636
  // ============================ 4-bit non-linear quants
4553
4637
 
4554
- static inline int best_index_int8(int n, const int8_t * val, float x) {
4555
- if (x <= val[0]) return 0;
4556
- if (x >= val[n-1]) return n-1;
4557
- int ml = 0, mu = n-1;
4558
- while (mu-ml > 1) {
4559
- int mav = (ml+mu)/2;
4560
- if (x < val[mav]) mu = mav; else ml = mav;
4561
- }
4562
- return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
4563
- }
4564
-
4565
4638
  static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x,
4566
4639
  ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
4567
4640
  float * scales, float * weight, uint8_t * L,
@@ -4961,6 +5034,15 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
4961
5034
  return true;
4962
5035
  }
4963
5036
 
5037
+ static bool validate_e_e8m0(uint8_t e, size_t i) {
5038
+ if (e == 0xff) {
5039
+ fprintf(stderr, "ggml_validate_row_data: found invalid e value %d at block %zu\n", e, i);
5040
+ return false;
5041
+ }
5042
+
5043
+ return true;
5044
+ }
5045
+
4964
5046
  #define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
4965
5047
  const type * q = (const type *) (data); \
4966
5048
  for (size_t i = 0; i < (nb); ++i) { \
@@ -4977,6 +5059,14 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
4977
5059
  } \
4978
5060
  }
4979
5061
 
5062
+ #define VALIDATE_ROW_DATA_E_E8M0_IMPL(type, data, nb) \
5063
+ const type * q = (const type *) (data); \
5064
+ for (size_t i = 0; i < (nb); ++i) { \
5065
+ if (!validate_e_e8m0(q[i].e, i)) { \
5066
+ return false; \
5067
+ } \
5068
+ }
5069
+
4980
5070
  #define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
4981
5071
  const type * q = (const type *) (data); \
4982
5072
  for (size_t i = 0; i < (nb); ++i) { \
@@ -5130,6 +5220,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
5130
5220
  {
5131
5221
  VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
5132
5222
  } break;
5223
+ case GGML_TYPE_MXFP4:
5224
+ {
5225
+ VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
5226
+ } break;
5133
5227
  case GGML_TYPE_Q2_K:
5134
5228
  {
5135
5229
  VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
@@ -21,6 +21,8 @@ GGML_API void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 *
21
21
  GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
22
22
  GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
23
23
 
24
+ GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
25
+
24
26
  GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
25
27
  GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
26
28
  GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
@@ -45,6 +47,8 @@ GGML_API void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GG
45
47
  GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
46
48
  //GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
47
49
 
50
+ GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
51
+
48
52
  GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
49
53
  GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
50
54
  GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@@ -90,6 +94,8 @@ GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTR
90
94
  GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
91
95
  GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
92
96
 
97
+ GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
98
+
93
99
  GGML_API void iq2xs_init_impl(enum ggml_type type);
94
100
  GGML_API void iq2xs_free_impl(enum ggml_type type);
95
101
  GGML_API void iq3xs_init_impl(int grid_size);
@@ -29,9 +29,12 @@
29
29
  #include <cstring>
30
30
  #include <fstream>
31
31
  #include <filesystem>
32
+ #include <algorithm>
32
33
 
33
34
  namespace fs = std::filesystem;
34
35
 
36
+ static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB
37
+
35
38
  #ifdef _WIN32
36
39
  typedef SOCKET sockfd_t;
37
40
  using ssize_t = __int64;
@@ -323,11 +326,14 @@ static std::shared_ptr<socket_t> create_server_socket(const char * host, int por
323
326
  static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
324
327
  size_t bytes_sent = 0;
325
328
  while (bytes_sent < size) {
326
- ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0);
329
+ size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE);
330
+ ssize_t n = send(sockfd, (const char *)data + bytes_sent, size_to_send, 0);
327
331
  if (n < 0) {
332
+ GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n",
333
+ bytes_sent, size_to_send);
328
334
  return false;
329
335
  }
330
- bytes_sent += n;
336
+ bytes_sent += (size_t)n;
331
337
  }
332
338
  return true;
333
339
  }
@@ -335,11 +341,18 @@ static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
335
341
  static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
336
342
  size_t bytes_recv = 0;
337
343
  while (bytes_recv < size) {
338
- ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0);
339
- if (n <= 0) {
344
+ size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE);
345
+ ssize_t n = recv(sockfd, (char *)data + bytes_recv, size_to_recv, 0);
346
+ if (n < 0) {
347
+ GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n",
348
+ bytes_recv, size_to_recv);
340
349
  return false;
341
350
  }
342
- bytes_recv += n;
351
+ if (n == 0) {
352
+ GGML_LOG_ERROR("recv returned 0 (peer closed?)\n");
353
+ return false;
354
+ }
355
+ bytes_recv += (size_t)n;
343
356
  }
344
357
  return true;
345
358
  }
@@ -823,10 +836,10 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
823
836
  };
824
837
 
825
838
  ggml_backend_t backend = new ggml_backend {
826
- /* .guid = */ ggml_backend_rpc_guid(),
827
- /* .interface = */ ggml_backend_rpc_interface,
828
- /* .device = */ ggml_backend_rpc_add_device(endpoint),
829
- /* .context = */ ctx
839
+ /* .guid = */ ggml_backend_rpc_guid(),
840
+ /* .iface = */ ggml_backend_rpc_interface,
841
+ /* .device = */ ggml_backend_rpc_add_device(endpoint),
842
+ /* .context = */ ctx
830
843
  };
831
844
  return backend;
832
845
  }
@@ -28,6 +28,7 @@
28
28
  #include "mmvq.hpp"
29
29
  #include "norm.hpp"
30
30
  #include "outprod.hpp"
31
+ #include "quantize.hpp"
31
32
  #include "quants.hpp"
32
33
  #include "rope.hpp"
33
34
  #include "set_rows.hpp"