cui-llama.rn 1.3.0 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. package/android/src/main/CMakeLists.txt +9 -6
  2. package/android/src/main/java/com/rnllama/LlamaContext.java +4 -4
  3. package/android/src/main/jni.cpp +15 -15
  4. package/cpp/common.cpp +1962 -1682
  5. package/cpp/common.h +645 -600
  6. package/cpp/ggml-alloc.c +1038 -1040
  7. package/cpp/ggml-alloc.h +76 -76
  8. package/cpp/ggml-backend-impl.h +256 -216
  9. package/cpp/ggml-backend-reg.cpp +552 -195
  10. package/cpp/ggml-backend.cpp +1999 -1997
  11. package/cpp/ggml-backend.h +352 -328
  12. package/cpp/ggml-common.h +1853 -1853
  13. package/cpp/ggml-cpp.h +38 -38
  14. package/cpp/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +4262 -3560
  15. package/cpp/ggml-cpu-aarch64.h +8 -30
  16. package/cpp/ggml-cpu-impl.h +386 -371
  17. package/cpp/ggml-cpu-quants.c +10835 -10822
  18. package/cpp/ggml-cpu-quants.h +63 -63
  19. package/cpp/ggml-cpu-traits.cpp +36 -0
  20. package/cpp/ggml-cpu-traits.h +38 -0
  21. package/cpp/ggml-cpu.c +14122 -13975
  22. package/cpp/ggml-cpu.cpp +618 -663
  23. package/cpp/ggml-cpu.h +135 -177
  24. package/cpp/ggml-impl.h +556 -550
  25. package/cpp/ggml-metal.h +66 -66
  26. package/cpp/ggml-metal.m +4884 -4294
  27. package/cpp/ggml-quants.c +5238 -5247
  28. package/cpp/ggml-quants.h +100 -100
  29. package/cpp/ggml-threading.cpp +12 -12
  30. package/cpp/ggml-threading.h +14 -12
  31. package/cpp/ggml.c +7707 -8180
  32. package/cpp/ggml.h +2286 -2411
  33. package/cpp/json-schema-to-grammar.cpp +1045 -0
  34. package/cpp/json-schema-to-grammar.h +8 -0
  35. package/cpp/json.hpp +24766 -0
  36. package/cpp/llama-grammar.cpp +1138 -1138
  37. package/cpp/llama-grammar.h +144 -144
  38. package/cpp/llama-impl.h +181 -181
  39. package/cpp/llama-sampling.cpp +2293 -2348
  40. package/cpp/llama-sampling.h +48 -48
  41. package/cpp/llama-vocab.cpp +1985 -1984
  42. package/cpp/llama-vocab.h +170 -170
  43. package/cpp/llama.cpp +22836 -22132
  44. package/cpp/llama.h +1263 -1253
  45. package/cpp/log.cpp +401 -401
  46. package/cpp/log.h +121 -121
  47. package/cpp/rn-llama.hpp +6 -6
  48. package/cpp/sampling.cpp +500 -466
  49. package/cpp/sampling.h +22 -1
  50. package/cpp/sgemm.cpp +1884 -1884
  51. package/cpp/speculative.cpp +274 -0
  52. package/cpp/speculative.h +28 -0
  53. package/cpp/unicode.cpp +62 -51
  54. package/cpp/unicode.h +9 -10
  55. package/ios/RNLlamaContext.mm +13 -0
  56. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  57. package/lib/commonjs/grammar.js +4 -2
  58. package/lib/commonjs/grammar.js.map +1 -1
  59. package/lib/commonjs/index.js +38 -1
  60. package/lib/commonjs/index.js.map +1 -1
  61. package/lib/module/NativeRNLlama.js.map +1 -1
  62. package/lib/module/grammar.js +2 -1
  63. package/lib/module/grammar.js.map +1 -1
  64. package/lib/module/index.js +36 -0
  65. package/lib/module/index.js.map +1 -1
  66. package/lib/typescript/NativeRNLlama.d.ts +95 -6
  67. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  68. package/lib/typescript/grammar.d.ts +5 -6
  69. package/lib/typescript/grammar.d.ts.map +1 -1
  70. package/lib/typescript/index.d.ts +40 -4
  71. package/lib/typescript/index.d.ts.map +1 -1
  72. package/package.json +2 -1
  73. package/src/NativeRNLlama.ts +99 -12
  74. package/src/grammar.ts +10 -8
  75. package/src/index.ts +68 -3
  76. package/cpp/ggml-aarch64.c +0 -129
  77. package/cpp/ggml-aarch64.h +0 -19
