cui-llama.rn 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. package/android/src/main/CMakeLists.txt +5 -7
  2. package/android/src/main/java/com/rnllama/LlamaContext.java +4 -4
  3. package/android/src/main/jni.cpp +9 -9
  4. package/cpp/common.cpp +28 -44
  5. package/cpp/common.h +35 -14
  6. package/cpp/ggml-alloc.c +0 -1
  7. package/cpp/ggml-backend-impl.h +38 -20
  8. package/cpp/ggml-backend-reg.cpp +246 -92
  9. package/cpp/ggml-backend.h +1 -0
  10. package/cpp/ggml-common.h +42 -48
  11. package/cpp/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +642 -223
  12. package/cpp/ggml-cpu-aarch64.h +2 -26
  13. package/cpp/ggml-cpu-traits.cpp +36 -0
  14. package/cpp/ggml-cpu-traits.h +38 -0
  15. package/cpp/ggml-cpu.c +14122 -13971
  16. package/cpp/ggml-cpu.cpp +627 -715
  17. package/cpp/ggml-cpu.h +0 -17
  18. package/cpp/ggml-impl.h +22 -6
  19. package/cpp/ggml-metal.m +482 -24
  20. package/cpp/ggml-quants.c +0 -9
  21. package/cpp/ggml-threading.h +4 -2
  22. package/cpp/ggml.c +284 -178
  23. package/cpp/ggml.h +73 -25
  24. package/cpp/llama-grammar.cpp +15 -15
  25. package/cpp/llama-grammar.h +2 -5
  26. package/cpp/llama-sampling.cpp +35 -90
  27. package/cpp/llama-vocab.cpp +7 -2
  28. package/cpp/llama-vocab.h +1 -1
  29. package/cpp/llama.cpp +1782 -586
  30. package/cpp/llama.h +20 -19
  31. package/cpp/sampling.cpp +11 -16
  32. package/cpp/sgemm.cpp +265 -258
  33. package/cpp/sgemm.h +2 -2
  34. package/cpp/speculative.cpp +4 -0
  35. package/cpp/unicode.cpp +51 -51
  36. package/cpp/unicode.h +9 -10
  37. package/lib/commonjs/index.js +38 -1
  38. package/lib/commonjs/index.js.map +1 -1
  39. package/lib/module/index.js +36 -0
  40. package/lib/module/index.js.map +1 -1
  41. package/lib/typescript/NativeRNLlama.d.ts +2 -3
  42. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  43. package/lib/typescript/index.d.ts +36 -2
  44. package/lib/typescript/index.d.ts.map +1 -1
  45. package/package.json +1 -1
  46. package/src/NativeRNLlama.ts +3 -3
  47. package/src/index.ts +46 -2
  48. package/cpp/amx/amx.cpp +0 -196
  49. package/cpp/amx/amx.h +0 -20
  50. package/cpp/amx/common.h +0 -101
  51. package/cpp/amx/mmq.cpp +0 -2524
  52. package/cpp/amx/mmq.h +0 -16
  53. package/cpp/ggml-aarch64.c +0 -129
  54. package/cpp/ggml-aarch64.h +0 -19
@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.10)
2
2
 
3
3
  project(llama.rn)
4
4
 
5
- set(CMAKE_CXX_STANDARD 11)
5
+ set(CMAKE_CXX_STANDARD 17)
6
6
  set(RNLLAMA_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../cpp)
7
7
 
8
8
  include_directories(${RNLLAMA_LIB_DIR})
