@novastera-oss/llamarn 0.2.1 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. package/README.md +80 -14
  2. package/RNLlamaCpp.podspec +10 -3
  3. package/android/CMakeLists.txt +8 -0
  4. package/android/src/main/cpp/include/llama.h +62 -125
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  13. package/cpp/build-info.cpp +2 -2
  14. package/cpp/llama.cpp/README.md +11 -3
  15. package/cpp/llama.cpp/build-xcframework.sh +1 -0
  16. package/cpp/llama.cpp/common/CMakeLists.txt +8 -2
  17. package/cpp/llama.cpp/common/arg.cpp +153 -113
  18. package/cpp/llama.cpp/common/chat-parser.cpp +379 -0
  19. package/cpp/llama.cpp/common/chat-parser.h +117 -0
  20. package/cpp/llama.cpp/common/chat.cpp +847 -699
  21. package/cpp/llama.cpp/common/chat.h +73 -6
  22. package/cpp/llama.cpp/common/common.cpp +50 -82
  23. package/cpp/llama.cpp/common/common.h +21 -17
  24. package/cpp/llama.cpp/common/json-partial.cpp +255 -0
  25. package/cpp/llama.cpp/common/json-partial.h +37 -0
  26. package/cpp/llama.cpp/common/minja/chat-template.hpp +9 -5
  27. package/cpp/llama.cpp/common/minja/minja.hpp +69 -36
  28. package/cpp/llama.cpp/common/regex-partial.cpp +204 -0
  29. package/cpp/llama.cpp/common/regex-partial.h +56 -0
  30. package/cpp/llama.cpp/common/sampling.cpp +7 -8
  31. package/cpp/llama.cpp/convert_hf_to_gguf.py +453 -118
  32. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +120 -68
  33. package/cpp/llama.cpp/ggml/CMakeLists.txt +2 -1
  34. package/cpp/llama.cpp/ggml/cmake/common.cmake +25 -0
  35. package/cpp/llama.cpp/ggml/include/ggml-opt.h +49 -28
  36. package/cpp/llama.cpp/ggml/include/ggml.h +26 -7
  37. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +16 -10
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +4 -1
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +2 -0
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +604 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +42 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +54 -2
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +50 -51
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -2
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -9
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +779 -19
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +22 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +322 -100
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +117 -1
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +220 -49
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/acc.cu +40 -26
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +11 -1
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +15 -7
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +266 -64
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +49 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +48 -4
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +2 -1
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +5 -1
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +2 -0
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +7 -6
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +1 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +10 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-impl.h +1 -1
  71. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +99 -17
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +200 -2
  74. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +8 -2
  75. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +6 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +972 -178
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +373 -190
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -10
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +101 -5
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +31 -33
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +1 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +29 -2
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +4 -5
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +9 -1
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +84 -72
  96. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  98. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -3
  99. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +324 -129
  100. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +1 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +31 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +95 -68
  103. package/cpp/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +1 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +22 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -4
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +2 -3
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +69 -43
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +2 -14
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -91
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -181
  112. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +17 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  114. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +6 -152
  115. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  117. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +2 -118
  118. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +1 -1
  119. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +12 -1
  120. package/cpp/llama.cpp/ggml/src/ggml.c +107 -36
  121. package/cpp/llama.cpp/ggml/src/gguf.cpp +33 -33
  122. package/cpp/llama.cpp/gguf-py/gguf/constants.py +100 -15
  123. package/cpp/llama.cpp/gguf-py/gguf/gguf_reader.py +1 -1
  124. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +44 -12
  125. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py +21 -10
  126. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +5 -2
  127. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +128 -31
  128. package/cpp/llama.cpp/gguf-py/gguf/utility.py +1 -1
  129. package/cpp/llama.cpp/gguf-py/pyproject.toml +1 -1
  130. package/cpp/llama.cpp/include/llama.h +62 -125
  131. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +1 -1
  132. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +1 -1
  133. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +1 -1
  134. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +1 -1
  135. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +1 -1
  136. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +1 -1
  137. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +1 -1
  138. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +1 -1
  139. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +1 -1
  140. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +1 -1
  141. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +1 -1
  142. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +1 -1
  143. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +1 -1
  144. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +1 -1
  145. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +1 -1
  146. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +1 -1
  147. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +1 -1
  148. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +1 -1
  149. package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
  150. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +1 -1
  151. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +1 -1
  152. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +1 -1
  153. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +1 -1
  154. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +1 -1
  155. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +1 -1
  156. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +1 -1
  157. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +1 -1
  158. package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +62 -0
  159. package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +85 -0
  160. package/cpp/llama.cpp/models/templates/README.md +2 -0
  161. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +5 -1
  162. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +5 -1
  163. package/cpp/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
  164. package/cpp/llama.cpp/requirements/requirements-gguf_editor_gui.txt +1 -1
  165. package/cpp/llama.cpp/src/CMakeLists.txt +2 -0
  166. package/cpp/llama.cpp/src/llama-arch.cpp +6 -0
  167. package/cpp/llama.cpp/src/llama-arch.h +2 -0
  168. package/cpp/llama.cpp/src/llama-batch.cpp +3 -1
  169. package/cpp/llama.cpp/src/llama-context.cpp +340 -123
  170. package/cpp/llama.cpp/src/llama-context.h +30 -0
  171. package/cpp/llama.cpp/src/llama-cparams.cpp +4 -0
  172. package/cpp/llama.cpp/src/llama-cparams.h +2 -0
  173. package/cpp/llama.cpp/src/llama-grammar.cpp +12 -2
  174. package/cpp/llama.cpp/src/llama-graph.cpp +157 -247
  175. package/cpp/llama.cpp/src/llama-graph.h +52 -7
  176. package/cpp/llama.cpp/src/llama-hparams.cpp +17 -1
  177. package/cpp/llama.cpp/src/llama-hparams.h +37 -5
  178. package/cpp/llama.cpp/src/llama-kv-cache.cpp +742 -481
  179. package/cpp/llama.cpp/src/llama-kv-cache.h +196 -99
  180. package/cpp/llama.cpp/src/llama-kv-cells.h +379 -0
  181. package/cpp/llama.cpp/src/llama-memory.h +4 -3
  182. package/cpp/llama.cpp/src/llama-model-loader.cpp +22 -17
  183. package/cpp/llama.cpp/src/llama-model-saver.cpp +281 -0
  184. package/cpp/llama.cpp/src/llama-model-saver.h +37 -0
  185. package/cpp/llama.cpp/src/llama-model.cpp +529 -172
  186. package/cpp/llama.cpp/src/llama-model.h +6 -1
  187. package/cpp/llama.cpp/src/llama-quant.cpp +15 -13
  188. package/cpp/llama.cpp/src/llama-sampling.cpp +2 -2
  189. package/cpp/llama.cpp/src/llama-vocab.cpp +35 -8
  190. package/cpp/llama.cpp/src/llama-vocab.h +6 -0
  191. package/cpp/llama.cpp/src/llama.cpp +14 -0
  192. package/cpp/rn-completion.cpp +4 -2
  193. package/ios/include/chat.h +73 -6
  194. package/ios/include/common/minja/chat-template.hpp +9 -5
  195. package/ios/include/common/minja/minja.hpp +69 -36
  196. package/ios/include/common.h +21 -17
  197. package/ios/include/llama.h +62 -125
  198. package/ios/libs/llama.xcframework/Info.plist +19 -19
  199. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  200. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4617 -4487
  201. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  202. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +26 -7
  203. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +62 -125
  204. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  205. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  206. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  207. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3557 -3435
  208. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  209. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  210. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  211. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  212. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  213. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3559 -3437
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +237 -0
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +26 -7
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +62 -125
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +237 -0
  219. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +26 -7
  220. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +62 -125
  221. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  222. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +237 -0
  223. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +26 -7
  224. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +62 -125
  225. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  226. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  227. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  228. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4616 -4487
  229. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +26 -7
  231. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +62 -125
  232. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4637 -4508
  235. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3556 -3435
  236. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  237. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  238. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  239. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  240. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  241. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4653 -4523
  242. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  243. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +26 -7
  244. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +62 -125
  245. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  246. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  247. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4674 -4544
  248. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3587 -3465
  249. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  250. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  251. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  252. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  253. package/package.json +1 -1
  254. package/cpp/llama.cpp/common/stb_image.h +0 -7988
  255. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +0 -112
  256. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.out +0 -46
  257. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +0 -112
  258. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +0 -46
  259. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +0 -112
  260. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +0 -46
  261. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.inp +0 -112
  262. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.out +0 -46
  263. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +0 -112
  264. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.out +0 -46
  265. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +0 -112
  266. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +0 -46
