cui-llama.rn 1.6.1 → 1.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. package/android/src/main/CMakeLists.txt +6 -0
  2. package/android/src/main/java/com/rnllama/LlamaContext.java +51 -14
  3. package/android/src/main/java/com/rnllama/RNLlama.java +158 -6
  4. package/android/src/main/jni.cpp +153 -14
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  13. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
  14. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
  15. package/cpp/chat.cpp +128 -106
  16. package/cpp/chat.h +2 -0
  17. package/cpp/common.cpp +38 -76
  18. package/cpp/common.h +23 -19
  19. package/cpp/ggml-backend.cpp +9 -5
  20. package/cpp/ggml-backend.h +4 -4
  21. package/cpp/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  22. package/cpp/ggml-cpu/ggml-cpu-quants.c +306 -6
  23. package/cpp/ggml-cpu/ggml-cpu.c +5 -13
  24. package/cpp/ggml-cpu/ggml-cpu.cpp +29 -16
  25. package/cpp/ggml-cpu/ops.cpp +107 -13
  26. package/cpp/ggml-cpu/vec.cpp +0 -6
  27. package/cpp/ggml-cpu/vec.h +16 -0
  28. package/cpp/ggml-llama-sim.metallib +0 -0
  29. package/cpp/ggml-llama.metallib +0 -0
  30. package/cpp/ggml-metal-impl.h +36 -11
  31. package/cpp/ggml-metal.m +321 -132
  32. package/cpp/ggml-opt.cpp +373 -190
  33. package/cpp/ggml-opt.h +49 -28
  34. package/cpp/ggml-quants.c +0 -6
  35. package/cpp/ggml.c +93 -38
  36. package/cpp/ggml.h +21 -7
  37. package/cpp/gguf.cpp +33 -33
  38. package/cpp/llama-adapter.cpp +6 -0
  39. package/cpp/llama-arch.cpp +3 -0
  40. package/cpp/llama-batch.cpp +3 -1
  41. package/cpp/llama-chat.cpp +8 -6
  42. package/cpp/llama-chat.h +1 -0
  43. package/cpp/llama-context.cpp +349 -135
  44. package/cpp/llama-context.h +30 -3
  45. package/cpp/llama-cparams.h +1 -0
  46. package/cpp/llama-graph.cpp +150 -234
  47. package/cpp/llama-graph.h +52 -7
  48. package/cpp/llama-hparams.cpp +17 -1
  49. package/cpp/llama-hparams.h +34 -5
  50. package/cpp/llama-kv-cache.cpp +662 -321
  51. package/cpp/llama-kv-cache.h +203 -93
  52. package/cpp/llama-memory.h +3 -2
  53. package/cpp/llama-model-loader.cpp +24 -15
  54. package/cpp/llama-model-saver.cpp +281 -0
  55. package/cpp/llama-model-saver.h +37 -0
  56. package/cpp/llama-model.cpp +536 -132
  57. package/cpp/llama-model.h +7 -1
  58. package/cpp/llama-sampling.cpp +18 -6
  59. package/cpp/llama-vocab.cpp +46 -8
  60. package/cpp/llama-vocab.h +6 -0
  61. package/cpp/llama.cpp +14 -0
  62. package/cpp/llama.h +72 -131
  63. package/cpp/minja/chat-template.hpp +9 -5
  64. package/cpp/minja/minja.hpp +69 -36
  65. package/cpp/rn-llama.cpp +611 -47
  66. package/cpp/rn-llama.h +33 -3
  67. package/cpp/sampling.cpp +57 -50
  68. package/cpp/tools/mtmd/clip-impl.h +462 -0
  69. package/cpp/tools/mtmd/clip.cpp +4024 -0
  70. package/cpp/tools/mtmd/clip.h +101 -0
  71. package/cpp/tools/mtmd/miniaudio.h +93468 -0
  72. package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
  73. package/cpp/tools/mtmd/mtmd-audio.h +62 -0
  74. package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
  75. package/cpp/tools/mtmd/mtmd.cpp +942 -0
  76. package/cpp/tools/mtmd/mtmd.h +362 -0
  77. package/cpp/tools/mtmd/stb_image.h +7988 -0
  78. package/ios/CMakeLists.txt +7 -0
  79. package/ios/RNLlama.mm +77 -3
  80. package/ios/RNLlamaContext.h +5 -1
  81. package/ios/RNLlamaContext.mm +105 -10
  82. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
  83. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +23 -19
  84. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  85. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  86. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  87. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +21 -7
  88. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  89. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +30 -3
  90. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  91. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
  92. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
  93. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  94. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
  95. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  96. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +7 -1
  97. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  98. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +72 -131
  99. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  100. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  101. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
  102. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  103. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  104. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  105. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  106. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
  107. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  108. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  109. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  110. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
  111. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  112. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
  113. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  114. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
  115. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
  116. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  117. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
  118. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  119. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
  120. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  121. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
  122. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  123. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  124. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
  125. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  126. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  127. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  128. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  129. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
  130. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +23 -19
  131. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  132. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  133. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  134. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +21 -7
  135. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  136. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +30 -3
  137. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  138. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
  139. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
  140. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  141. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
  142. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  143. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +7 -1
  144. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  145. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +72 -131
  146. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  147. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  148. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
  149. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  150. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  151. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  152. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  153. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
  154. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  155. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  156. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  157. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
  158. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  159. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
  160. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  161. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
  162. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
  163. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  164. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
  165. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  166. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
  167. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  168. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
  169. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  170. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  171. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
  172. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  173. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  174. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  175. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  176. package/jest/mock.js +33 -7
  177. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  178. package/lib/commonjs/index.js +153 -21
  179. package/lib/commonjs/index.js.map +1 -1
  180. package/lib/module/NativeRNLlama.js.map +1 -1
  181. package/lib/module/index.js +152 -20
  182. package/lib/module/index.js.map +1 -1
  183. package/lib/typescript/NativeRNLlama.d.ts +50 -4
  184. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  185. package/lib/typescript/index.d.ts +72 -6
  186. package/lib/typescript/index.d.ts.map +1 -1
  187. package/package.json +1 -1
  188. package/src/NativeRNLlama.ts +67 -4
  189. package/src/index.ts +212 -38
  190. package/lib/commonjs/chat.js +0 -37
  191. package/lib/commonjs/chat.js.map +0 -1
  192. package/lib/module/chat.js +0 -33
  193. package/lib/module/chat.js.map +0 -1
  194. package/lib/typescript/chat.d.ts +0 -10
  195. package/lib/typescript/chat.d.ts.map +0 -1
  196. package/src/chat.ts +0 -44
