cui-llama.rn 1.6.1 → 1.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. package/android/src/main/CMakeLists.txt +6 -0
  2. package/android/src/main/java/com/rnllama/LlamaContext.java +51 -14
  3. package/android/src/main/java/com/rnllama/RNLlama.java +158 -6
  4. package/android/src/main/jni.cpp +153 -14
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  13. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
  14. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
  15. package/cpp/chat.cpp +128 -106
  16. package/cpp/chat.h +2 -0
  17. package/cpp/common.cpp +38 -76
  18. package/cpp/common.h +23 -19
  19. package/cpp/ggml-backend.cpp +9 -5
  20. package/cpp/ggml-backend.h +4 -4
  21. package/cpp/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  22. package/cpp/ggml-cpu/ggml-cpu-quants.c +306 -6
  23. package/cpp/ggml-cpu/ggml-cpu.c +5 -13
  24. package/cpp/ggml-cpu/ggml-cpu.cpp +29 -16
  25. package/cpp/ggml-cpu/ops.cpp +107 -13
  26. package/cpp/ggml-cpu/vec.cpp +0 -6
  27. package/cpp/ggml-cpu/vec.h +16 -0
  28. package/cpp/ggml-llama-sim.metallib +0 -0
  29. package/cpp/ggml-llama.metallib +0 -0
  30. package/cpp/ggml-metal-impl.h +36 -11
  31. package/cpp/ggml-metal.m +321 -132
  32. package/cpp/ggml-opt.cpp +373 -190
  33. package/cpp/ggml-opt.h +49 -28
  34. package/cpp/ggml-quants.c +0 -6
  35. package/cpp/ggml.c +93 -38
  36. package/cpp/ggml.h +21 -7
  37. package/cpp/gguf.cpp +33 -33
  38. package/cpp/llama-adapter.cpp +6 -0
  39. package/cpp/llama-arch.cpp +3 -0
  40. package/cpp/llama-batch.cpp +3 -1
  41. package/cpp/llama-chat.cpp +8 -6
  42. package/cpp/llama-chat.h +1 -0
  43. package/cpp/llama-context.cpp +349 -135
  44. package/cpp/llama-context.h +30 -3
  45. package/cpp/llama-cparams.h +1 -0
  46. package/cpp/llama-graph.cpp +150 -234
  47. package/cpp/llama-graph.h +52 -7
  48. package/cpp/llama-hparams.cpp +17 -1
  49. package/cpp/llama-hparams.h +34 -5
  50. package/cpp/llama-kv-cache.cpp +662 -321
  51. package/cpp/llama-kv-cache.h +203 -93
  52. package/cpp/llama-memory.h +3 -2
  53. package/cpp/llama-model-loader.cpp +24 -15
  54. package/cpp/llama-model-saver.cpp +281 -0
  55. package/cpp/llama-model-saver.h +37 -0
  56. package/cpp/llama-model.cpp +536 -132
  57. package/cpp/llama-model.h +7 -1
  58. package/cpp/llama-sampling.cpp +18 -6
  59. package/cpp/llama-vocab.cpp +46 -8
  60. package/cpp/llama-vocab.h +6 -0
  61. package/cpp/llama.cpp +14 -0
  62. package/cpp/llama.h +72 -131
  63. package/cpp/minja/chat-template.hpp +9 -5
  64. package/cpp/minja/minja.hpp +69 -36
  65. package/cpp/rn-llama.cpp +611 -47
  66. package/cpp/rn-llama.h +33 -3
  67. package/cpp/sampling.cpp +57 -50
  68. package/cpp/tools/mtmd/clip-impl.h +462 -0
  69. package/cpp/tools/mtmd/clip.cpp +4024 -0
  70. package/cpp/tools/mtmd/clip.h +101 -0
  71. package/cpp/tools/mtmd/miniaudio.h +93468 -0
  72. package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
  73. package/cpp/tools/mtmd/mtmd-audio.h +62 -0
  74. package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
  75. package/cpp/tools/mtmd/mtmd.cpp +942 -0
  76. package/cpp/tools/mtmd/mtmd.h +362 -0
  77. package/cpp/tools/mtmd/stb_image.h +7988 -0
  78. package/ios/CMakeLists.txt +7 -0
  79. package/ios/RNLlama.mm +77 -3
  80. package/ios/RNLlamaContext.h +5 -1
  81. package/ios/RNLlamaContext.mm +105 -10
  82. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
  83. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +23 -19
  84. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  85. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  86. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  87. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +21 -7
  88. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  89. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +30 -3
  90. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  91. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
  92. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
  93. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  94. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
  95. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  96. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +7 -1
  97. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  98. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +72 -131
  99. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  100. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  101. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
  102. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  103. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  104. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  105. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  106. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
  107. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  108. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  109. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  110. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
  111. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  112. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
  113. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  114. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
  115. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
  116. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  117. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
  118. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  119. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
  120. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  121. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
  122. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  123. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  124. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
  125. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  126. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  127. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  128. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  129. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
  130. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +23 -19
  131. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  132. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  133. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  134. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +21 -7
  135. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  136. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +30 -3
  137. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  138. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
  139. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
  140. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  141. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
  142. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  143. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +7 -1
  144. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  145. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +72 -131
  146. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  147. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  148. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
  149. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  150. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  151. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  152. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  153. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
  154. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  155. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  156. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  157. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
  158. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  159. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
  160. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  161. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
  162. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
  163. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  164. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
  165. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  166. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
  167. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  168. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
  169. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  170. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  171. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
  172. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  173. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  174. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  175. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  176. package/jest/mock.js +33 -7
  177. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  178. package/lib/commonjs/index.js +153 -21
  179. package/lib/commonjs/index.js.map +1 -1
  180. package/lib/module/NativeRNLlama.js.map +1 -1
  181. package/lib/module/index.js +152 -20
  182. package/lib/module/index.js.map +1 -1
  183. package/lib/typescript/NativeRNLlama.d.ts +50 -4
  184. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  185. package/lib/typescript/index.d.ts +72 -6
  186. package/lib/typescript/index.d.ts.map +1 -1
  187. package/package.json +1 -1
  188. package/src/NativeRNLlama.ts +67 -4
  189. package/src/index.ts +212 -38
  190. package/lib/commonjs/chat.js +0 -37
  191. package/lib/commonjs/chat.js.map +0 -1
  192. package/lib/module/chat.js +0 -33
  193. package/lib/module/chat.js.map +0 -1
  194. package/lib/typescript/chat.d.ts +0 -10
  195. package/lib/typescript/chat.d.ts.map +0 -1
  196. package/src/chat.ts +0 -44