@@ -856,6 +856,7 @@ kernel void kernel_tanh(
856
856
  constant float GELU_COEF_A = 0.044715f;
857
857
  constant float GELU_QUICK_COEF = -1.702f;
858
858
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
859
+ constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
859
860
 
860
861
  kernel void kernel_gelu(
861
862
  device const float * src0,
@@ -897,6 +898,42 @@ kernel void kernel_gelu_quick_4(
897
898
  dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
898
899
  }
899
900
 
901
+ // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
902
+ // ref: https://www.johndcook.com/blog/python_erf/
903
+ constant float p_erf = 0.3275911f;
904
+ constant float a1_erf = 0.254829592f;
905
+ constant float a2_erf = -0.284496736f;
906
+ constant float a3_erf = 1.421413741f;
907
+ constant float a4_erf = -1.453152027f;
908
+ constant float a5_erf = 1.061405429f;
909
+
910
+ template<typename T>
911
+ T erf_approx(T x) {
912
+ T sign_x = sign(x);
913
+ x = fabs(x);
914
+ T t = 1.0f / (1.0f + p_erf * x);
915
+ T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
916
+ return sign_x * y;
917
+ }
918
+
919
+ kernel void kernel_gelu_erf(
920
+ device const float * src0,
921
+ device float * dst,
922
+ uint tpig[[thread_position_in_grid]]) {
923
+ device const float & x = src0[tpig];
924
+
925
+ dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
926
+ }
927
+
928
+ kernel void kernel_gelu_erf_4(
929
+ device const float4 * src0,
930
+ device float4 * dst,
931
+ uint tpig[[thread_position_in_grid]]) {
932
+ device const float4 & x = src0[tpig];
933
+
934
+ dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
935
+ }
936
+
900
937
  kernel void kernel_silu(
901
938
  device const float * src0,
902
939
  device float * dst,
@@ -2713,8 +2750,148 @@ kernel void kernel_rope_neox(
2713
2750
  }
2714
2751
  }
2715
2752
 
2753
+ template<typename T>
2754
+ kernel void kernel_rope_multi(
2755
+ constant ggml_metal_kargs_rope & args,
2756
+ device const char * src0,
2757
+ device const char * src1,
2758
+ device const char * src2,
2759
+ device char * dst,
2760
+ ushort tiitg[[thread_index_in_threadgroup]],
2761
+ ushort3 tptg [[threads_per_threadgroup]],
2762
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
2763
+ const int i3 = tgpig[2];
2764
+ const int i2 = tgpig[1];
2765
+ const int i1 = tgpig[0];
2766
+
2767
+ float corr_dims[2];
2768
+ rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
2769
+
2770
+ device const int32_t * pos = (device const int32_t *) src1;
2771
+
2772
+ const float inv_ndims = -1.f/args.n_dims;
2773
+
2774
+ float cos_theta;
2775
+ float sin_theta;
2776
+
2777
+ for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
2778
+ if (i0 < args.n_dims) {
2779
+ const int ic = i0/2;
2780
+
2781
+ // mrope theta calculations
2782
+ // note: the rest is the same as kernel_rope_neox
2783
+ const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
2784
+ const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
2785
+ const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
2786
+ const int sector = ic % sect_dims;
2787
+
2788
+ float theta_base;
2789
+ if (sector < args.sect_0) {
2790
+ theta_base = (float) pos[i2];
2791
+ } else if (sector < sec_w01) {
2792
+ theta_base = (float) pos[i2 + args.ne02];
2793
+ } else if (sector < sec_w012) {
2794
+ theta_base = (float) pos[i2 + args.ne02 * 2];
2795
+ } else {
2796
+ theta_base = (float) pos[i2 + args.ne02 * 3];
2797
+ }
2798
+ // end of mrope
2799
+
2800
+ const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
2801
+
2802
+ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
2803
+
2804
+ rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
2805
+
2806
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
2807
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
2808
+
2809
+ const float x0 = src[0];
2810
+ const float x1 = src[args.n_dims/2];
2811
+
2812
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
2813
+ dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
2814
+ } else {
2815
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
2816
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
2817
+
2818
+ dst_data[0] = src[0];
2819
+ dst_data[1] = src[1];
2820
+ }
2821
+ }
2822
+ }
2823
+
2824
+ template<typename T>
2825
+ kernel void kernel_rope_vision(
2826
+ constant ggml_metal_kargs_rope & args,
2827
+ device const char * src0,
2828
+ device const char * src1,
2829
+ device const char * src2,
2830
+ device char * dst,
2831
+ ushort tiitg[[thread_index_in_threadgroup]],
2832
+ ushort3 tptg [[threads_per_threadgroup]],
2833
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
2834
+ const int i3 = tgpig[2];
2835
+ const int i2 = tgpig[1];
2836
+ const int i1 = tgpig[0];
2837
+
2838
+ float corr_dims[2];
2839
+ rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
2840
+
2841
+ device const int32_t * pos = (device const int32_t *) src1;
2842
+
2843
+ const float inv_ndims = -1.f/args.n_dims;
2844
+
2845
+ float cos_theta;
2846
+ float sin_theta;
2847
+
2848
+ for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
2849
+ if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
2850
+ const int ic = i0/2;
2851
+
2852
+ // mrope theta calculations (only support 2 dimensions)
2853
+ const int sect_dims = args.sect_0 + args.sect_1;
2854
+ const int sector = ic % sect_dims;
2855
+
2856
+ float p;
2857
+ float theta_base;
2858
+ if (sector < args.sect_1) {
2859
+ p = (float) sector;
2860
+ theta_base = (float) pos[i2];
2861
+ } else {
2862
+ p = (float) sector - args.sect_0;
2863
+ theta_base = (float) pos[i2 + args.ne02];
2864
+ }
2865
+
2866
+ const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
2867
+ // end of mrope
2868
+
2869
+ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
2870
+
2871
+ rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
2872
+
2873
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
2874
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
2875
+
2876
+ const float x0 = src[0];
2877
+ const float x1 = src[args.n_dims]; // different from kernel_rope_multi
2878
+
2879
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
2880
+ dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
2881
+ } else {
2882
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
2883
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
2884
+
2885
+ dst_data[0] = src[0];
2886
+ dst_data[1] = src[1];
2887
+ }
2888
+ }
2889
+ }
2890
+
2716
2891
  typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
