cui-llama.rn 1.4.4 → 1.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. package/android/src/main/CMakeLists.txt +9 -2
  2. package/android/src/main/jni.cpp +54 -34
  3. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  11. package/cpp/binary-ops.cpp +158 -0
  12. package/cpp/binary-ops.h +16 -0
  13. package/cpp/chat.cpp +1769 -1085
  14. package/cpp/chat.h +143 -0
  15. package/cpp/common.cpp +1562 -1996
  16. package/cpp/common.h +677 -744
  17. package/cpp/cpu-common.h +72 -0
  18. package/cpp/ggml-alloc.c +1039 -1030
  19. package/cpp/ggml-alloc.h +1 -1
  20. package/cpp/ggml-backend-impl.h +255 -255
  21. package/cpp/ggml-backend-reg.cpp +586 -582
  22. package/cpp/ggml-backend.cpp +2004 -2002
  23. package/cpp/ggml-backend.h +354 -354
  24. package/cpp/ggml-common.h +1857 -1851
  25. package/cpp/ggml-cpp.h +39 -39
  26. package/cpp/ggml-cpu-aarch64.cpp +5725 -4247
  27. package/cpp/ggml-cpu-aarch64.h +8 -8
  28. package/cpp/ggml-cpu-impl.h +512 -380
  29. package/cpp/ggml-cpu-quants.c +13026 -11517
  30. package/cpp/ggml-cpu-traits.cpp +36 -36
  31. package/cpp/ggml-cpu-traits.h +38 -38
  32. package/cpp/ggml-cpu.c +3438 -14485
  33. package/cpp/ggml-cpu.cpp +655 -633
  34. package/cpp/ggml-cpu.h +138 -135
  35. package/cpp/ggml-impl.h +594 -567
  36. package/cpp/ggml-metal-impl.h +312 -3
  37. package/cpp/ggml-metal.h +66 -66
  38. package/cpp/ggml-metal.m +5360 -5002
  39. package/cpp/ggml-opt.cpp +854 -854
  40. package/cpp/ggml-opt.h +216 -216
  41. package/cpp/ggml-quants.c +5238 -5238
  42. package/cpp/ggml-threading.h +14 -14
  43. package/cpp/ggml.c +6618 -6524
  44. package/cpp/ggml.h +2222 -2194
  45. package/cpp/gguf.cpp +1330 -1329
  46. package/cpp/gguf.h +202 -202
  47. package/cpp/json-schema-to-grammar.cpp +1024 -1025
  48. package/cpp/json-schema-to-grammar.h +21 -22
  49. package/cpp/json.hpp +24766 -24766
  50. package/cpp/llama-adapter.cpp +382 -347
  51. package/cpp/llama-adapter.h +76 -74
  52. package/cpp/llama-arch.cpp +1714 -1492
  53. package/cpp/llama-arch.h +428 -402
  54. package/cpp/llama-batch.cpp +368 -368
  55. package/cpp/llama-batch.h +88 -88
  56. package/cpp/llama-chat.cpp +640 -587
  57. package/cpp/llama-chat.h +56 -53
  58. package/cpp/llama-context.cpp +2831 -1775
  59. package/cpp/llama-context.h +265 -128
  60. package/cpp/llama-cparams.cpp +1 -1
  61. package/cpp/llama-cparams.h +38 -37
  62. package/cpp/llama-cpp.h +30 -30
  63. package/cpp/llama-grammar.cpp +1219 -1219
  64. package/cpp/llama-grammar.h +173 -164
  65. package/cpp/llama-graph.cpp +1695 -0
  66. package/cpp/llama-graph.h +592 -0
  67. package/cpp/llama-hparams.cpp +79 -71
  68. package/cpp/llama-hparams.h +156 -139
  69. package/cpp/llama-impl.cpp +167 -167
  70. package/cpp/llama-impl.h +61 -61
  71. package/cpp/llama-io.cpp +15 -0
  72. package/cpp/llama-io.h +35 -0
  73. package/cpp/llama-kv-cache.cpp +1380 -718
  74. package/cpp/llama-kv-cache.h +213 -218
  75. package/cpp/llama-memory.cpp +1 -0
  76. package/cpp/llama-memory.h +21 -0
  77. package/cpp/llama-mmap.cpp +600 -590
  78. package/cpp/llama-mmap.h +68 -68
  79. package/cpp/llama-model-loader.cpp +1129 -1124
  80. package/cpp/llama-model-loader.h +169 -167
  81. package/cpp/llama-model.cpp +13080 -4023
  82. package/cpp/llama-model.h +409 -370
  83. package/cpp/llama-sampling.cpp +2563 -2525
  84. package/cpp/llama-sampling.h +32 -32
  85. package/cpp/llama-vocab.cpp +3295 -3252
  86. package/cpp/llama-vocab.h +125 -125
  87. package/cpp/llama.cpp +351 -10137
  88. package/cpp/llama.h +1434 -1340
  89. package/cpp/log.cpp +427 -423
  90. package/cpp/log.h +132 -132
  91. package/cpp/{chat-template.hpp → minja/chat-template.hpp} +537 -529
  92. package/cpp/{minja.hpp → minja/minja.hpp} +2941 -2883
  93. package/cpp/ops.cpp +8723 -0
  94. package/cpp/ops.h +128 -0
  95. package/cpp/rn-llama.cpp +45 -71
  96. package/cpp/rn-llama.h +3 -3
  97. package/cpp/sampling.cpp +573 -532
  98. package/cpp/sgemm.cpp +3043 -2598
  99. package/cpp/sgemm.h +14 -14
  100. package/cpp/simd-mappings.h +888 -0
  101. package/cpp/speculative.cpp +278 -277
  102. package/cpp/speculative.h +28 -28
  103. package/cpp/unary-ops.cpp +186 -0
  104. package/cpp/unary-ops.h +28 -0
  105. package/cpp/vec.cpp +258 -0
  106. package/cpp/vec.h +802 -0
  107. package/ios/CMakeLists.txt +5 -2
  108. package/ios/RNLlama.mm +2 -2
  109. package/ios/RNLlamaContext.mm +40 -24
  110. package/package.json +1 -1
  111. package/src/NativeRNLlama.ts +6 -4
  112. package/src/index.ts +3 -1
  113. package/android/src/main/build-arm64/CMakeCache.txt +0 -429
  114. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCCompiler.cmake +0 -81
  115. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCXXCompiler.cmake +0 -101
  116. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_C.bin +0 -0
  117. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_CXX.bin +0 -0
  118. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeSystem.cmake +0 -15
  119. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.c +0 -904
  120. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.o +0 -0
  121. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.cpp +0 -919
  122. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.o +0 -0
  123. package/android/src/main/build-arm64/CMakeFiles/CMakeConfigureLog.yaml +0 -431
  124. package/android/src/main/build-arm64/CMakeFiles/CMakeDirectoryInformation.cmake +0 -16
  125. package/android/src/main/build-arm64/CMakeFiles/Makefile.cmake +0 -165
  126. package/android/src/main/build-arm64/CMakeFiles/Makefile2 +0 -297
  127. package/android/src/main/build-arm64/CMakeFiles/Progress/1 +0 -1
  128. package/android/src/main/build-arm64/CMakeFiles/Progress/2 +0 -1
  129. package/android/src/main/build-arm64/CMakeFiles/Progress/3 +0 -1
  130. package/android/src/main/build-arm64/CMakeFiles/Progress/4 +0 -1
  131. package/android/src/main/build-arm64/CMakeFiles/Progress/5 +0 -1
  132. package/android/src/main/build-arm64/CMakeFiles/Progress/6 +0 -1
  133. package/android/src/main/build-arm64/CMakeFiles/Progress/count.txt +0 -1
  134. package/android/src/main/build-arm64/CMakeFiles/TargetDirectories.txt +0 -8
  135. package/android/src/main/build-arm64/CMakeFiles/cmake.check_cache +0 -1
  136. package/android/src/main/build-arm64/CMakeFiles/progress.marks +0 -1
  137. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o +0 -0
  138. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o.d +0 -58
  139. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o +0 -0
  140. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o.d +0 -756
  141. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o +0 -0
  142. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o.d +0 -709
  143. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o +0 -0
  144. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o.d +0 -714
  145. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o +0 -0
  146. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o.d +0 -62
  147. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o +0 -0
  148. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o.d +0 -708
  149. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o +0 -0
  150. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o.d +0 -113
  151. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o +0 -0
  152. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o.d +0 -713
  153. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o +0 -0
  154. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o.d +0 -763
  155. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o +0 -0
  156. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o.d +0 -61
  157. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o +0 -0
  158. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o.d +0 -707
  159. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o +0 -0
  160. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o.d +0 -104
  161. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o +0 -0
  162. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o.d +0 -714
  163. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o +0 -0
  164. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o.d +0 -723
  165. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/DependInfo.cmake +0 -62
  166. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/build.make +0 -722
  167. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/cmake_clean.cmake +0 -89
  168. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.make +0 -2
  169. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.ts +0 -2
  170. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/depend.make +0 -2
  171. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/flags.make +0 -17
  172. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/progress.make +0 -41
  173. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/DependInfo.cmake +0 -62
  174. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/build.make +0 -722
  175. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/cmake_clean.cmake +0 -89
  176. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.make +0 -2
  177. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.ts +0 -2
  178. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/depend.make +0 -2
  179. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/flags.make +0 -17
  180. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/progress.make +0 -41
  181. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/DependInfo.cmake +0 -62
  182. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/build.make +0 -722
  183. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/cmake_clean.cmake +0 -89
  184. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.make +0 -2
  185. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.ts +0 -2
  186. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/depend.make +0 -2
  187. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/flags.make +0 -17
  188. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/progress.make +0 -41
  189. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/DependInfo.cmake +0 -62
  190. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/build.make +0 -722
  191. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/cmake_clean.cmake +0 -89
  192. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.make +0 -2
  193. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.ts +0 -2
  194. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/depend.make +0 -2
  195. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/flags.make +0 -17
  196. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/progress.make +0 -41
  197. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/DependInfo.cmake +0 -62
  198. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/build.make +0 -722
  199. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/cmake_clean.cmake +0 -89
  200. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.make +0 -2
  201. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.ts +0 -2
  202. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/depend.make +0 -2
  203. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/flags.make +0 -17
  204. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/progress.make +0 -41
  205. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/DependInfo.cmake +0 -62
  206. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/build.make +0 -722
  207. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/cmake_clean.cmake +0 -89
  208. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.make +0 -2
  209. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.ts +0 -2
  210. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/depend.make +0 -2
  211. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/flags.make +0 -17
  212. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/progress.make +0 -41
  213. package/android/src/main/build-arm64/Makefile +0 -1862
  214. package/android/src/main/build-arm64/cmake_install.cmake +0 -66
  215. package/cpp/chat.hpp +0 -55
  216. package/cpp/rn-llama.hpp +0 -913