@@ -11,24 +11,26 @@
11
11
  #include <vector>
12
12
 
13
13
  #ifdef LM_GGML_USE_CPU_HBM
14
- #include "ggml-cpu-hbm.h"
14
+ # include "ggml-cpu-hbm.h"
15
15
  #endif
16
16
 
17
17
  #ifdef LM_GGML_USE_CPU_KLEIDIAI
18
- #include "kleidiai/kleidiai.h"
19
- #endif
20
-
21
- #if defined(__APPLE__)
22
- #include <sys/types.h>
23
- #include <sys/sysctl.h>
18
+ # include "kleidiai/kleidiai.h"
24
19
  #endif
25
20
 
26
21
  #if defined(_WIN32)
27
- #define WIN32_LEAN_AND_MEAN
28
- #ifndef NOMINMAX
29
- #define NOMINMAX
22
+ # define WIN32_LEAN_AND_MEAN
23
+ # ifndef NOMINMAX
24
+ # define NOMINMAX
25
+ # endif
26
+ # include <windows.h>
27
+ #else
28
+ # include <unistd.h>
30
29
  #endif
31
- #include <windows.h>
30
+
31
+ #if defined(__APPLE__)
32
+ # include <sys/sysctl.h>
33
+ # include <sys/types.h>
32
34
  #endif
33
35
 
34
36
  // ggml-backend interface
@@ -70,8 +72,10 @@ static lm_ggml_backend_buffer_type_t * lm_ggml_backend_cpu_device_get_extra_buff
70
72
  }