package/cpp/common.h CHANGED
@@ -6,6 +6,7 @@
6
6
 
7
7
  #include <set>
8
8
  #include <string>
9
+ #include <string_view>
9
10
  #include <vector>
10
11
  #include <sstream>
11
12
 
@@ -77,7 +78,6 @@ enum llama_example {
77
78
  LLAMA_EXAMPLE_COMMON,
78
79
  LLAMA_EXAMPLE_SPECULATIVE,
79
80
  LLAMA_EXAMPLE_MAIN,
80
- LLAMA_EXAMPLE_INFILL,
81
81
  LLAMA_EXAMPLE_EMBEDDING,
82
82
  LLAMA_EXAMPLE_PERPLEXITY,
83
83
  LLAMA_EXAMPLE_RETRIEVAL,
@@ -87,7 +87,7 @@ enum llama_example {
87
87
  LLAMA_EXAMPLE_SERVER,
88
88
  LLAMA_EXAMPLE_CVECTOR_GENERATOR,
89
89
  LLAMA_EXAMPLE_EXPORT_LORA,
90
- LLAMA_EXAMPLE_LLAVA,
90
+ LLAMA_EXAMPLE_MTMD,
91
91
  LLAMA_EXAMPLE_LOOKUP,
92
92
  LLAMA_EXAMPLE_PARALLEL,
93
93
  LLAMA_EXAMPLE_TTS,
@@ -107,6 +107,7 @@ enum common_sampler_type {
107
107
  COMMON_SAMPLER_TYPE_XTC = 8,
108
108
  COMMON_SAMPLER_TYPE_INFILL = 9,
109
109
  COMMON_SAMPLER_TYPE_PENALTIES = 10,
110
+ COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
110
111
  };
111
112
 
112
113
  // dimensionality reduction methods, used by cvector-generator
@@ -172,6 +173,7 @@ struct common_params_sampling {
172
173
  std::vector<enum common_sampler_type> samplers = {
173
174
  COMMON_SAMPLER_TYPE_PENALTIES,
174
175
  COMMON_SAMPLER_TYPE_DRY,
176
+ COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
175
177
  COMMON_SAMPLER_TYPE_TOP_K,
176
178
  COMMON_SAMPLER_TYPE_TYPICAL_P,
177
179
  COMMON_SAMPLER_TYPE_TOP_P,
@@ -336,17 +338,17 @@ struct common_params {
336
338
  bool flash_attn = false; // flash attention
337
339
  bool no_perf = false; // disable performance metrics
338
340
  bool ctx_shift = true; // context shift on inifinite text generation
341
+ bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
339
342
 
340
343
  bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
341
- bool logits_all = false; // return logits for all tokens in the batch
342
344
  bool use_mmap = true; // use mmap for faster loads
343
345
  bool use_mlock = false; // use mlock to keep model in memory
344
346
  bool verbose_prompt = false; // print prompt tokens before generation
345
347
  bool display_prompt = true; // print prompt before generation
346
- bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
347
348
  bool no_kv_offload = false; // disable KV offloading
348
349
  bool warmup = true; // warmup run
349
350
  bool check_tensors = false; // validate tensor data
351
+ bool no_op_offload = false; // globally disable offload host tensor operations to device
350
352
 
351
353
  bool single_turn = false; // single turn chat conversation
352
354
 
@@ -355,7 +357,7 @@ struct common_params {
355
357
 
356
358
  common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
357
359
 
358
- // multimodal models (see tools/llava)
360
+ // multimodal models (see tools/mtmd)
359
361
  struct common_params_model mmproj;
360
362
  bool mmproj_use_gpu = true; // use GPU for multimodal model
361
363
  bool no_mmproj = false; // explicitly disable multimodal model
@@ -381,6 +383,7 @@ struct common_params {
381
383
  bool use_jinja = false; // NOLINT
382
384
  bool enable_chat_template = true;
383
385
  common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
386
+ bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
384
387
 
385
388
  std::vector<std::string> api_keys;
386
389
 
@@ -424,6 +427,7 @@ struct common_params {
424
427
 
425
428
  bool process_output = false; // collect data for the output tensor
426
429
  bool compute_ppl = true; // whether to compute perplexity
430
+ bool parse_special = false; // whether to parse special tokens during imatrix tokenization
427
431
 
428
432
  // cvector-generator params
429
433
  int n_pca_batch = 100;
@@ -439,6 +443,11 @@ struct common_params {
439
443
 
440
444
  // common params
441
445
  std::string out_file; // output filename for all example programs
446
+ // optional callback for model loading progress and cancellation:
447
+ // called with a progress value between 0.0 and 1.0.
448
+ // return false from callback to abort model loading or true to continue
449
+ llama_progress_callback load_progress_callback = NULL;
450
+ void * load_progress_callback_user_data = NULL;
442
451
  };
443
452
 
444
453
  // call once at the start of a program if it uses libcommon
@@ -516,10 +525,9 @@ static bool string_starts_with(const std::string & str,
516
525
  return str.rfind(prefix, 0) == 0;
517
526
  }
518
527
 
519
- static bool string_ends_with(const std::string & str,
520
- const std::string & suffix) { // While we wait for C++20's std::string::ends_with...
521
- return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
522
- }
528
+ // While we wait for C++20's std::string::ends_with...
529
+ bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
530
+ size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
523
531
 
524
532
  bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
525
533
  void string_process_escapes(std::string & input);
@@ -628,16 +636,6 @@ std::string common_detokenize(
628
636
  const std::vector<llama_token> & tokens,
629
637
  bool special = true);
630
638
 
631
- //
632
- // KV cache utils
633
- //
634
-
635
- // Dump the KV cache view with the number of sequences per cell.
636
- void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80);
637
-
638
- // Dump the KV cache view showing individual sequences in each cell (long output).
639
- void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
640
-
641
639
  //
642
640
  // Embedding utils
643
641
  //
@@ -679,3 +677,9 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count";
679
677
  const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
680
678
 
681
679
  }
680
+
681
+ //
682
+ // training utils
683
+ //
684
+
685
+ lm_ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
@@ -56,7 +56,7 @@ size_t lm_ggml_backend_buft_get_max_size(lm_ggml_backend_buffer_type_t buft) {
56
56
  return SIZE_MAX;
57
57
  }
58
58
 
59
- size_t lm_ggml_backend_buft_get_alloc_size(lm_ggml_backend_buffer_type_t buft, struct lm_ggml_tensor * tensor) {
59
+ size_t lm_ggml_backend_buft_get_alloc_size(lm_ggml_backend_buffer_type_t buft, const struct lm_ggml_tensor * tensor) {
60
60
  // get_alloc_size is optional, defaults to lm_ggml_nbytes
61
61
  if (buft->iface.get_alloc_size) {
62
62
  size_t size = buft->iface.get_alloc_size(buft, tensor);
@@ -152,7 +152,7 @@ size_t lm_ggml_backend_buffer_get_max_size(lm_ggml_backend_buffer_t buffer) {
152
152
  return lm_ggml_backend_buft_get_max_size(lm_ggml_backend_buffer_get_type(buffer));
153
153
  }
154
154
 
155
- size_t lm_ggml_backend_buffer_get_alloc_size(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor) {
155
+ size_t lm_ggml_backend_buffer_get_alloc_size(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * tensor) {
156
156
  return lm_ggml_backend_buft_get_alloc_size(lm_ggml_backend_buffer_get_type(buffer), tensor);
157
157
  }
158
158
 
@@ -674,6 +674,8 @@ struct lm_ggml_backend_sched {
674
674
  char * context_buffer;
675
675
  size_t context_buffer_size;
676
676
 
677
+ bool op_offload;
678
+
677
679
  int debug;
678
680
  };
679
681
 
@@ -766,7 +768,7 @@ static int lm_ggml_backend_sched_backend_id_from_cur(lm_ggml_backend_sched_t sch
766
768
  if (tensor->op != LM_GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == LM_GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
767
769
  int src_backend_id = lm_ggml_backend_sched_backend_from_buffer(sched, src, tensor);
768
770
  // check if a backend with higher prio wants to offload the op
769
- if (src_backend_id == sched->n_backends - 1 && lm_ggml_backend_buffer_is_host(src->buffer)) {
771
+ if (sched->op_offload && src_backend_id == sched->n_backends - 1 && lm_ggml_backend_buffer_is_host(src->buffer)) {
770
772
  for (int b = 0; b < src_backend_id; b++) {
771
773
  if (lm_ggml_backend_supports_op(sched->backends[b], tensor) && lm_ggml_backend_offload_op(sched->backends[b], tensor)) {
772
774
  SET_CAUSE(tensor, "1.off");
@@ -1109,7 +1111,7 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str
1109
1111
 
1110
1112
  const int node_backend_id = tensor_backend_id(node);
1111
1113
 
1112
- assert(node_backend_id != -1); // all nodes should be assigned by now
1114
+ assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
1113
1115
 
1114
1116
  // check if we should start a new split based on the sources of the current node
1115
1117
  bool need_new_split = false;
@@ -1452,7 +1454,8 @@ lm_ggml_backend_sched_t lm_ggml_backend_sched_new(
1452
1454
  lm_ggml_backend_buffer_type_t * bufts,
1453
1455
  int n_backends,
1454
1456
  size_t graph_size,
1455
- bool parallel) {
1457
+ bool parallel,
1458
+ bool op_offload) {
1456
1459
  LM_GGML_ASSERT(n_backends > 0);
1457
1460
  LM_GGML_ASSERT(n_backends <= LM_GGML_SCHED_MAX_BACKENDS);
1458
1461
  LM_GGML_ASSERT(lm_ggml_backend_dev_type(lm_ggml_backend_get_device(backends[n_backends - 1])) == LM_GGML_BACKEND_DEVICE_TYPE_CPU);
@@ -1497,6 +1500,7 @@ lm_ggml_backend_sched_t lm_ggml_backend_sched_new(
1497
1500
  }
1498
1501
 
1499
1502
  sched->galloc = lm_ggml_gallocr_new_n(sched->bufts, n_backends);
1503
+ sched->op_offload = op_offload;
1500
1504
 
1501
1505
  lm_ggml_backend_sched_reset(sched);
1502
1506
 
@@ -38,7 +38,7 @@ extern "C" {
38
38
  LM_GGML_API lm_ggml_backend_buffer_t lm_ggml_backend_buft_alloc_buffer (lm_ggml_backend_buffer_type_t buft, size_t size);
39
39
  LM_GGML_API size_t lm_ggml_backend_buft_get_alignment (lm_ggml_backend_buffer_type_t buft);
40
40
  LM_GGML_API size_t lm_ggml_backend_buft_get_max_size (lm_ggml_backend_buffer_type_t buft);
41
- LM_GGML_API size_t lm_ggml_backend_buft_get_alloc_size(lm_ggml_backend_buffer_type_t buft, struct lm_ggml_tensor * tensor);
41
+ LM_GGML_API size_t lm_ggml_backend_buft_get_alloc_size(lm_ggml_backend_buffer_type_t buft, const struct lm_ggml_tensor * tensor);
42
42
  LM_GGML_API bool lm_ggml_backend_buft_is_host (lm_ggml_backend_buffer_type_t buft);
43
43
  LM_GGML_API lm_ggml_backend_dev_t lm_ggml_backend_buft_get_device (lm_ggml_backend_buffer_type_t buft);
44
44
 
@@ -59,7 +59,7 @@ extern "C" {
59
59
  LM_GGML_API enum lm_ggml_status lm_ggml_backend_buffer_init_tensor (lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor);
60
60
  LM_GGML_API size_t lm_ggml_backend_buffer_get_alignment (lm_ggml_backend_buffer_t buffer);
61
61
  LM_GGML_API size_t lm_ggml_backend_buffer_get_max_size (lm_ggml_backend_buffer_t buffer);
62
- LM_GGML_API size_t lm_ggml_backend_buffer_get_alloc_size(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor);
62
+ LM_GGML_API size_t lm_ggml_backend_buffer_get_alloc_size(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * tensor);
63
63
  LM_GGML_API void lm_ggml_backend_buffer_clear (lm_ggml_backend_buffer_t buffer, uint8_t value);
64
64
  LM_GGML_API bool lm_ggml_backend_buffer_is_host (lm_ggml_backend_buffer_t buffer);
65
65
  LM_GGML_API void lm_ggml_backend_buffer_set_usage (lm_ggml_backend_buffer_t buffer, enum lm_ggml_backend_buffer_usage usage);
@@ -248,7 +248,7 @@ extern "C" {
248
248
  // preferrably to run on the same backend as the buffer
249
249
  lm_ggml_backend_buffer_set_usage(buf_weights, LM_GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
250
250
 
251
- sched = lm_ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, LM_GGML_DEFAULT_GRAPH_SIZE, false);
251
+ sched = lm_ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, LM_GGML_DEFAULT_GRAPH_SIZE, false, true);
252
252
 
253
253
  // initialize buffers from a max size graph (optional)
254
254
  reserve_graph = build_graph(sched, max_batch_size);
@@ -289,7 +289,7 @@ extern "C" {
289
289
  typedef bool (*lm_ggml_backend_sched_eval_callback)(struct lm_ggml_tensor * t, bool ask, void * user_data);
290
290
 
291
291
  // Initialize a backend scheduler, backends with low index are given priority over backends with high index
292
- LM_GGML_API lm_ggml_backend_sched_t lm_ggml_backend_sched_new(lm_ggml_backend_t * backends, lm_ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
292
+ LM_GGML_API lm_ggml_backend_sched_t lm_ggml_backend_sched_new(lm_ggml_backend_t * backends, lm_ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload);
293
293
  LM_GGML_API void lm_ggml_backend_sched_free(lm_ggml_backend_sched_t sched);
294
294
 
295
295
  // Initialize backend buffers from a measure graph
@@ -72,8 +72,6 @@ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(lm_ggml_half) + QK4_NL * 2, "
72
72
 
73
73
  #if defined(__GNUC__)
74
74
  #pragma GCC diagnostic ignored "-Woverlength-strings"
75
- #elif defined(_MSC_VER)
76
- #pragma warning(disable: 4244 4267) // possible loss of data
77
75
  #endif
78
76
 
79
77
  #define UNUSED LM_GGML_UNUSED
@@ -20,12 +20,6 @@
20
20
  #define GROUP_MAX_EPS_IQ1_M 1e-7f
21
21
  #define GROUP_MAX_EPS_IQ1_S 1e-12f
22
22
 
23
- #if defined(_MSC_VER)
24
- // disable "possible loss of data" to avoid warnings for hundreds of casts
25
- // we should just be careful :)
26
- #pragma warning(disable: 4244 4267)
27
- #endif
28
-
29
23
  #define UNUSED LM_GGML_UNUSED
30
24
 
31
25
  // some compilers don't provide _mm256_set_m128i, e.g. gcc 7
@@ -6596,7 +6590,118 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
6596
6590
  }
6597
6591
 
6598
6592
  *s = hsum_float_8(acc);
6593
+ #elif defined(__VXE__) || defined(__VXE2__)
6594
+ uint32_t aux[3];
6595
+ uint32_t utmp[4];
6596
+
6597
+ const int32x4_t v_z = vec_splat_s32(0);
6598
+ const uint8x16_t v_3m = vec_splat_u8(0x03);
6599
+
6600
+ const uint8x16_t v_0c = vec_splat_u8(1);
6601
+ const uint8x16_t v_1c = vec_sl(v_0c, 1);
6602
+ const uint8x16_t v_2c = vec_sl(v_0c, 2);
6603
+ const uint8x16_t v_3c = vec_sl(v_0c, 3);
6604
+
6605
+ uint8x16_t q3h[4];
6606
+ uint8x16_t q3b[2];
6607
+ int8x16_t q3bytes[4];
6608
+ int8x16_t q8bytes[4];
6609
+ uint8x16_t qhbits[2];
6610
+
6611
+ float sum = 0;
6612
+
6613
+ for (int i = 0; i < nb; ++i) {
6614
+ const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
6615
+
6616
+ const uint8_t * restrict x0l = x[i].qs;
6617
+ const uint8_t * restrict x0h = x[i].hmask;
6618
+ const int8_t * restrict y0 = y[i].qs;
6619
+
6620
+ qhbits[0] = vec_xl(0 , x0h);
6621
+ qhbits[1] = vec_xl(16, x0h);
6622
+
6623
+ int32_t isum = 0;
6624
+
6625
+ memcpy(aux, x[i].scales, 12);
6626
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
6627
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
6628
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
6629
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
6630
+
6631
+ int8_t * scale = (int8_t *)utmp;
6632
+ for (int j = 0; j < 16; ++j) scale[j] -= 32;
6633
+
6634
+ for (int j = 0; j < QK_K/128; ++j) {
6635
+ int32x4_t isum0, isum1, isum2, isum3;
6636
+
6637
+ q3b[0] = vec_xl(0 , x0l);
6638
+ q3b[1] = vec_xl(16, x0l);
6639
+ x0l += 32;
6640
+
6641
+ q8bytes[0] = vec_xl(0 , y0);
6642
+ q8bytes[1] = vec_xl(16 , y0);
6643
+ q8bytes[2] = vec_xl(32 , y0);
6644
+ q8bytes[3] = vec_xl(48 , y0);
6645
+ q8bytes[4] = vec_xl(64 , y0);
6646
+ q8bytes[5] = vec_xl(80 , y0);
6647
+ q8bytes[6] = vec_xl(96 , y0);
6648
+ q8bytes[7] = vec_xl(112, y0);
6649
+ y0 += 128;
6650
+
6651
+ q3h[0] = vec_sl(vec_andc(v_0c, qhbits[0]), 2);
6652
+ q3h[1] = vec_sl(vec_andc(v_0c, qhbits[1]), 2);
6653
+ q3h[2] = vec_sl(vec_andc(v_1c, qhbits[0]), 1);
6654
+ q3h[3] = vec_sl(vec_andc(v_1c, qhbits[1]), 1);
6655
+
6656
+ q3bytes[0] = vec_sub((int8x16_t)vec_and(q3b[0], v_3m), (int8x16_t)q3h[0]);
6657
+ q3bytes[1] = vec_sub((int8x16_t)vec_and(q3b[1], v_3m), (int8x16_t)q3h[1]);
6658
+ q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 2), v_3m), (int8x16_t)q3h[2]);
6659
+ q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 2), v_3m), (int8x16_t)q3h[3]);
6660
+
6661
+ isum0 = lm_ggml_vec_dot(v_z, q3bytes[0], q8bytes[0]);
6662
+ isum1 = lm_ggml_vec_dot(v_z, q3bytes[1], q8bytes[1]);
6663
+ isum2 = lm_ggml_vec_dot(v_z, q3bytes[2], q8bytes[2]);
6664
+ isum3 = lm_ggml_vec_dot(v_z, q3bytes[3], q8bytes[3]);
6665
+
6666
+ isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
6667
+ isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
6668
+ isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
6669
+ isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
6670
+
6671
+ scale += 4;
6672
+
6673
+ q3h[0] = vec_andc(v_2c, qhbits[0]);
6674
+ q3h[1] = vec_andc(v_2c, qhbits[1]);
6675
+ q3h[2] = vec_sr(vec_andc(v_3c, qhbits[0]), 1);
6676
+ q3h[3] = vec_sr(vec_andc(v_3c, qhbits[1]), 1);
6677
+
6678
+ q3bytes[0] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 4), v_3m), (int8x16_t)q3h[0]);
6679
+ q3bytes[1] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 4), v_3m), (int8x16_t)q3h[1]);
6680
+ q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 6), v_3m), (int8x16_t)q3h[2]);
6681
+ q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 6), v_3m), (int8x16_t)q3h[3]);
6682
+
6683
+ isum0 = lm_ggml_vec_dot(v_z, q3bytes[0], q8bytes[4]);
6684
+ isum1 = lm_ggml_vec_dot(v_z, q3bytes[1], q8bytes[5]);
6685
+ isum2 = lm_ggml_vec_dot(v_z, q3bytes[2], q8bytes[6]);
6686
+ isum3 = lm_ggml_vec_dot(v_z, q3bytes[3], q8bytes[7]);
6687
+
6688
+ isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
6689
+ isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
6690
+ isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
6691
+ isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
6692
+
6693
+ scale += 4;
6694
+
6695
+ if (j == 0) {
6696
+ qhbits[0] = vec_sr(qhbits[0], 4);
6697
+ qhbits[1] = vec_sr(qhbits[1], 4);
6698
+ }
6699
+ }
6700
+
6701
+ sum += d * isum;
6702
+ }
6599
6703
 
6704
+ *s = sum;
6600
6705
  #else
6601
6706
  // scalar version
6602
6707
  // This function is written like this so the compiler can manage to vectorize most of it
@@ -8414,7 +8519,11 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
8414
8519
 
8415
8520
  void lm_ggml_vec_dot_q6_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
8416
8521
  assert(n % QK_K == 0);
8522
+ #ifdef __ARM_FEATURE_MATMUL_INT8
8523
+ assert((nrc == 2) || (nrc == 1));
8524
+ #else
8417
8525
  assert(nrc == 1);
8526
+ #endif
8418
8527
  UNUSED(nrc);
8419
8528
  UNUSED(bx);
8420
8529
  UNUSED(by);
@@ -8425,6 +8534,197 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
8425
8534
 
8426
8535
  const int nb = n / QK_K;
8427
8536
 
8537
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
8538
+ if (nrc == 2) {
8539
+ const block_q6_K * LM_GGML_RESTRICT x0 = x;
8540
+ const block_q6_K * LM_GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
8541
+ const block_q8_K * LM_GGML_RESTRICT y0 = y;
8542
+ const block_q8_K * LM_GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
8543
+
8544
+ float32x4_t vfsum = vdupq_n_f32(0.0f);
8545
+
8546
+ for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
8547
+ const uint8_t * LM_GGML_RESTRICT ql0 = x0->ql;
8548
+ const uint8_t * LM_GGML_RESTRICT ql1 = x1->ql;
8549
+ const uint8_t * LM_GGML_RESTRICT qh0 = x0->qh;
8550
+ const uint8_t * LM_GGML_RESTRICT qh1 = x1->qh;
8551
+ const int8_t * LM_GGML_RESTRICT qy0 = y0->qs;
8552
+ const int8_t * LM_GGML_RESTRICT qy1 = y1->qs;
8553
+
8554
+ const uint8x16_t mone = vdupq_n_u8(0x30);
8555
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
8556
+
8557
+ int32x4_t visum = vdupq_n_s32(0);
8558
+
8559
+ // process 8 blocks per iteration, totally 16 blocks
8560
+ for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
8561
+ int8x16_t vx0[8], vx1[8];
8562
+
8563
+ // de-quantize vx0[8]
8564
+ {
8565
+ const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
8566
+ const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
8567
+
8568
+ uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
8569
+ uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
8570
+ uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
8571
+ uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
8572
+
8573
+ vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
8574
+ vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
8575
+ vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
8576
+ vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
8577
+
8578
+ q6h_0 = vandq_u8(mone, qh_bits.val[0]);
8579
+ q6h_1 = vandq_u8(mone, qh_bits.val[1]);
8580
+ q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
8581
+ q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
8582
+
8583
+ vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
8584
+ vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
8585
+ vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
8586
+ vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
8587
+ }
8588
+
8589
+ // de-quantize vx1[8]
8590
+ {
8591
+ const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
8592
+ const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
8593
+
8594
+ uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
8595
+ uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
8596
+ uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
8597
+ uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
8598
+
8599
+ vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
8600
+ vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
8601
+ vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
8602
+ vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
8603
+
8604
+ q6h_0 = vandq_u8(mone, qh_bits.val[0]);
8605
+ q6h_1 = vandq_u8(mone, qh_bits.val[1]);
8606
+ q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
8607
+ q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
8608
+
8609
+ vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
8610
+ vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
8611
+ vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
8612
+ vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
8613
+ }
8614
+
8615
+ // process 16 elements (one block with same scale) per iteration
8616
+ // - vx = concat(ql, qh) - 32
8617
+ // - r1,r2,r3,r4 = smmla(vx, vy)
8618
+ for (int k = 0; k < 8; ++k) {
8619
+ const int blk = j * 8 + k;
8620
+
8621
+ const int8x16_t vy0 = vld1q_s8(qy0);
8622
+ const int8x16_t vy1 = vld1q_s8(qy1);
8623
+ qy0 += 16;
8624
+ qy1 += 16;
8625
+
8626
+ const int32x4_t block_scale = {
8627
+ x0->scales[blk],
8628
+ x0->scales[blk],
8629
+ x1->scales[blk],
8630
+ x1->scales[blk],
8631
+ };
8632
+
8633
+ // calculate four results at once with outer product
8634
+ const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
8635
+ const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
8636
+ const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
8637
+ const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
8638
+ int32x4_t vr = vdupq_n_s32(0);
8639
+ vr = vmmlaq_s32(vr, vx_l, vy_l);
8640
+ vr = vmmlaq_s32(vr, vx_h, vy_h);
8641
+
8642
+ // apply block scale, will NOT overflow
8643
+ // block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
8644
+ visum = vmlaq_s32(visum, vr, block_scale);
8645
+ }
8646
+ }
8647
+
8648
+ // adjust bias, apply superblock scale
8649
+ {
8650
+ int32_t bias[4];
8651
+ #ifdef __ARM_FEATURE_SVE
8652
+ const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
8653
+ const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
8654
+ const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
8655
+ const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
8656
+ const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
8657
+ const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
8658
+ const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
8659
+ const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
8660
+ const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
8661
+ const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
8662
+ const svint64_t zero = svdup_n_s64(0);
8663
+ bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
8664
+ svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
8665
+ bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
8666
+ svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
8667
+ bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
8668
+ svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
8669
+ bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
8670
+ svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
8671
+ #else
8672
+ // NEON doesn't support int16 dot product, fallback to separated mul and add
8673
+ const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
8674
+ const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
8675
+
8676
+ int8x16_t scales_s8 = vld1q_s8(x0->scales);
8677
+ const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
8678
+ scales_s8 = vld1q_s8(x1->scales);
8679
+ const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
8680
+
8681
+ int32x4_t prod;
8682
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
8683
+ vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
8684
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
8685
+ vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
8686
+ bias[0] = vaddvq_s32(prod);
8687
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
8688
+ vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
8689
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
8690
+ vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
8691
+ bias[1] = vaddvq_s32(prod);
8692
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
8693
+ vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
8694
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
8695
+ vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
8696
+ bias[2] = vaddvq_s32(prod);
8697
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
8698
+ vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
8699
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
8700
+ vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
8701
+ bias[3] = vaddvq_s32(prod);
8702
+
8703
+ #endif
8704
+ const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
8705
+
8706
+ const float32x4_t superblock_scale = {
8707
+ LM_GGML_FP16_TO_FP32(x0->d) * y0->d,
8708
+ LM_GGML_FP16_TO_FP32(x0->d) * y1->d,
8709
+ LM_GGML_FP16_TO_FP32(x1->d) * y0->d,
8710
+ LM_GGML_FP16_TO_FP32(x1->d) * y1->d,
8711
+ };
8712
+
8713
+ visum = vsubq_s32(visum, vibias);
8714
+ vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
8715
+ }
8716
+ }
8717
+
8718
+ // vfsum = ABCD -> ACBD
8719
+ // AC -> s, BD -> (s+bs)
8720
+ vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
8721
+ vst1_f32(s, vget_low_f32 (vfsum));
8722
+ vst1_f32(s + bs, vget_high_f32(vfsum));
8723
+
8724
+ return;
8725
+ }
8726
+ #endif
8727
+
8428
8728
  #ifdef __ARM_FEATURE_SVE