@@ -14,10 +14,9 @@ set(
14
14
  ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
15
15
  ${RNLLAMA_LIB_DIR}/log.cpp
16
16
 
17
- ${RNLLAMA_LIB_DIR}/amx/amx.cpp
18
- ${RNLLAMA_LIB_DIR}/amx/mmq.cpp
17
+ #${RNLLAMA_LIB_DIR}/amx/amx.cpp
18
+ #${RNLLAMA_LIB_DIR}/amx/mmq.cpp
19
19
 
20
- ${RNLLAMA_LIB_DIR}/ggml-aarch64.c
21
20
  ${RNLLAMA_LIB_DIR}/llama-grammar.cpp
22
21
  ${RNLLAMA_LIB_DIR}/llama-sampling.cpp
23
22
  ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
@@ -25,14 +24,14 @@ set(
25
24
  ${RNLLAMA_LIB_DIR}/json.hpp
26
25
  ${RNLLAMA_LIB_DIR}/json-schema-to-grammar.cpp
27
26
 
28
- ${RNLLAMA_LIB_DIR}/ggml-aarch64.c
29
27
  ${RNLLAMA_LIB_DIR}/ggml-alloc.c
30
28
  ${RNLLAMA_LIB_DIR}/ggml-backend.cpp
31
29
  ${RNLLAMA_LIB_DIR}/ggml-backend-reg.cpp
32
30
  ${RNLLAMA_LIB_DIR}/ggml.c
33
31
  ${RNLLAMA_LIB_DIR}/ggml-cpu.c
34
32
  ${RNLLAMA_LIB_DIR}/ggml-cpu.cpp
35
- ${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.c
33
+ ${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.cpp
34
+ ${RNLLAMA_LIB_DIR}/ggml-cpu-traits.cpp
36
35
  ${RNLLAMA_LIB_DIR}/ggml-cpu-quants.c
37
36
  ${RNLLAMA_LIB_DIR}/ggml-threading.cpp
38
37
  ${RNLLAMA_LIB_DIR}/ggml-quants.c
@@ -42,7 +41,6 @@ set(
42
41
  ${RNLLAMA_LIB_DIR}/unicode.cpp
43
42
  ${RNLLAMA_LIB_DIR}/llama.cpp
44
43
  ${RNLLAMA_LIB_DIR}/sgemm.cpp
45
- ${RNLLAMA_LIB_DIR}/ggml-aarch64.c
46
44
  ${RNLLAMA_LIB_DIR}/rn-llama.hpp
47
45
  ${CMAKE_SOURCE_DIR}/jni.cpp
48
46
  )
@@ -115,9 +115,9 @@ public class LlamaContext {
115
115
  // boolean flash_attn,
116
116
  params.hasKey("flash_attn") ? params.getBoolean("flash_attn") : false,
117
117
  // String cache_type_k,
118
- params.hasKey("cache_type_k") ? params.getString("cache_type_k") : "f16",
118
+ params.hasKey("cache_type_k") ? params.getInt("cache_type_k") : 1,
119
119
  // String cache_type_v,
120
- params.hasKey("cache_type_v") ? params.getString("cache_type_v") : "f16",
120
+ params.hasKey("cache_type_v") ? params.getInt("cache_type_v") : 1,
121
121
  // boolean use_mlock,
122
122
  params.hasKey("use_mlock") ? params.getBoolean("use_mlock") : true,
123
123
  // boolean use_mmap,
@@ -463,8 +463,8 @@ public class LlamaContext {
463
463
  int n_threads,
464
464
  int n_gpu_layers, // TODO: Support this
465
465
  boolean flash_attn,
466
- String cache_type_k,
467
- String cache_type_v,
466
+ int cache_type_k,
467
+ int cache_type_v,
468
468
  boolean use_mlock,
469
469
  boolean use_mmap,
470
470
  boolean vocab_only,
@@ -236,8 +236,8 @@ Java_com_rnllama_LlamaContext_initContext(
236
236
  jint n_threads,
237
237
  jint n_gpu_layers, // TODO: Support this
238
238
  jboolean flash_attn,
239
- jstring cache_type_k,
240
- jstring cache_type_v,
239
+ jint cache_type_k,
240
+ jint cache_type_v,
241
241
  jboolean use_mlock,
242
242
  jboolean use_mmap,
243
243
  jboolean vocab_only,
@@ -284,10 +284,10 @@ Java_com_rnllama_LlamaContext_initContext(
284
284
  // defaultParams.n_gpu_layers = n_gpu_layers;
285
285
  defaultParams.flash_attn = flash_attn;
286
286
 
287
- const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
288
- const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
289
- defaultParams.cache_type_k = cache_type_k_chars;
290
- defaultParams.cache_type_v = cache_type_v_chars;
287
+ // const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
288
+ // const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
289
+ defaultParams.cache_type_k = (lm_ggml_type) cache_type_k;
290
+ defaultParams.cache_type_v = (lm_ggml_type) cache_type_v;
291
291
 
292
292
  defaultParams.use_mlock = use_mlock;
293
293
  defaultParams.use_mmap = use_mmap;
@@ -331,8 +331,8 @@ Java_com_rnllama_LlamaContext_initContext(
331
331
 
332
332
  env->ReleaseStringUTFChars(model_path_str, model_path_chars);
333
333
  env->ReleaseStringUTFChars(lora_str, lora_chars);
334
- env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
335
- env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
334
+ // env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
335
+ // env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
336
336
 
337
337
  LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
338
338
  if (is_model_loaded) {
@@ -577,7 +577,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
577
577
  sparams.mirostat = mirostat;
578
578
  sparams.mirostat_tau = mirostat_tau;
579
579
  sparams.mirostat_eta = mirostat_eta;
580
- sparams.penalize_nl = penalize_nl;
580
+ // sparams.penalize_nl = penalize_nl;
581
581
  sparams.top_k = top_k;
582
582
  sparams.top_p = top_p;
583
583
  sparams.min_p = min_p;
package/cpp/common.cpp CHANGED
@@ -946,6 +946,25 @@ struct common_init_result common_init_from_params(common_params & params) {
946
946
  params.sampling.ignore_eos = false;
947
947
  }
948
948
 
949
+ if (params.sampling.ignore_eos) {
950
+ for (llama_token i = 0; i < llama_n_vocab(model); i++) {
951
+ if (llama_token_is_eog(model, i)) {
952
+ LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
953
+ params.sampling.logit_bias.push_back({i, -INFINITY});
954
+ }
955
+ }
956
+ }
957
+
958
+ if (params.sampling.penalty_last_n == -1) {
959
+ LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
960
+ params.sampling.penalty_last_n = llama_n_ctx(lctx);
961
+ }
962
+
963
+ if (params.sampling.dry_penalty_last_n == -1) {
964
+ LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
965
+ params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
966
+ }
967
+
949
968
  if (params.warmup) {
950
969
  LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
951
970
 
@@ -1025,38 +1044,6 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
1025
1044
  return mparams;
1026
1045
  }
1027
1046
 
1028
- static lm_ggml_type kv_cache_type_from_str(const std::string & s) {
1029
- if (s == "f32") {
1030
- return LM_GGML_TYPE_F32;
1031
- }
1032
- if (s == "f16") {
1033
- return LM_GGML_TYPE_F16;
1034
- }
1035
- if (s == "bf16") {
1036
- return LM_GGML_TYPE_BF16;
1037
- }
1038
- if (s == "q8_0") {
1039
- return LM_GGML_TYPE_Q8_0;
1040
- }
1041
- if (s == "q4_0") {
1042
- return LM_GGML_TYPE_Q4_0;
1043
- }
1044
- if (s == "q4_1") {
1045
- return LM_GGML_TYPE_Q4_1;
1046
- }
1047
- if (s == "iq4_nl") {
1048
- return LM_GGML_TYPE_IQ4_NL;
1049
- }
1050
- if (s == "q5_0") {
1051
- return LM_GGML_TYPE_Q5_0;
1052
- }
1053
- if (s == "q5_1") {
1054
- return LM_GGML_TYPE_Q5_1;
1055
- }
1056
-
1057
- throw std::runtime_error("Unsupported cache type: " + s);
1058
- }
1059
-
1060
1047
  struct llama_context_params common_context_params_to_llama(const common_params & params) {
1061
1048
  auto cparams = llama_context_default_params();
1062
1049
 
@@ -1091,8 +1078,8 @@ struct llama_context_params common_context_params_to_llama(const common_params &
1091
1078
  cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
1092
1079
  }
1093
1080
 
1094
- cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
1095
- cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
1081
+ cparams.type_k = params.cache_type_k;
1082
+ cparams.type_v = params.cache_type_v;
1096
1083
 
1097
1084
  return cparams;
1098
1085
  }
@@ -1118,13 +1105,7 @@ struct lm_ggml_threadpool_params lm_ggml_threadpool_params_from_cpu_params(const
1118
1105
  #define CURL_MAX_RETRY 3
1119
1106
  #define CURL_RETRY_DELAY_SECONDS 2
1120
1107
 
1121
-
1122
- static bool starts_with(const std::string & str, const std::string & prefix) {
1123
- // While we wait for C++20's std::string::starts_with...
1124
- return str.rfind(prefix, 0) == 0;
1125
- }
1126
-
1127
- static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_attempts, int retry_delay_seconds) {
1108
+ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) {
1128
1109
  int remaining_attempts = max_attempts;
1129
1110
 
1130
1111
  while (remaining_attempts > 0) {
@@ -1148,7 +1129,6 @@ static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_
1148
1129
  }
1149
1130
 
1150
1131
  static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
1151
-
1152
1132
  // Initialize libcurl
1153
1133
  std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
1154
1134
  if (!curl) {
@@ -1221,11 +1201,13 @@ static bool common_download_file(const std::string & url, const std::string & pa
1221
1201
  std::string etag;
1222
1202
  std::string last_modified;
1223
1203
  };
1204
+
1224
1205
  common_load_model_from_url_headers headers;
1206
+
1225
1207
  {
1226
1208
  typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
1227
1209
  auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
1228
- common_load_model_from_url_headers *headers = (common_load_model_from_url_headers *) userdata;
1210
+ common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
1229
1211
 
1230
1212
  static std::regex header_regex("([^:]+): (.*)\r\n");
1231
1213
  static std::regex etag_regex("ETag", std::regex_constants::icase);
@@ -1809,7 +1791,9 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm)
1809
1791
  break;
1810
1792
  case 0: // max absolute
1811
1793
  for (int i = 0; i < n; i++) {
1812
- if (sum < std::abs(inp[i])) sum = std::abs(inp[i]);
1794
+ if (sum < std::abs(inp[i])) {
1795
+ sum = std::abs(inp[i]);
1796
+ }
1813
1797
  }
1814
1798
  sum /= 32760.0; // make an int16 range
1815
1799
  break;
package/cpp/common.h CHANGED
@@ -37,9 +37,9 @@ using llama_tokens = std::vector<llama_token>;
37
37
 
38
38
  // build info
39
39
  extern int LLAMA_BUILD_NUMBER;
40
- extern char const * LLAMA_COMMIT;
41
- extern char const * LLAMA_COMPILER;
42
- extern char const * LLAMA_BUILD_TARGET;
40
+ extern const char * LLAMA_COMMIT;
41
+ extern const char * LLAMA_COMPILER;
42
+ extern const char * LLAMA_BUILD_TARGET;
43
43
 
44
44
  struct common_control_vector_load_info;
45
45
 
@@ -91,6 +91,7 @@ enum llama_example {
91
91
  LLAMA_EXAMPLE_LLAVA,
92
92
  LLAMA_EXAMPLE_LOOKUP,
93
93
  LLAMA_EXAMPLE_PARALLEL,
94
+ LLAMA_EXAMPLE_TTS,
94
95
 
95
96
  LLAMA_EXAMPLE_COUNT,
96
97
  };
@@ -106,6 +107,7 @@ enum common_sampler_type {
106
107
  COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
107
108
  COMMON_SAMPLER_TYPE_XTC = 8,
108
109
  COMMON_SAMPLER_TYPE_INFILL = 9,
110
+ COMMON_SAMPLER_TYPE_PENALTIES = 10,
109
111
  };
110
112
 
111
113
  // dimensionality reduction methods, used by cvector-generator
@@ -141,14 +143,15 @@ struct common_params_sampling {
141
143
  int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
142
144
  float mirostat_tau = 5.00f; // target entropy
143
145
  float mirostat_eta = 0.10f; // learning rate
144
- bool penalize_nl = false; // consider newlines as a repeatable token
145
146
  bool ignore_eos = false;
146
147
  bool no_perf = false; // disable performance metrics
148
+ bool timing_per_token = false;
147
149
 
148
150
  std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
149
151
 
150
152
 
151
153
  std::vector<enum common_sampler_type> samplers = {
154
+ COMMON_SAMPLER_TYPE_PENALTIES,
152
155
  COMMON_SAMPLER_TYPE_DRY,
153
156
  COMMON_SAMPLER_TYPE_TOP_K,
154
157
  COMMON_SAMPLER_TYPE_TYPICAL_P,
@@ -168,6 +171,7 @@ struct common_params_sampling {
168
171
 
169
172
  struct common_params_speculative {
170
173
  std::vector<lm_ggml_backend_dev_t> devices; // devices to use for offloading
174
+
171
175
  int32_t n_ctx = 0; // draft context size
172
176
  int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
173
177
  int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
@@ -181,6 +185,14 @@ struct common_params_speculative {
181
185
  std::string model = ""; // draft model for speculative decoding // NOLINT
182
186
  };
183
187
 
188
+ struct common_params_vocoder {
189
+ std::string hf_repo = ""; // HF repo // NOLINT
190
+ std::string hf_file = ""; // HF file // NOLINT
191
+
192
+ std::string model = ""; // model path // NOLINT
193
+ std::string model_url = ""; // model url to download // NOLINT
194
+ };
195
+
184
196
  struct common_params {
185
197
 
186
198
  void * progress_callback_user_data = nullptr;
@@ -207,11 +219,13 @@ struct common_params {
207
219
  float defrag_thold = 0.1f; // KV cache defragmentation threshold
208
220
 
209
221
  // offload params
210
- std::vector<lm_ggml_backend_dev_t> devices; // devices to use for offloading
211
- int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
212
- int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
213
- float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
214
- enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
222
+ std::vector<lm_ggml_backend_dev_t> devices; // devices to use for offloading
223
+
224
+ int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
225
+ int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
226
+ float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
227
+
228
+ enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
215
229
 
216
230
  struct cpu_params cpuparams;
217
231
  struct cpu_params cpuparams_batch;
@@ -225,11 +239,12 @@ struct common_params {
225
239
  enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
226
240
  enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
227
241
 
228
- struct common_params_sampling sampling;
242
+ struct common_params_sampling sampling;
229
243
  struct common_params_speculative speculative;
244
+ struct common_params_vocoder vocoder;
230
245
 
231
246
  std::string model = ""; // model path // NOLINT
232
- std::string model_alias = "unknown"; // model alias // NOLINT
247
+ std::string model_alias = ""; // model alias // NOLINT
233
248
  std::string model_url = ""; // model url to download // NOLINT
234
249
  std::string hf_token = ""; // HF token // NOLINT
235
250
  std::string hf_repo = ""; // HF repo // NOLINT
@@ -300,8 +315,8 @@ struct common_params {
300
315
  bool warmup = true; // warmup run
301
316
  bool check_tensors = false; // validate tensor data
302
317
 
303
- std::string cache_type_k = "f16"; // KV cache data type for the K
304
- std::string cache_type_v = "f16"; // KV cache data type for the V
318
+ lm_ggml_type cache_type_k = LM_GGML_TYPE_F16; // KV cache data type for the K
319
+ lm_ggml_type cache_type_v = LM_GGML_TYPE_F16; // KV cache data type for the V
305
320
 
306
321
  // multimodal models (see examples/llava)
307
322
  std::string mmproj = ""; // path to multimodal projector // NOLINT
@@ -451,6 +466,11 @@ std::vector<std::string> string_split<std::string>(const std::string & input, ch
451
466
  return parts;
452
467
  }
453
468
 
469
+ static bool string_starts_with(const std::string & str,
470
+ const std::string & prefix) { // While we wait for C++20's std::string::starts_with...
471
+ return str.rfind(prefix, 0) == 0;
472
+ }
473
+
454
474
  bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
455
475
  void string_process_escapes(std::string & input);
456
476
 
@@ -602,7 +622,8 @@ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_si
602
622
  // Embedding utils
603
623
  //
604
624
 
605
- void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
625
+ // TODO: repace embd_norm with an enum
626
+ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);
606
627
 
607
628
  float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);
608
629
 
package/cpp/ggml-alloc.c CHANGED
@@ -534,7 +534,6 @@ static void lm_ggml_gallocr_allocate_node(lm_ggml_gallocr_t galloc, struct lm_gg
534
534
  size_t offset = lm_ggml_dyn_tallocr_alloc(alloc, size, node);
535
535
  hn->buffer_id = buffer_id;
536
536
  hn->offset = offset;
537
- return;
538
537
  }
539
538
  }
540
539
 
@@ -211,27 +211,45 @@ extern "C" {
211
211
  LM_GGML_API void lm_ggml_backend_device_register(lm_ggml_backend_dev_t device);
212
212
 
213
213
  // Add backend dynamic loading support to the backend
214
- typedef lm_ggml_backend_reg_t (*lm_ggml_backend_init_t)(void);
215
214
 
216
- #ifdef LM_GGML_BACKEND_DL
217
- #ifdef __cplusplus
218
- # define LM_GGML_BACKEND_DL_IMPL(reg_fn) \
219
- extern "C" { \
220
- LM_GGML_BACKEND_API lm_ggml_backend_reg_t lm_ggml_backend_init(void); \
221
- } \
222
- lm_ggml_backend_reg_t lm_ggml_backend_init(void) { \
223
- return reg_fn(); \
224
- }
225
- #else
226
- # define LM_GGML_BACKEND_DL_IMPL(reg_fn) \
227
- LM_GGML_BACKEND_API lm_ggml_backend_reg_t lm_ggml_backend_init(void); \
228
- lm_ggml_backend_reg_t lm_ggml_backend_init(void) { \
229
- return reg_fn(); \
230
- }
231
- #endif
232
- #else
233
- # define LM_GGML_BACKEND_DL_IMPL(reg_fn)
234
- #endif
215
+ // Initialize the backend
216
+ typedef lm_ggml_backend_reg_t (*lm_ggml_backend_init_t)(void);
217
+ // Optional: obtain a score for the backend based on the system configuration
218
+ // Higher scores are preferred, 0 means the backend is not supported in the current system
219
+ typedef int (*lm_ggml_backend_score_t)(void);
220
+
221
+ #ifdef LM_GGML_BACKEND_DL
222
+ # ifdef __cplusplus
223
+ # define LM_GGML_BACKEND_DL_IMPL(reg_fn) \
224
+ extern "C" { \
225
+ LM_GGML_BACKEND_API lm_ggml_backend_reg_t lm_ggml_backend_init(void); \
226
+ } \
227
+ lm_ggml_backend_reg_t lm_ggml_backend_init(void) { \
228
+ return reg_fn(); \
229
+ }
230
+ # define LM_GGML_BACKEND_DL_SCORE_IMPL(score_fn) \
231
+ extern "C" { \
232
+ LM_GGML_BACKEND_API int lm_ggml_backend_score(void); \
233
+ } \
234
+ int lm_ggml_backend_score(void) { \
235
+ return score_fn(); \
236
+ }
237
+ # else
238
+ # define LM_GGML_BACKEND_DL_IMPL(reg_fn) \
239
+ LM_GGML_BACKEND_API lm_ggml_backend_reg_t lm_ggml_backend_init(void); \
240
+ lm_ggml_backend_reg_t lm_ggml_backend_init(void) { \
241
+ return reg_fn(); \
242
+ }
243
+ # define LM_GGML_BACKEND_DL_SCORE_IMPL(score_fn) \
244
+ LM_GGML_BACKEND_API int lm_ggml_backend_score(void); \
245
+ int lm_ggml_backend_score(void) { \
246
+ return score_fn(); \
247
+ }
248
+ # endif
249
+ #else
250
+ # define LM_GGML_BACKEND_DL_IMPL(reg_fn)
251
+ # define LM_GGML_BACKEND_DL_SCORE_IMPL(score_fn)
252
+ #endif
235
253
 
236
254
  #ifdef __cplusplus
237
255
  }