@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.10)
2
2
 
3
3
  project(llama.rn)
4
4
 
5
- set(CMAKE_CXX_STANDARD 11)
5
+ set(CMAKE_CXX_STANDARD 17)
6
6
  set(RNLLAMA_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../cpp)
7
7
 
8
8
  include_directories(${RNLLAMA_LIB_DIR})
@@ -13,21 +13,25 @@ set(
13
13
  ${RNLLAMA_LIB_DIR}/llama-sampling.cpp
14
14
  ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
15
15
  ${RNLLAMA_LIB_DIR}/log.cpp
16
+
17
+ #${RNLLAMA_LIB_DIR}/amx/amx.cpp
18
+ #${RNLLAMA_LIB_DIR}/amx/mmq.cpp
16
19
 
17
- ${RNLLAMA_LIB_DIR}/ggml-aarch64.c
18
20
  ${RNLLAMA_LIB_DIR}/llama-grammar.cpp
19
21
  ${RNLLAMA_LIB_DIR}/llama-sampling.cpp
20
22
  ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
21
23
  ${RNLLAMA_LIB_DIR}/log.cpp
24
+ ${RNLLAMA_LIB_DIR}/json.hpp
25
+ ${RNLLAMA_LIB_DIR}/json-schema-to-grammar.cpp
22
26
 
23
- ${RNLLAMA_LIB_DIR}/ggml-aarch64.c
24
27
  ${RNLLAMA_LIB_DIR}/ggml-alloc.c
25
28
  ${RNLLAMA_LIB_DIR}/ggml-backend.cpp
26
29
  ${RNLLAMA_LIB_DIR}/ggml-backend-reg.cpp
27
30
  ${RNLLAMA_LIB_DIR}/ggml.c
28
31
  ${RNLLAMA_LIB_DIR}/ggml-cpu.c
29
32
  ${RNLLAMA_LIB_DIR}/ggml-cpu.cpp
30
- ${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.c
33
+ ${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.cpp
34
+ ${RNLLAMA_LIB_DIR}/ggml-cpu-traits.cpp
31
35
  ${RNLLAMA_LIB_DIR}/ggml-cpu-quants.c
32
36
  ${RNLLAMA_LIB_DIR}/ggml-threading.cpp
33
37
  ${RNLLAMA_LIB_DIR}/ggml-quants.c
@@ -37,7 +41,6 @@ set(
37
41
  ${RNLLAMA_LIB_DIR}/unicode.cpp
38
42
  ${RNLLAMA_LIB_DIR}/llama.cpp
39
43
  ${RNLLAMA_LIB_DIR}/sgemm.cpp
40
- ${RNLLAMA_LIB_DIR}/ggml-aarch64.c
41
44
  ${RNLLAMA_LIB_DIR}/rn-llama.hpp
42
45
  ${CMAKE_SOURCE_DIR}/jni.cpp
43
46
  )
@@ -53,7 +56,7 @@ function(build_library target_name cpu_flags)
53
56
 
54
57
  target_link_libraries(${target_name} ${LOG_LIB} android)
55
58
 
56
- target_compile_options(${target_name} PRIVATE -pthread ${cpu_flags})
59
+ target_compile_options(${target_name} PRIVATE -pthread ${cpu_flags} -DLM_GGML_USE_CPU -DLM_GGML_USE_CPU_AARCH64)
57
60
 
58
61
  if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
59
62
  target_compile_options(${target_name} PRIVATE -DRNLLAMA_ANDROID_ENABLE_LOGGING)
@@ -115,9 +115,9 @@ public class LlamaContext {
115
115
  // boolean flash_attn,
116
116
  params.hasKey("flash_attn") ? params.getBoolean("flash_attn") : false,
117
117
  // String cache_type_k,
118
- params.hasKey("cache_type_k") ? params.getString("cache_type_k") : "f16",
118
+ params.hasKey("cache_type_k") ? params.getInt("cache_type_k") : 1,
119
119
  // String cache_type_v,
120
- params.hasKey("cache_type_v") ? params.getString("cache_type_v") : "f16",
120
+ params.hasKey("cache_type_v") ? params.getInt("cache_type_v") : 1,
121
121
  // boolean use_mlock,
122
122
  params.hasKey("use_mlock") ? params.getBoolean("use_mlock") : true,
123
123
  // boolean use_mmap,
@@ -463,8 +463,8 @@ public class LlamaContext {
463
463
  int n_threads,
464
464
  int n_gpu_layers, // TODO: Support this
465
465
  boolean flash_attn,
466
- String cache_type_k,
467
- String cache_type_v,
466
+ int cache_type_k,
467
+ int cache_type_v,
468
468
  boolean use_mlock,
469
469
  boolean use_mmap,
470
470
  boolean vocab_only,
@@ -236,8 +236,8 @@ Java_com_rnllama_LlamaContext_initContext(
236
236
  jint n_threads,
237
237
  jint n_gpu_layers, // TODO: Support this
238
238
  jboolean flash_attn,
239
- jstring cache_type_k,
240
- jstring cache_type_v,
239
+ jint cache_type_k,
240
+ jint cache_type_v,
241
241
  jboolean use_mlock,
242
242
  jboolean use_mmap,
243
243
  jboolean vocab_only,
@@ -259,7 +259,7 @@ Java_com_rnllama_LlamaContext_initContext(
259
259
 
260
260
  const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
261
261
  defaultParams.model = model_path_chars;
262
-
262
+
263
263
  defaultParams.n_ctx = n_ctx;
264
264
  defaultParams.n_batch = n_batch;
265
265
 
@@ -281,13 +281,13 @@ Java_com_rnllama_LlamaContext_initContext(
281
281
  int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
282
282
  defaultParams.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
283
283
 
284
- defaultParams.n_gpu_layers = n_gpu_layers;
284
+ // defaultParams.n_gpu_layers = n_gpu_layers;
285
285
  defaultParams.flash_attn = flash_attn;
286
286
 
287
- const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
288
- const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
289
- defaultParams.cache_type_k = cache_type_k_chars;
290
- defaultParams.cache_type_v = cache_type_v_chars;
287
+ // const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
288
+ // const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
289
+ defaultParams.cache_type_k = (lm_ggml_type) cache_type_k;
290
+ defaultParams.cache_type_v = (lm_ggml_type) cache_type_v;
291
291
 
292
292
  defaultParams.use_mlock = use_mlock;
293
293
  defaultParams.use_mmap = use_mmap;
@@ -331,8 +331,8 @@ Java_com_rnllama_LlamaContext_initContext(
331
331
 
332
332
  env->ReleaseStringUTFChars(model_path_str, model_path_chars);
333
333
  env->ReleaseStringUTFChars(lora_str, lora_chars);
334
- env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
335
- env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
334
+ // env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
335
+ // env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
336
336
 
337
337
  LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
338
338
  if (is_model_loaded) {
@@ -558,7 +558,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
558
558
  //llama_reset_timings(llama->ctx);
559
559
 
560
560
  llama->params.prompt = env->GetStringUTFChars(prompt, nullptr);
561
- llama->params.sparams.seed = (seed == -1) ? time(NULL) : seed;
561
+ llama->params.sampling.seed = (seed == -1) ? time(NULL) : seed;
562
562
 
563
563
  int max_threads = std::thread::hardware_concurrency();
564
564
  // Use 2 threads by default on 4-core devices, 4 threads on more cores
@@ -566,9 +566,9 @@ Java_com_rnllama_LlamaContext_doCompletion(
566
566
  llama->params.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
567
567
 
568
568
  llama->params.n_predict = n_predict;
569
- llama->params.sparams.ignore_eos = ignore_eos;
569
+ llama->params.sampling.ignore_eos = ignore_eos;
570
570
 
571
- auto & sparams = llama->params.sparams;
571
+ auto & sparams = llama->params.sampling;
572
572
  sparams.temp = temperature;
573
573
  sparams.penalty_last_n = penalty_last_n;
574
574
  sparams.penalty_repeat = penalty_repeat;
@@ -577,7 +577,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
577
577
  sparams.mirostat = mirostat;
578
578
  sparams.mirostat_tau = mirostat_tau;
579
579
  sparams.mirostat_eta = mirostat_eta;
580
- sparams.penalize_nl = penalize_nl;
580
+ // sparams.penalize_nl = penalize_nl;
581
581
  sparams.top_k = top_k;
582
582
  sparams.top_p = top_p;
583
583
  sparams.min_p = min_p;
@@ -693,7 +693,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
693
693
  auto tokenResult = createWriteableMap(env);
694
694
  putString(env, tokenResult, "token", to_send.c_str());
695
695
 
696
- if (llama->params.sparams.n_probs > 0) {
696
+ if (llama->params.sampling.n_probs > 0) {
697
697
  const std::vector<llama_token> to_send_toks = common_tokenize(llama->ctx, to_send, false);
698
698
  size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
699
699
  size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());