8429
8729
  const int vector_length = lm_ggml_cpu_get_sve_cnt()*8;
8430
8730
  float sum = 0;
@@ -50,19 +50,6 @@
50
50
  #include "llamafile/sgemm.h"
51
51
  #endif
52
52
 
53
- #if defined(_MSC_VER)
54
- // disable "possible loss of data" to avoid hundreds of casts
55
- // we should just be careful :)
56
- #pragma warning(disable: 4244 4267)
57
-
58
- // disable POSIX deprecation warnings
59
- // these functions are never going away, anyway
60
- #pragma warning(disable: 4996)
61
-
62
- // unreachable code because of multiple instances of code after LM_GGML_ABORT
63
- #pragma warning(disable: 4702)
64
- #endif
65
-
66
53
  // Note: once we move threading into a separate C++ file
67
54
  // will use std::hardware_destructive_interference_size instead of hardcoding it here
68
55
  // and we'll use C++ attribute syntax.
@@ -295,7 +282,11 @@ static const struct lm_ggml_type_traits_cpu type_traits_cpu[LM_GGML_TYPE_COUNT]
295
282
  .from_float = quantize_row_q6_K,
296
283
  .vec_dot = lm_ggml_vec_dot_q6_K_q8_K,
297
284
  .vec_dot_type = LM_GGML_TYPE_Q8_K,
285
+ #if defined (__ARM_FEATURE_MATMUL_INT8)
286
+ .nrows = 2,
287
+ #else
298
288
  .nrows = 1,
289
+ #endif
299
290
  },
300
291
  [LM_GGML_TYPE_IQ2_XXS] = {
301
292
  .from_float = NULL,
@@ -2211,6 +2202,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
2211
2202
  } break;
2212
2203
 
2213
2204
  case LM_GGML_UNARY_OP_GELU:
2205
+ case LM_GGML_UNARY_OP_GELU_ERF:
2214
2206
  case LM_GGML_UNARY_OP_GELU_QUICK:
2215
2207
  case LM_GGML_UNARY_OP_SILU:
2216
2208
  {