2717
2892
  typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
2893
+ typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
2894
+ typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
2718
2895
 
2719
2896
  template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
2720
2897
  template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
@@ -2722,6 +2899,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
2722
2899
  template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
2723
2900
  template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
2724
2901
 
2902
+ template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
2903
+ template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
2904
+
2905
+ template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
2906
+ template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
2907
+
2725
2908
  typedef void (im2col_t)(
2726
2909
  device const float * x,
2727
2910
  device char * dst,
@@ -3109,7 +3292,7 @@ template<
3109
3292
  typename kd4x4_t, // key type in device memory
3110
3293
  short nl_k,
3111
3294
  void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
3112
- typename vd4x4_t, // key type in device memory
3295
+ typename vd4x4_t, // value type in device memory
3113
3296
  short nl_v,
3114
3297
  void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
3115
3298
  short DK, // K head size
@@ -3630,7 +3813,7 @@ template<
3630
3813
  typename kd4_t, // key type in device memory
3631
3814
  short nl_k,
3632
3815
  void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
3633
- typename vd4_t, // key type in device memory
3816
+ typename vd4_t, // value type in device memory
3634
3817
  short nl_v,
3635
3818
  void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
3636
3819
  short DK, // K head size
@@ -3741,6 +3924,11 @@ kernel void kernel_flash_attn_ext_vec(
3741
3924
  sm[tiisg] = pm[ic + tiisg];
3742
3925
  }
3743
3926
 
3927
+ // skip -INF blocks
3928
+ if (simd_max(sm[tiisg]) == -INFINITY) {
3929
+ continue;
3930
+ }
3931
+
3744
3932
  // Q*K^T
3745
3933
  {
3746
3934
  // each simdgroup processes 1 query and NE (NW/NL) head elements
@@ -3973,6 +4161,16 @@ kernel void kernel_flash_attn_ext_vec(
3973
4161
 
3974
4162
  typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
3975
4163
 
4164
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 8>;
4165
+ #if defined(GGML_METAL_USE_BF16)
4166
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 8>;
4167
+ #endif
4168
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 8>;
4169
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 8>;
4170
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 8>;
4171
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 8>;
4172
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 8>;
4173
+
3976
4174
  template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
3977
4175
  #if defined(GGML_METAL_USE_BF16)
3978
4176
  template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
@@ -27,12 +27,15 @@ if (MUSAToolkit_FOUND)
27
27
 
28
28
  file(GLOB GGML_HEADERS_MUSA "../ggml-cuda/*.cuh")
29
29
  list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
30
+ list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh")
30
31
 
31
32
  file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
32
33
  file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
33
34
  list(APPEND GGML_SOURCES_MUSA ${SRCS})
34
35
  file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
35
36
  list(APPEND GGML_SOURCES_MUSA ${SRCS})
37
+ file(GLOB SRCS "../ggml-musa/*.cu")
38
+ list(APPEND GGML_SOURCES_MUSA ${SRCS})
36
39
 
37
40
  if (GGML_CUDA_FA_ALL_QUANTS)
38
41
  file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
@@ -62,7 +65,9 @@ if (MUSAToolkit_FOUND)
62
65
  )
63
66
 
64
67
  # TODO: do not use CUDA definitions for MUSA
65
- target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
68
+ if (NOT GGML_BACKEND_DL)
69
+ target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
70
+ endif()
66
71
 
67
72
  add_compile_definitions(GGML_USE_MUSA)
68
73
  add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
@@ -92,9 +97,10 @@ if (MUSAToolkit_FOUND)
92
97
  endif()
93
98
 
94
99
  if (GGML_STATIC)
100
+ # TODO: mudnn has not provided static libraries yet
95
101
  target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static)
96
102
  else()
97
- target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas)
103
+ target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas mudnn)
98
104
  endif()