71
73
 
72
74
  static bool lm_ggml_backend_cpu_is_extra_buffer_type(lm_ggml_backend_buffer_type_t buft) {
73
- for (auto extra : lm_ggml_backend_cpu_get_extra_buffers_type()) {
74
- if (extra && extra == buft) return true;
75
+ for (auto * extra : lm_ggml_backend_cpu_get_extra_buffers_type()) {
76
+ if (extra && extra == buft) {
77
+ return true;
78
+ }
75
79
  }
76
80
  return false;
77
81
  }
@@ -330,9 +334,18 @@ static const char * lm_ggml_backend_cpu_device_get_description(lm_ggml_backend_d
330
334
  }
331
335
 
332
336
  static void lm_ggml_backend_cpu_device_get_memory(lm_ggml_backend_dev_t dev, size_t * free, size_t * total) {
333
- // TODO
334
- *free = 0;
335
- *total = 0;
337
+ #ifdef _WIN32
338
+ MEMORYSTATUSEX status;
339
+ status.dwLength = sizeof(status);
340
+ GlobalMemoryStatusEx(&status);
341
+ *total = status.ullTotalPhys;
342
+ *free = status.ullAvailPhys;
343
+ #else
344
+ long pages = sysconf(_SC_PHYS_PAGES);
345
+ long page_size = sysconf(_SC_PAGE_SIZE);
346
+ *total = pages * page_size;
347
+ *free = *total;
348
+ #endif
336
349
 
337
350
  LM_GGML_UNUSED(dev);
338
351
  }
@@ -8,19 +8,6 @@
8
8
 
9
9
  #include <float.h>
10
10
 
11
- #if defined(_MSC_VER)
12
- // disable "possible loss of data" to avoid hundreds of casts
13
- // we should just be careful :)
14
- #pragma warning(disable: 4244 4267)
15
-
16
- // disable POSIX deprecation warnings
17
- // these functions are never going away, anyway
18
- #pragma warning(disable: 4996)
19
-
20
- // unreachable code because of multiple instances of code after LM_GGML_ABORT
21
- #pragma warning(disable: 4702)
22
- #endif
23
-
24
11
  // lm_ggml_compute_forward_dup
25
12
 
26
13
  static void lm_ggml_compute_forward_dup_same_cont(
@@ -2704,6 +2691,109 @@ static void lm_ggml_compute_forward_gelu(
2704
2691
  }
2705
2692
  }
2706
2693
 
2694
+ // lm_ggml_compute_forward_gelu_erf
2695
+
2696
+ static void lm_ggml_compute_forward_gelu_erf_f32(
2697
+ const lm_ggml_compute_params * params,
2698
+ lm_ggml_tensor * dst) {
2699
+
2700
+ const lm_ggml_tensor * src0 = dst->src[0];
2701
+
2702
+ assert(lm_ggml_is_contiguous_1(src0));
2703
+ assert(lm_ggml_is_contiguous_1(dst));
2704
+ assert(lm_ggml_are_same_shape(src0, dst));
2705
+
2706
+ const int ith = params->ith;
2707
+ const int nth = params->nth;
2708
+
2709
+ const int nc = src0->ne[0];
2710
+ const int nr = lm_ggml_nrows(src0);
2711
+
2712
+ // rows per thread
2713
+ const int dr = (nr + nth - 1)/nth;
2714
+
2715
+ // row range for this thread
2716
+ const int ir0 = dr*ith;
2717
+ const int ir1 = MIN(ir0 + dr, nr);
2718
+
2719
+ for (int i1 = ir0; i1 < ir1; i1++) {
2720
+ lm_ggml_vec_gelu_erf_f32(nc,
2721
+ (float *) ((char *) dst->data + i1*( dst->nb[1])),
2722
+ (float *) ((char *) src0->data + i1*(src0->nb[1])));
2723
+
2724
+ #ifndef NDEBUG
2725
+ for (int k = 0; k < nc; k++) {
2726
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2727
+ LM_GGML_UNUSED(x);
2728
+ assert(!isnan(x));
2729
+ assert(!isinf(x));
2730
+ }
2731
+ #endif
2732
+ }
2733
+ }
2734
+
2735
+ static void lm_ggml_compute_forward_gelu_erf_f16(
2736
+ const lm_ggml_compute_params * params,
2737
+ lm_ggml_tensor * dst) {
2738
+
2739
+ const lm_ggml_tensor * src0 = dst->src[0];
2740
+
2741
+ assert(lm_ggml_is_contiguous_1(src0));
2742
+ assert(lm_ggml_is_contiguous_1(dst));
2743
+ assert(lm_ggml_are_same_shape(src0, dst));
2744
+
2745
+ const int ith = params->ith;
2746
+ const int nth = params->nth;
2747
+
2748
+ const int nc = src0->ne[0];
2749
+ const int nr = lm_ggml_nrows(src0);
2750
+
2751
+ // rows per thread
2752
+ const int dr = (nr + nth - 1)/nth;
2753
+
2754
+ // row range for this thread
2755
+ const int ir0 = dr*ith;
2756
+ const int ir1 = MIN(ir0 + dr, nr);
2757
+
2758
+ for (int i1 = ir0; i1 < ir1; i1++) {
2759
+ lm_ggml_vec_gelu_erf_f16(nc,
2760
+ (lm_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2761
+ (lm_ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2762
+
2763
+ #ifndef NDEBUG
2764
+ for (int k = 0; k < nc; k++) {
2765
+ const lm_ggml_fp16_t x = ((lm_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2766
+ const float v = LM_GGML_FP16_TO_FP32(x);
2767
+ LM_GGML_UNUSED(v);
2768
+ assert(!isnan(v));
2769
+ assert(!isinf(v));
2770
+ }
2771
+ #endif
2772
+ }
2773
+ }
2774
+
2775
+ static void lm_ggml_compute_forward_gelu_erf(
2776
+ const lm_ggml_compute_params * params,
2777
+ lm_ggml_tensor * dst) {
2778
+
2779
+ const lm_ggml_tensor * src0 = dst->src[0];
2780
+
2781
+ switch (src0->type) {
2782
+ case LM_GGML_TYPE_F32:
2783
+ {
2784
+ lm_ggml_compute_forward_gelu_erf_f32(params, dst);
2785
+ } break;
2786
+ case LM_GGML_TYPE_F16:
2787
+ {
2788
+ lm_ggml_compute_forward_gelu_erf_f16(params, dst);
2789
+ } break;
2790
+ default:
2791
+ {
2792
+ LM_GGML_ABORT("fatal error");
2793
+ }
2794
+ }
2795
+ }
2796
+
2707
2797
  // lm_ggml_compute_forward_gelu_quick
2708
2798
 
2709
2799
  static void lm_ggml_compute_forward_gelu_quick_f32(
@@ -7762,6 +7852,10 @@ void lm_ggml_compute_forward_unary(
7762
7852
  {
7763
7853
  lm_ggml_compute_forward_gelu(params, dst);
7764
7854
  } break;
7855
+ case LM_GGML_UNARY_OP_GELU_ERF:
7856
+ {
7857
+ lm_ggml_compute_forward_gelu_erf(params, dst);
7858
+ } break;
7765
7859
  case LM_GGML_UNARY_OP_GELU_QUICK:
7766
7860
  {
7767
7861
  lm_ggml_compute_forward_gelu_quick(params, dst);
@@ -2,12 +2,6 @@
2
2
 
3
3
  #include <cassert>
4
4
 
5
- #if defined(_MSC_VER)
6
- // disable "possible loss of data" to avoid hundreds of casts
7
- // we should just be careful :)
8
- #pragma warning(disable: 4244 4267)
9
- #endif
10
-
11
5
  // precomputed gelu table for f16 (128 KB)
12
6
  lm_ggml_fp16_t lm_ggml_table_gelu_f16[1 << 16];
13
7
 
@@ -428,6 +428,7 @@ inline static void lm_ggml_vec_exp_f16 (const int n, lm_ggml_fp16_t * y, const l
428
428
  static const float GELU_COEF_A = 0.044715f;
429
429
  static const float GELU_QUICK_COEF = -1.702f;
430
430
  static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
431
+ static const float SQRT_2_INV = 0.70710678118654752440084436210484f;
431
432
 
432
433
  inline static float lm_ggml_gelu_f32(float x) {
433
434
  return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
@@ -440,6 +441,14 @@ inline static void lm_ggml_vec_gelu_f16(const int n, lm_ggml_fp16_t * y, const l
440
441
  }
441
442
  }
442
443
 
444
+ inline static void lm_ggml_vec_gelu_erf_f16(const int n, lm_ggml_fp16_t * y, const lm_ggml_fp16_t * x) {
445
+ for (int i = 0; i < n; ++i) {
446
+ float xi = LM_GGML_FP16_TO_FP32(x[i]);
447
+ float res = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
448
+ y[i] = LM_GGML_FP32_TO_FP16(res);
449
+ }
450
+ }
451
+
443
452
  #ifdef LM_GGML_GELU_FP16
444
453
  inline static void lm_ggml_vec_gelu_f32(const int n, float * y, const float * x) {
445
454
  uint16_t t;
@@ -463,6 +472,13 @@ inline static void lm_ggml_vec_gelu_f32(const int n, float * y, const float * x)
463
472
  }
464
473
  #endif
465
474
 
475
+ inline static void lm_ggml_vec_gelu_erf_f32(const int n, float * y, const float * x) {
476
+ for (int i = 0; i < n; ++i) {
477
+ float xi = x[i];
478
+ y[i] = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
479
+ }
480
+ }
481
+
466
482
  inline static float lm_ggml_gelu_quick_f32(float x) {
467
483
  return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
468
484
  }
Binary file
Binary file
@@ -207,6 +207,10 @@ typedef struct {
207
207
  float attn_factor;
208
208
  float beta_fast;
209
209
  float beta_slow;
210
+ int32_t sect_0;
211
+ int32_t sect_1;
212
+ int32_t sect_2;
213
+ int32_t sect_3;
210
214
  } lm_ggml_metal_kargs_rope;
211
215
 
212
216
  typedef struct {
@@ -299,21 +303,42 @@ typedef struct {
299
303
  } lm_ggml_metal_kargs_mul_mv_ext;
300
304
 
301
305
  typedef struct {
302
- int32_t nei0;
303
- int32_t nei1;
304
- uint64_t nbi1;
306
+ int32_t ne10;
307
+ int32_t ne11; // n_expert_used (bcast)
308
+ uint64_t nb11;
309
+ uint64_t nb12;
310
+ int32_t neh11; // n_tokens
311
+ uint64_t nbh11;
312
+ int32_t ne20; // n_expert_used
313
+ uint64_t nb21;
314
+ } lm_ggml_metal_kargs_mul_mm_id_map0;
315
+
316
+ typedef struct {
317
+ int32_t ne20; // n_expert_used
318
+ int32_t neh0;
319
+ int32_t neh1;
320
+ uint64_t nbh1;
321
+ uint64_t nbh2;
322
+ int32_t ne0;
323
+ uint64_t nb1;
324
+ uint64_t nb2;
325
+ } lm_ggml_metal_kargs_mul_mm_id_map1;
326
+
327
+ typedef struct {
305
328
  int32_t ne00;
306
329
  int32_t ne02;
307
330
  uint64_t nb01;
308
331
  uint64_t nb02;
309
- int32_t ne11;
310
- int32_t ne12;
311
- int32_t ne13;
312
- uint64_t nb10;
313
- uint64_t nb11;
314
- uint64_t nb12;
315
- int32_t ne0;
316
- int32_t ne1;
332
+ uint64_t nb03;
333
+ int32_t neh12;
334
+ uint64_t nbh10;
335
+ uint64_t nbh11;
336
+ uint64_t nbh12;
337
+ uint64_t nbh13;
338
+ int32_t neh0;
339
+ int32_t neh1;
340
+ int16_t r2;
341
+ int16_t r3;
317
342
  } lm_ggml_metal_kargs_mul_mm_id;
318
343
 
319
344
  typedef struct {