@@ -1,6 +1,70 @@
1
1
  #ifndef GGML_METAL_IMPL
2
2
  #define GGML_METAL_IMPL
3
3
 
4
+ // kernel parameters for mat-vec threadgroups
5
+ //
6
+ // N_R0: number of src0 rows to process per simdgroup
7
+ // N_SG: number of simdgroups per threadgroup
8
+ //
9
+ // TODO: for optimal performance, become function of the device and work size
10
+
11
+ #define N_R0_Q4_0 4
12
+ #define N_SG_Q4_0 2
13
+
14
+ #define N_R0_Q4_1 4
15
+ #define N_SG_Q4_1 2
16
+
17
+ #define N_R0_Q5_0 4
18
+ #define N_SG_Q5_0 2
19
+
20
+ #define N_R0_Q5_1 4
21
+ #define N_SG_Q5_1 2
22
+
23
+ #define N_R0_Q8_0 4
24
+ #define N_SG_Q8_0 2
25
+
26
+ #define N_R0_Q2_K 4
27
+ #define N_SG_Q2_K 2
28
+
29
+ #define N_R0_Q3_K 2
30
+ #define N_SG_Q3_K 2
31
+
32
+ #define N_R0_Q4_K 4
33
+ #define N_SG_Q4_K 2
34
+
35
+ #define N_R0_Q5_K 2
36
+ #define N_SG_Q5_K 2
37
+
38
+ #define N_R0_Q6_K 1
39
+ #define N_SG_Q6_K 2
40
+
41
+ #define N_R0_IQ1_S 4
42
+ #define N_SG_IQ1_S 2
43
+
44
+ #define N_R0_IQ1_M 4
45
+ #define N_SG_IQ1_M 2
46
+
47
+ #define N_R0_IQ2_XXS 4
48
+ #define N_SG_IQ2_XXS 2
49
+
50
+ #define N_R0_IQ2_XS 4
51
+ #define N_SG_IQ2_XS 2
52
+
53
+ #define N_R0_IQ2_S 4
54
+ #define N_SG_IQ2_S 2
55
+
56
+ #define N_R0_IQ3_XXS 4
57
+ #define N_SG_IQ3_XXS 2
58
+
59
+ #define N_R0_IQ3_S 4
60
+ #define N_SG_IQ3_S 2
61
+
62
+ #define N_R0_IQ4_NL 2
63
+ #define N_SG_IQ4_NL 2
64
+
65
+ #define N_R0_IQ4_XS 2
66
+ #define N_SG_IQ4_XS 2
67
+
4
68
  // kernel argument structs