99
105
 
100
106
  if (GGML_CUDA_NO_VMM)
@@ -0,0 +1,112 @@
1
+ #include <mutex>
2
+ #include <mudnn.h>
3
+
4
+ #include "mudnn.cuh"
5
+
6
+ namespace mudnn = musa::dnn;
7
+
8
+ // Returns a human-readable error string for mudnn::Status
9
+ const char* mudnnGetErrorString(mudnn::Status err) {
10
+ switch (err) {
11
+ case mudnn::Status::SUCCESS:
12
+ return "Success";
13
+ case mudnn::Status::INVALID_PARAMETER:
14
+ return "Invalid parameter";
15
+ case mudnn::Status::NOT_INITIALIZED:
16
+ return "Not initialized";
17
+ case mudnn::Status::ALLOC_FAILED:
18
+ return "Allocation failed";
19
+ case mudnn::Status::NOT_SUPPORTED:
20
+ return "Not supported";
21
+ case mudnn::Status::INTERNAL_ERROR:
22
+ return "Internal error";
23
+ case mudnn::Status::ARCH_MISMATCH:
24
+ return "Architecture mismatch";
25
+ case mudnn::Status::EXECUTION_FAILED:
26
+ return "Execution failed";
27
+ default:
28
+ return "Unknown mudnn status";
29
+ }
30
+ }
31
+
32
+ // Error checking macro for MUDNN calls
33
+ #define MUDNN_CHECK(err) CUDA_CHECK_GEN(err, mudnn::Status::SUCCESS, mudnnGetErrorString)
34
+
35
+ namespace {
36
+ // Thread-safe cache for mudnn::Handle objects per device
37
+ std::unordered_map<int, std::unique_ptr<mudnn::Handle>> handle_cache;
38
+ std::mutex handle_cache_mutex;
39
+
40
+ mudnn::Handle* get_cached_handle(int device_id) {
41
+ std::lock_guard<std::mutex> lock(handle_cache_mutex);
42
+ auto it = handle_cache.find(device_id);
43
+ if (it != handle_cache.end()) {
44
+ return it->second.get();
45
+ }
46
+ auto handle = std::make_unique<mudnn::Handle>(device_id);
47
+ mudnn::Handle* handle_ptr = handle.get();
48
+ handle_cache[device_id] = std::move(handle);
49
+ return handle_ptr;
50
+ }
51
+ }
52
+
53
+ // Extracts dimensions and strides from a ggml_tensor
54
+ int get_ggml_dims_and_strides(const ggml_tensor* tensor,
55
+ std::vector<int64_t>& dims,
56
+ std::vector<int64_t>& strides) {
57
+ const int ndims = ggml_n_dims(tensor);
58
+ const size_t element_size = ggml_element_size(tensor);
59
+
60
+ dims.resize(ndims);
61
+ strides.resize(ndims);
62
+
63
+ for (int i = 0; i < ndims; ++i) {
64
+ dims[i] = tensor->ne[i];
65
+ strides[i] = tensor->nb[i] / static_cast<int64_t>(element_size);
66
+ }
67
+ return ndims;
68
+ }
69
+
70
+ // Converts ggml_type to mudnn::Tensor::Type
71
+ mudnn::Tensor::Type ggml_type_to_mudnn_type(ggml_type type) {
72
+ switch (type) {
73
+ case GGML_TYPE_F32:
74
+ return mudnn::Tensor::Type::FLOAT;
75
+ case GGML_TYPE_F16:
76
+ return mudnn::Tensor::Type::HALF;
77
+
78
+ // TODO: Add support for other types
79
+
80
+ default:
81
+ MUDNN_CHECK(mudnn::Status::NOT_SUPPORTED);
82
+ }
83
+
84
+ return mudnn::Tensor::Type::FLOAT; // Default fallback
85
+ }
86
+
87
+ // Asynchronous memory copy using mudnn::Unary::IDENTITY
88
+ musaError_t mudnnMemcpyAsync(ggml_backend_cuda_context& ctx, const ggml_tensor* dst, const ggml_tensor* src) {
89
+ mudnn::Tensor tensor_dst, tensor_src;
90
+
91
+ MUDNN_CHECK(tensor_dst.SetType(ggml_type_to_mudnn_type(dst->type)));
92
+ MUDNN_CHECK(tensor_src.SetType(ggml_type_to_mudnn_type(src->type)));
93
+
94
+ std::vector<int64_t> dims, strides;
95
+ const int ndims = get_ggml_dims_and_strides(src, dims, strides);
96
+
97
+ MUDNN_CHECK(tensor_dst.SetNdInfo(ndims, dims.data(), strides.data()));
98
+ MUDNN_CHECK(tensor_src.SetNdInfo(ndims, dims.data(), strides.data()));
99
+ MUDNN_CHECK(tensor_dst.SetAddr(dst->data));
100
+ MUDNN_CHECK(tensor_src.SetAddr(src->data));
101
+
102
+ mudnn::Unary op;
103
+ MUDNN_CHECK(op.SetMode(mudnn::Unary::Mode::IDENTITY));
104
+ MUDNN_CHECK(op.SetAlpha(0.0f));
105
+ MUDNN_CHECK(op.SetBeta(0.0f));
106
+
107
+ mudnn::Handle* handle = get_cached_handle(ctx.device);
108
+ MUDNN_CHECK(handle->SetStream(ctx.stream()));
109
+ MUDNN_CHECK(op.Run(*handle, tensor_dst, tensor_src));
110
+
111
+ return musaSuccess;
112
+ }
@@ -0,0 +1,12 @@
1
+ #pragma once
2
+
3
+ #include "../include/ggml.h"
4
+ #include "../ggml-cuda/common.cuh"
5
+
6
+ // Asynchronously copies data from src tensor to dst tensor using the provided context.
7
+ // Returns a musaError_t indicating success or failure.
8
+ musaError_t mudnnMemcpyAsync(
9
+ ggml_backend_cuda_context &ctx,
10
+ const ggml_tensor *dst,
11
+ const ggml_tensor *src
12
+ );
@@ -55,14 +55,17 @@ endfunction()
55
55
 
56
56
  set(GGML_OPENCL_KERNELS
57
57
  add
58
+ argsort
58
59
  clamp
59
60
  cpy
60
61
  cvt
61
62
  diag_mask_inf
63
+ div
62
64
  gelu
63
65
  gemv_noshuffle_general
64
66
  gemv_noshuffle
65
67
  get_rows
68
+ group_norm
66
69
  im2col_f32
67
70
  im2col_f16
68
71
  mul_mat_Ab_Bi_8x4
@@ -83,11 +86,14 @@ set(GGML_OPENCL_KERNELS
83
86
  rms_norm
84
87
  rope
85
88
  scale
89
+ sigmoid
86
90
  silu
87
91
  softmax_4_f32
88
92
  softmax_4_f16
89
93
  softmax_f32
90
94
  softmax_f16
95
+ sub
96
+ sum_rows
91
97
  transpose
92
98
  )
93
99