5
69
  //
6
70
  // - element counters (e.g. ne00) typically use int32_t to reduce register usage
@@ -155,9 +219,12 @@ typedef struct {
155
219
  int32_t ne11;
156
220
  int32_t ne_12_2; // assume K and V are same shape
157
221
  int32_t ne_12_3;
158
- uint64_t nb_12_1;
159
- uint64_t nb_12_2;
160
- uint64_t nb_12_3;
222
+ uint64_t nb11;
223
+ uint64_t nb12;
224
+ uint64_t nb13;
225
+ uint64_t nb21;
226
+ uint64_t nb22;
227
+ uint64_t nb23;
161
228
  uint64_t nb31;
162
229
  int32_t ne1;
163
230
  int32_t ne2;
@@ -285,4 +352,246 @@ typedef struct {
285
352
  float eps;
286
353
  } ggml_metal_kargs_rms_norm;
287
354
 
355
+ typedef struct {
356
+ int32_t ne00;
357
+ int32_t ne00_4;
358
+ uint64_t nb01;
359
+ float eps;
360
+ } ggml_metal_kargs_l2_norm;
361
+
362
+ typedef struct {
363
+ int64_t ne00;
364
+ int64_t ne01;
365
+ int64_t ne02;
366
+ uint64_t nb00;
367
+ uint64_t nb01;
368
+ uint64_t nb02;
369
+ int32_t n_groups;
370
+ float eps;
371
+ } ggml_metal_kargs_group_norm;
372
+
373
+ typedef struct {
374
+ int32_t IC;
375
+ int32_t IL;
376
+ int32_t K;
377
+ int32_t s0;
378
+ uint64_t nb0;
379
+ uint64_t nb1;
380
+ } ggml_metal_kargs_conv_transpose_1d;
381
+
382
+ typedef struct {
383
+ uint64_t ofs0;
384
+ uint64_t ofs1;
385
+ int32_t IW;
386
+ int32_t IH;
387
+ int32_t CHW;
388
+ int32_t s0;
389
+ int32_t s1;
390
+ int32_t p0;
391
+ int32_t p1;
392
+ int32_t d0;
393
+ int32_t d1;
394
+ int32_t N;
395
+ int32_t KH;
396
+ int32_t KW;
397
+ int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
398
+ } ggml_metal_kargs_im2col;
399
+
400
+ typedef struct {
401
+ int64_t ne00;
402
+ int64_t ne01;
403
+ int64_t ne02;
404
+ int64_t ne03;
405
+ uint64_t nb00;
406
+ uint64_t nb01;
407
+ uint64_t nb02;
408
+ uint64_t nb03;
409
+ int64_t ne10;
410
+ int64_t ne11;
411
+ int64_t ne12;
412
+ int64_t ne13;
413
+ uint64_t nb10;
414
+ uint64_t nb11;
415
+ uint64_t nb12;
416
+ uint64_t nb13;
417
+ int64_t ne0;
418
+ int64_t ne1;
419
+ int64_t ne2;
420
+ int64_t ne3;
421
+ uint64_t nb0;
422
+ uint64_t nb1;
423
+ uint64_t nb2;
424
+ uint64_t nb3;
425
+ } ggml_metal_kargs_sum_rows;
426
+
427
+ typedef struct {
428
+ int64_t ne00;
429
+ int64_t ne01;
430
+ int64_t ne02;
431
+ float scale;
432
+ float max_bias;
433
+ float m0;
434
+ float m1;
435
+ uint32_t n_head_log2;
436
+ } ggml_metal_kargs_soft_max;
437
+
438
+ typedef struct {
439
+ int64_t ne00;
440
+ int64_t ne01;
441
+ int n_past;
442
+ } ggml_metal_kargs_diag_mask_inf;
443
+
444
+ typedef struct {
445
+ int64_t ne00;
446
+ int64_t ne01;
447
+ int64_t ne02;
448
+ uint64_t nb00;
449
+ uint64_t nb01;
450
+ uint64_t nb02;
451
+ int64_t ne10;
452
+ int64_t ne11;
453
+ uint64_t nb10;
454
+ uint64_t nb11;
455
+ int64_t ne0;
456
+ int64_t ne1;
457
+ int64_t ne2;
458
+ uint64_t nb0;
459
+ uint64_t nb1;
460
+ uint64_t nb2;
461
+ } ggml_metal_kargs_ssm_conv;
462
+
463
+ typedef struct {
464
+ int64_t d_state;
465
+ int64_t d_inner;
466
+ int64_t n_seq_tokens;
467
+ int64_t n_seqs;
468
+ uint64_t nb00;
469
+ uint64_t nb01;
470
+ uint64_t nb02;
471
+ uint64_t nb10;
472
+ uint64_t nb11;
473
+ uint64_t nb12;
474
+ uint64_t nb13;
475
+ uint64_t nb20;
476
+ uint64_t nb21;
477
+ uint64_t nb22;
478
+ uint64_t nb30;
479
+ uint64_t nb31;
480
+ uint64_t nb40;
481
+ uint64_t nb41;
482
+ uint64_t nb42;
483
+ uint64_t nb50;
484
+ uint64_t nb51;
485
+ uint64_t nb52;
486
+ } ggml_metal_kargs_ssm_scan;
487
+
488
+ typedef struct {
489
+ int64_t ne00;
490
+ uint64_t nb01;
491
+ uint64_t nb02;
492
+ int64_t ne10;
493
+ uint64_t nb10;
494
+ uint64_t nb11;
495
+ uint64_t nb1;
496
+ uint64_t nb2;
497
+ } ggml_metal_kargs_get_rows;
498
+
499
+ typedef struct {
500
+ int64_t ne00;
501
+ int64_t ne01;
502
+ int64_t ne02;
503
+ int64_t ne03;
504
+ uint64_t nb00;
505
+ uint64_t nb01;
506
+ uint64_t nb02;
507
+ uint64_t nb03;
508
+ int64_t ne0;
509
+ int64_t ne1;
510
+ int64_t ne2;
511
+ int64_t ne3;
512
+ uint64_t nb0;
513
+ uint64_t nb1;
514
+ uint64_t nb2;
515
+ uint64_t nb3;
516
+ float sf0;
517
+ float sf1;
518
+ float sf2;
519
+ float sf3;
520
+ } ggml_metal_kargs_upscale;
521
+
522
+ typedef struct {
523
+ int64_t ne00;
524
+ int64_t ne01;
525
+ int64_t ne02;
526
+ int64_t ne03;
527
+ uint64_t nb00;
528
+ uint64_t nb01;
529
+ uint64_t nb02;
530
+ uint64_t nb03;
531
+ int64_t ne0;
532
+ int64_t ne1;
533
+ int64_t ne2;
534
+ int64_t ne3;
535
+ uint64_t nb0;
536
+ uint64_t nb1;
537
+ uint64_t nb2;
538
+ uint64_t nb3;
539
+ } ggml_metal_kargs_pad;
540
+
541
+ typedef struct {
542
+ int64_t ne00;
543
+ int64_t ne01;
544
+ int64_t ne02;
545
+ int64_t ne03;
546
+ uint64_t nb00;
547
+ uint64_t nb01;
548
+ uint64_t nb02;
549
+ uint64_t nb03;
550
+ int64_t ne0;
551
+ int64_t ne1;
552
+ int64_t ne2;
553
+ int64_t ne3;
554
+ uint64_t nb0;
555
+ uint64_t nb1;
556
+ uint64_t nb2;
557
+ uint64_t nb3;
558
+ int32_t p0;
559
+ int32_t p1;
560
+ } ggml_metal_kargs_pad_reflect_1d;
561
+
562
+ typedef struct {
563
+ uint64_t nb1;
564
+ int dim;
565
+ int max_period;
566
+ } ggml_metal_kargs_timestep_embedding;
567
+
568
+ typedef struct {
569
+ float slope;
570
+ } ggml_metal_kargs_leaky_relu;
571
+
572
+ typedef struct {
573
+ int64_t ncols;
574
+ int64_t ncols_pad;
575
+ } ggml_metal_kargs_argsort;
576
+
577
+ typedef struct {
578
+ int64_t ne0;
579
+ float start;
580
+ float step;
581
+ } ggml_metal_kargs_arange;
582
+
583
+ typedef struct {
584
+ int32_t k0;
585
+ int32_t k1;
586
+ int32_t s0;
587
+ int32_t s1;
588
+ int32_t p0;
589
+ int32_t p1;
590
+ int64_t IH;
591
+ int64_t IW;
592
+ int64_t OH;
593
+ int64_t OW;
594
+ int64_t parallel_elements;
595
+ } ggml_metal_kargs_pool_2d;
596
+
288
597
  #endif // GGML_METAL_IMPL
package/cpp/ggml-metal.h CHANGED
@@ -1,66 +1,66 @@
1
- // Note: this description is outdated
2
- //
3
- // An interface allowing to compute lm_ggml_cgraph with Metal
4
- //
5
- // This is a fully functional interface that extends ggml with GPU support for Apple devices.
6
- // A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.)
7
- //
8
- // How it works?
9
- //
10
- // As long as your program can create and evaluate a lm_ggml_cgraph on the CPU, you can use this
11
- // interface to evaluate the same graph on the GPU. Instead of using lm_ggml_graph_compute(), you
12
- // use lm_ggml_metal_graph_compute() (or lm_ggml_vulkan_graph_compute(), etc.)
13
- //
14
- // You only need to make sure that all memory buffers that you used during the graph creation
15
- // are mapped to the device memory with the lm_ggml_metal_add_buffer() function. This mapping is
16
- // used during the graph evaluation to determine the arguments of the compute kernels.
17
- //
18
- // Synchronization between device and host memory (for example for input and output tensors)
19
- // is done with the lm_ggml_metal_set_tensor() and lm_ggml_metal_get_tensor() functions.
20
- //
21
-
22
- #pragma once
23
-
24
- #include "ggml.h"
25
- #include "ggml-backend.h"
26
-
27
- #include <stddef.h>
28
- #include <stdbool.h>
29
-
30
- struct lm_ggml_tensor;
31
- struct lm_ggml_cgraph;
32
-
33
- #ifdef __cplusplus
34
- extern "C" {
35
- #endif
36
-
37
- //
38
- // backend API
39
- // user-code should use only these functions
40
- //
41
-
42
- LM_GGML_BACKEND_API lm_ggml_backend_t lm_ggml_backend_metal_init(void);
43
-
44
- LM_GGML_BACKEND_API bool lm_ggml_backend_is_metal(lm_ggml_backend_t backend);
45
-
46
- LM_GGML_DEPRECATED(
47
- LM_GGML_BACKEND_API lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
48
- "obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713");
49
-
50
- LM_GGML_BACKEND_API void lm_ggml_backend_metal_set_abort_callback(lm_ggml_backend_t backend, lm_ggml_abort_callback abort_callback, void * user_data);
51
-
52
- LM_GGML_BACKEND_API lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_type(void);
53
-
54
- // helper to check if the device supports a specific family
55
- // ideally, the user code should be doing these checks
56
- // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
57
- LM_GGML_BACKEND_API bool lm_ggml_backend_metal_supports_family(lm_ggml_backend_t backend, int family);
58
-
59
- // capture all command buffers committed the next time `lm_ggml_backend_graph_compute` is called
60
- LM_GGML_BACKEND_API void lm_ggml_backend_metal_capture_next_compute(lm_ggml_backend_t backend);
61
-
62
- LM_GGML_BACKEND_API lm_ggml_backend_reg_t lm_ggml_backend_metal_reg(void);
63
-
64
- #ifdef __cplusplus
65
- }
66
- #endif
1
+ // Note: this description is outdated
2
+ //
3
+ // An interface allowing to compute lm_ggml_cgraph with Metal
4
+ //
5
+ // This is a fully functional interface that extends ggml with GPU support for Apple devices.
6
+ // A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.)
7
+ //
8
+ // How it works?
9
+ //
10
+ // As long as your program can create and evaluate a lm_ggml_cgraph on the CPU, you can use this
11
+ // interface to evaluate the same graph on the GPU. Instead of using lm_ggml_graph_compute(), you
12
+ // use lm_ggml_metal_graph_compute() (or lm_ggml_vulkan_graph_compute(), etc.)
13
+ //
14
+ // You only need to make sure that all memory buffers that you used during the graph creation
15
+ // are mapped to the device memory with the lm_ggml_metal_add_buffer() function. This mapping is
16
+ // used during the graph evaluation to determine the arguments of the compute kernels.
17
+ //
18
+ // Synchronization between device and host memory (for example for input and output tensors)
19
+ // is done with the lm_ggml_metal_set_tensor() and lm_ggml_metal_get_tensor() functions.
20
+ //
21
+
22
+ #pragma once
23
+
24
+ #include "ggml.h"
25
+ #include "ggml-backend.h"
26
+
27
+ #include <stddef.h>
28
+ #include <stdbool.h>
29
+
30
+ struct lm_ggml_tensor;
31
+ struct lm_ggml_cgraph;
32
+
33
+ #ifdef __cplusplus
34
+ extern "C" {
35
+ #endif
36
+
37
+ //
38
+ // backend API
39
+ // user-code should use only these functions
40
+ //
41
+
42
+ LM_GGML_BACKEND_API lm_ggml_backend_t lm_ggml_backend_metal_init(void);
43
+
44
+ LM_GGML_BACKEND_API bool lm_ggml_backend_is_metal(lm_ggml_backend_t backend);
45
+
46
+ LM_GGML_DEPRECATED(
47
+ LM_GGML_BACKEND_API lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
48
+ "obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713");
49
+
50
+ LM_GGML_BACKEND_API void lm_ggml_backend_metal_set_abort_callback(lm_ggml_backend_t backend, lm_ggml_abort_callback abort_callback, void * user_data);
51
+
52
+ LM_GGML_BACKEND_API lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_type(void);
53
+
54
+ // helper to check if the device supports a specific family
55
+ // ideally, the user code should be doing these checks
56
+ // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
57
+ LM_GGML_BACKEND_API bool lm_ggml_backend_metal_supports_family(lm_ggml_backend_t backend, int family);
58
+
59
+ // capture all command buffers committed the next time `lm_ggml_backend_graph_compute` is called
60
+ LM_GGML_BACKEND_API void lm_ggml_backend_metal_capture_next_compute(lm_ggml_backend_t backend);
61
+
62
+ LM_GGML_BACKEND_API lm_ggml_backend_reg_t lm_ggml_backend_metal_reg(void);
63
+
64
+ #ifdef __cplusplus
65
+ }
66
+ #endif