cui-llama.rn 1.4.1 → 1.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. package/README.md +4 -23
  2. package/android/build.gradle +12 -3
  3. package/android/src/main/CMakeLists.txt +13 -7
  4. package/android/src/main/java/com/rnllama/LlamaContext.java +27 -20
  5. package/android/src/main/java/com/rnllama/RNLlama.java +5 -1
  6. package/android/src/main/jni.cpp +8 -5
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  15. package/cpp/README.md +1 -1
  16. package/cpp/common.cpp +0 -212
  17. package/cpp/common.h +3 -0
  18. package/cpp/rn-llama.cpp +822 -0
  19. package/cpp/rn-llama.h +123 -0
  20. package/ios/CMakeLists.txt +99 -0
  21. package/ios/RNLlama.h +5 -1
  22. package/ios/RNLlama.mm +2 -2
  23. package/ios/RNLlamaContext.h +8 -1
  24. package/ios/RNLlamaContext.mm +15 -11
  25. package/ios/rnllama.xcframework/Info.plist +74 -0
  26. package/jest/mock.js +3 -2
  27. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  28. package/lib/commonjs/index.js +4 -2
  29. package/lib/commonjs/index.js.map +1 -1
  30. package/lib/module/NativeRNLlama.js.map +1 -1
  31. package/lib/module/index.js +4 -2
  32. package/lib/module/index.js.map +1 -1
  33. package/lib/typescript/NativeRNLlama.d.ts +5 -1
  34. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  35. package/lib/typescript/index.d.ts.map +1 -1
  36. package/llama-rn.podspec +8 -2
  37. package/package.json +5 -2
  38. package/src/NativeRNLlama.ts +5 -1
  39. package/src/index.ts +9 -2
package/README.md CHANGED
@@ -36,6 +36,8 @@ npm install llama.rn
36
36
 
37
37
  Please re-run `npx pod-install` again.
38
38
 
39
+ By default, `llama.rn` will use pre-built `rnllama.xcframework` for iOS. If you want to build from source, please set `RNLLAMA_BUILD_FROM_SOURCE` to `1` in your Podfile.
40
+
39
41
  #### Android
40
42
 
41
43
  Add proguard rule if it's enabled in project (android/app/proguard-rules.pro):
@@ -45,6 +47,8 @@ Add proguard rule if it's enabled in project (android/app/proguard-rules.pro):
45
47
  -keep class com.rnllama.** { *; }
46
48
  ```
47
49
 
50
+ By default, `llama.rn` will use pre-built libraries for Android. If you want to build from source, please set `rnllamaBuildFromSource` to `true` in `android/gradle.properties`.
51
+
48
52
  ## Obtain the model
49
53
 
50
54
  You can search HuggingFace for available models (Keyword: [`GGUF`](https://huggingface.co/search/full-text?q=GGUF&type=model)).
@@ -137,29 +141,6 @@ Please visit the [Documentation](docs/API) for more details.
137
141
 
138
142
  You can also visit the [example](example) to see how to use it.
139
143
 
140
- Run the example:
141
-
142
- ```bash
143
- yarn && yarn bootstrap
144
-
145
- # iOS
146
- yarn example ios
147
- # Use device
148
- yarn example ios --device "<device name>"
149
- # With release mode
150
- yarn example ios --mode Release
151
-
152
- # Android
153
- yarn example android
154
- # With release mode
155
- yarn example android --mode release
156
- ```
157
-
158
- This example used [react-native-document-picker](https://github.com/rnmods/react-native-document-picker) for select model.
159
-
160
- - iOS: You can move the model to iOS Simulator, or iCloud for real device.
161
- - Android: Selected file will be copied or downloaded to cache directory so it may be slow.
162
-
163
144
  ## Grammar Sampling
164
145
 
165
146
  GBNF (GGML BNF) is a format for defining [formal grammars](https://en.wikipedia.org/wiki/Formal_grammar) to constrain model outputs in `llama.cpp`. For example, you can use it to force the model to generate valid JSON, or speak only in emojis.
@@ -54,9 +54,18 @@ android {
54
54
  }
55
55
  }
56
56
  }
57
- externalNativeBuild {
58
- cmake {
59
- path = file('src/main/CMakeLists.txt')
57
+ def rnllamaBuildFromSource = project.properties["rnllamaBuildFromSource"]
58
+ if (rnllamaBuildFromSource == "true") {
59
+ externalNativeBuild {
60
+ cmake {
61
+ path = file('src/main/CMakeLists.txt')
62
+ }
63
+ }
64
+ // Exclude jniLibs
65
+ sourceSets {
66
+ main {
67
+ jniLibs.srcDirs = []
68
+ }
60
69
  }
61
70
  }
62
71
  buildTypes {
@@ -2,6 +2,12 @@ cmake_minimum_required(VERSION 3.10)
2
2
 
3
3
  project(llama.rn)
4
4
 
5
+ find_program(CCACHE_FOUND ccache)
6
+ if(CCACHE_FOUND)
7
+ set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache)
8
+ set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ccache)
9
+ endif(CCACHE_FOUND)
10
+
5
11
  set(CMAKE_CXX_STANDARD 17)
6
12
  set(RNLLAMA_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../cpp)
7
13
 
@@ -45,7 +51,7 @@ set(
45
51
  ${RNLLAMA_LIB_DIR}/unicode.cpp
46
52
  ${RNLLAMA_LIB_DIR}/sgemm.cpp
47
53
  ${RNLLAMA_LIB_DIR}/common.cpp
48
- ${RNLLAMA_LIB_DIR}/rn-llama.hpp
54
+ ${RNLLAMA_LIB_DIR}/rn-llama.cpp
49
55
  ${CMAKE_SOURCE_DIR}/jni-utils.h
50
56
  ${CMAKE_SOURCE_DIR}/jni.cpp
51
57
  )
@@ -86,13 +92,13 @@ build_library("rnllama" "")
86
92
 
87
93
  if (${ANDROID_ABI} STREQUAL "arm64-v8a")
88
94
  # ARM64 targets
89
- build_library("rnllama_v8_4_fp16_dotprod_sve" "-march=armv8.4-a+fp16+dotprod+sve")
90
- build_library("rnllama_v8_4_fp16_dotprod_i8mm_sve" "-march=armv8.4-a+fp16+dotprod+i8mm+sve")
91
- build_library("rnllama_v8_4_fp16_dotprod_i8mm" "-march=armv8.4-a+fp16+dotprod+i8mm")
92
- build_library("rnllama_v8_4_fp16_dotprod" "-march=armv8.4-a+fp16+dotprod")
93
- build_library("rnllama_v8_2_fp16_dotprod" "-march=armv8.2-a+fp16+dotprod")
94
- build_library("rnllama_v8_2_fp16" "-march=armv8.2-a+fp16")
95
+ # Removing fp16 for now as it leads to issues with some models like deepseek r1 distills
96
+ # https://github.com/mybigday/llama.rn/pull/110#issuecomment-2609918310
95
97
  build_library("rnllama_v8" "-march=armv8-a")
98
+ build_library("rnllama_v8_2" "-march=armv8.2-a")
99
+ build_library("rnllama_v8_2_dotprod" "-march=armv8.2-a+dotprod")
100
+ build_library("rnllama_v8_2_i8mm" "-march=armv8.2-a+i8mm")
101
+ build_library("rnllama_v8_2_dotprod_i8mm" "-march=armv8.2-a+dotprod+i8mm")
96
102
 
97
103
  # https://github.com/ggerganov/llama.cpp/blob/master/docs/android.md#cross-compile-using-android-ndk
98
104
  # llama.cpp will deal with the cpu features
@@ -26,6 +26,8 @@ import java.io.FileInputStream;
26
26
  public class LlamaContext {
27
27
  public static final String NAME = "RNLlamaContext";
28
28
 
29
+ private static String loadedLibrary = "";
30
+
29
31
  private int id;
30
32
  private ReactApplicationContext reactContext;
31
33
  private long context;
@@ -160,6 +162,10 @@ public class LlamaContext {
160
162
  return modelDetails;
161
163
  }
162
164
 
165
+ public String getLoadedLibrary() {
166
+ return loadedLibrary;
167
+ }
168
+
163
169
  public String getFormattedChat(ReadableArray messages, String chatTemplate) {
164
170
  ReadableMap[] msgs = new ReadableMap[messages.size()];
165
171
  for (int i = 0; i < messages.size(); i++) {
@@ -401,36 +407,37 @@ public class LlamaContext {
401
407
 
402
408
  // TODO: Add runtime check for cpu features
403
409
  if (LlamaContext.isArm64V8a()) {
404
- if (isAtLeastArmV84 && hasSve && hasI8mm && hasFp16 && hasDotProd) {
405
- Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm_sve.so");
406
- System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm_sve");
407
- } else if (isAtLeastArmV84 && hasSve && hasFp16 && hasDotProd) {
408
- Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_sve.so");
409
- System.loadLibrary("rnllama_v8_4_fp16_dotprod_sve");
410
- } else if (isAtLeastArmV84 && hasI8mm && hasFp16 && hasDotProd) {
411
- Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm.so");
412
- System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm");
413
- } else if (isAtLeastArmV84 && hasFp16 && hasDotProd) {
414
- Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod.so");
415
- System.loadLibrary("rnllama_v8_4_fp16_dotprod");
416
- } else if (isAtLeastArmV82 && hasFp16 && hasDotProd) {
417
- Log.d(NAME, "Loading librnllama_v8_2_fp16_dotprod.so");
418
- System.loadLibrary("rnllama_v8_2_fp16_dotprod");
419
- } else if (isAtLeastArmV82 && hasFp16) {
420
- Log.d(NAME, "Loading librnllama_v8_2_fp16.so");
421
- System.loadLibrary("rnllama_v8_2_fp16");
410
+ if (hasDotProd && hasI8mm) {
411
+ Log.d(NAME, "Loading librnllama_v8_2_dotprod_i8mm.so");
412
+ System.loadLibrary("rnllama_v8_2_dotprod_i8mm");
413
+ loadedLibrary = "rnllama_v8_2_dotprod_i8mm";
414
+ } else if (hasDotProd) {
415
+ Log.d(NAME, "Loading librnllama_v8_2_dotprod.so");
416
+ System.loadLibrary("rnllama_v8_2_dotprod");
417
+ loadedLibrary = "rnllama_v8_2_dotprod";
418
+ } else if (hasI8mm) {
419
+ Log.d(NAME, "Loading librnllama_v8_2_i8mm.so");
420
+ System.loadLibrary("rnllama_v8_2_i8mm");
421
+ loadedLibrary = "rnllama_v8_2_i8mm";
422
+ } else if (hasFp16) {
423
+ Log.d(NAME, "Loading librnllama_v8_2.so");
424
+ System.loadLibrary("rnllama_v8_2");
425
+ loadedLibrary = "rnllama_v8_2";
422
426
  } else {
423
- Log.d(NAME, "Loading librnllama_v8.so");
427
+ Log.d(NAME, "Loading default librnllama_v8.so");
424
428
  System.loadLibrary("rnllama_v8");
429
+ loadedLibrary = "rnllama_v8";
425
430
  }
426
431
  // Log.d(NAME, "Loading librnllama_v8_7.so with runtime feature detection");
427
432
  // System.loadLibrary("rnllama_v8_7");
428
433
  } else if (LlamaContext.isX86_64()) {
429
434
  Log.d(NAME, "Loading librnllama_x86_64.so");
430
435
  System.loadLibrary("rnllama_x86_64");
436
+ loadedLibrary = "rnllama_x86_64";
431
437
  } else {
432
438
  Log.d(NAME, "Loading default librnllama.so");
433
439
  System.loadLibrary("rnllama");
440
+ loadedLibrary = "rnllama";
434
441
  }
435
442
  }
436
443
 
@@ -465,7 +472,7 @@ public class LlamaContext {
465
472
  public void emitModelProgressUpdate(int progress) {
466
473
  WritableMap event = Arguments.createMap();
467
474
  event.putInt("progress", progress);
468
- eventEmitter.emit("@RNLlama_onModelProgress", event);
475
+ eventEmitter.emit("@RNLlama_onInitContextProgress", event);
469
476
  }
470
477
 
471
478
  protected static native WritableMap modelInfo(
@@ -35,7 +35,7 @@ public class RNLlama implements LifecycleEventListener {
35
35
 
36
36
  private HashMap<Integer, LlamaContext> contexts = new HashMap<>();
37
37
 
38
- private int llamaContextLimit = 1;
38
+ private int llamaContextLimit = -1;
39
39
 
40
40
  public void setContextLimit(double limit, Promise promise) {
41
41
  llamaContextLimit = (int) limit;
@@ -83,6 +83,9 @@ public class RNLlama implements LifecycleEventListener {
83
83
  if (context != null) {
84
84
  throw new Exception("Context already exists");
85
85
  }
86
+ if (llamaContextLimit > -1 && contexts.size() >= llamaContextLimit) {
87
+ throw new Exception("Context limit reached");
88
+ }
86
89
  LlamaContext llamaContext = new LlamaContext(contextId, reactContext, params);
87
90
  if (llamaContext.getContext() == 0) {
88
91
  throw new Exception("Failed to initialize context");
@@ -92,6 +95,7 @@ public class RNLlama implements LifecycleEventListener {
92
95
  result.putBoolean("gpu", false);
93
96
  result.putString("reasonNoGPU", "Currently not supported");
94
97
  result.putMap("model", llamaContext.getModelDetails());
98
+ result.putString("androidLib", llamaContext.getLoadedLibrary());
95
99
  return result;
96
100
  } catch (Exception e) {
97
101
  exception = e;
@@ -11,9 +11,8 @@
11
11
  #include <unordered_map>
12
12
  #include "llama.h"
13
13
  #include "llama-impl.h"
14
- #include "llama-context.h"
15
- #include "gguf.h"
16
- #include "rn-llama.hpp"
14
+ #include "ggml.h"
15
+ #include "rn-llama.h"
17
16
  #include "jni-utils.h"
18
17
 
19
18
  #define UNUSED(x) (void)(x)
@@ -421,6 +420,7 @@ Java_com_rnllama_LlamaContext_loadModelDetails(
421
420
  llama_model_desc(llama->model, desc, sizeof(desc));
422
421
  putString(env, result, "desc", desc);
423
422
  putDouble(env, result, "size", llama_model_size(llama->model));
423
+ putDouble(env, result, "nEmbd", llama_model_n_embd(llama->model));
424
424
  putDouble(env, result, "nParams", llama_model_n_params(llama->model));
425
425
  putBoolean(env, result, "isChatTemplateSupported", llama->validateModelChatTemplate());
426
426
  putMap(env, result, "metadata", meta);
@@ -621,9 +621,12 @@ Java_com_rnllama_LlamaContext_doCompletion(
621
621
  sparams.dry_allowed_length = dry_allowed_length;
622
622
  sparams.dry_penalty_last_n = dry_penalty_last_n;
623
623
 
624
+ const llama_model * model = llama_get_model(llama->ctx);
625
+ const llama_vocab * vocab = llama_model_get_vocab(model);
626
+
624
627
  sparams.logit_bias.clear();
625
628
  if (ignore_eos) {
626
- sparams.logit_bias[llama_vocab_eos(llama_model_get_vocab(llama->model))].bias = -INFINITY;
629
+ sparams.logit_bias[llama_vocab_eos(vocab)].bias = -INFINITY;
627
630
  }
628
631
 
629
632
  // dry break seq
@@ -642,7 +645,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
642
645
  sparams.dry_sequence_breakers = dry_sequence_breakers_vector;
643
646
 
644
647
  // logit bias
645
- const int n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama->model));
648
+ const int n_vocab = llama_vocab_n_tokens(vocab);
646
649
  jsize logit_bias_len = env->GetArrayLength(logit_bias);
647
650
 
648
651
  for (jsize i = 0; i < logit_bias_len; i++) {
package/cpp/README.md CHANGED
@@ -1,4 +1,4 @@
1
1
  # Note
2
2
 
3
- - Only `rn-llama.hpp` is the specific file for this project, others are sync from [llama.cpp](https://github.com/ggerganov/llama.cpp).
3
+ - Only `rn-llama.h` and `rn-llama.cpp` are the specific files for this folder, others are sync from [llama.cpp](https://github.com/ggerganov/llama.cpp).
4
4
  - We can update the native source by using the [bootstrap](../scripts/bootstrap.sh) script.
package/cpp/common.cpp CHANGED
@@ -1153,218 +1153,6 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma
1153
1153
  return false;
1154
1154
  }
1155
1155
 
1156
- static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
1157
- // Initialize libcurl
1158
- curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
1159
- curl_slist_ptr http_headers;
1160
- if (!curl) {
1161
- LOG_ERR("%s: error initializing libcurl\n", __func__);
1162
- return false;
1163
- }
1164
-
1165
- bool force_download = false;
1166
-
1167
- // Set the URL, allow to follow http redirection
1168
- curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
1169
- curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
1170
-
1171
- // Check if hf-token or bearer-token was specified
1172
- if (!hf_token.empty()) {
1173
- std::string auth_header = "Authorization: Bearer " + hf_token;
1174
- http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
1175
- curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
1176
- }
1177
-
1178
- #if defined(_WIN32)
1179
- // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
1180
- // operating system. Currently implemented under MS-Windows.
1181
- curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
1182
- #endif
1183
-
1184
- // Check if the file already exists locally
1185
- auto file_exists = std::filesystem::exists(path);
1186
-
1187
- // If the file exists, check its JSON metadata companion file.
1188
- std::string metadata_path = path + ".json";
1189
- nlohmann::json metadata;
1190
- std::string etag;
1191
- std::string last_modified;
1192
-
1193
- if (file_exists) {
1194
- // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
1195
- std::ifstream metadata_in(metadata_path);
1196
- if (metadata_in.good()) {
1197
- try {
1198
- metadata_in >> metadata;
1199
- LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
1200
- if (metadata.contains("url") && metadata.at("url").is_string()) {
1201
- auto previous_url = metadata.at("url").get<std::string>();
1202
- if (previous_url != url) {
1203
- LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str());
1204
- return false;
1205
- }
1206
- }
1207
- if (metadata.contains("etag") && metadata.at("etag").is_string()) {
1208
- etag = metadata.at("etag");
1209
- }
1210
- if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
1211
- last_modified = metadata.at("lastModified");
1212
- }
1213
- } catch (const nlohmann::json::exception & e) {
1214
- LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
1215
- return false;
1216
- }
1217
- }
1218
- } else {
1219
- LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
1220
- }
1221
-
1222
- // Send a HEAD request to retrieve the etag and last-modified headers
1223
- struct common_load_model_from_url_headers {
1224
- std::string etag;
1225
- std::string last_modified;
1226
- };
1227
-
1228
- common_load_model_from_url_headers headers;
1229
-
1230
- {
1231
- typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
1232
- auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
1233
- common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
1234
-
1235
- static std::regex header_regex("([^:]+): (.*)\r\n");
1236
- static std::regex etag_regex("ETag", std::regex_constants::icase);
1237
- static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
1238
-
1239
- std::string header(buffer, n_items);
1240
- std::smatch match;
1241
- if (std::regex_match(header, match, header_regex)) {
1242
- const std::string & key = match[1];
1243
- const std::string & value = match[2];
1244
- if (std::regex_match(key, match, etag_regex)) {
1245
- headers->etag = value;
1246
- } else if (std::regex_match(key, match, last_modified_regex)) {
1247
- headers->last_modified = value;
1248
- }
1249
- }
1250
- return n_items;
1251
- };
1252
-
1253
- curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
1254
- curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
1255
- curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
1256
- curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
1257
-
1258
- bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
1259
- if (!was_perform_successful) {
1260
- return false;
1261
- }
1262
-
1263
- long http_code = 0;
1264
- curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
1265
- if (http_code != 200) {
1266
- // HEAD not supported, we don't know if the file has changed
1267
- // force trigger downloading
1268
- force_download = true;
1269
- LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
1270
- }
1271
- }
1272
-
1273
- bool should_download = !file_exists || force_download;
1274
- if (!should_download) {
1275
- if (!etag.empty() && etag != headers.etag) {
1276
- LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
1277
- should_download = true;
1278
- } else if (!last_modified.empty() && last_modified != headers.last_modified) {
1279
- LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
1280
- should_download = true;
1281
- }
1282
- }
1283
- if (should_download) {
1284
- std::string path_temporary = path + ".downloadInProgress";
1285
- if (file_exists) {
1286
- LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
1287
- if (remove(path.c_str()) != 0) {
1288
- LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
1289
- return false;
1290
- }
1291
- }
1292
-
1293
- // Set the output file
1294
-
1295
- struct FILE_deleter {
1296
- void operator()(FILE * f) const {
1297
- fclose(f);
1298
- }
1299
- };
1300
-
1301
- std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
1302
- if (!outfile) {
1303
- LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
1304
- return false;
1305
- }
1306
-
1307
- typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
1308
- auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
1309
- return fwrite(data, size, nmemb, (FILE *)fd);
1310
- };
1311
- curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
1312
- curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
1313
- curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
1314
-
1315
- // display download progress
1316
- curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
1317
-
1318
- // helper function to hide password in URL
1319
- auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
1320
- std::size_t protocol_pos = url.find("://");
1321
- if (protocol_pos == std::string::npos) {
1322
- return url; // Malformed URL
1323
- }
1324
-
1325
- std::size_t at_pos = url.find('@', protocol_pos + 3);
1326
- if (at_pos == std::string::npos) {
1327
- return url; // No password in URL
1328
- }
1329
-
1330
- return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
1331
- };
1332
-
1333
- // start the download
1334
- LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
1335
- llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
1336
- bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
1337
- if (!was_perform_successful) {
1338
- return false;
1339
- }
1340
-
1341
- long http_code = 0;
1342
- curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
1343
- if (http_code < 200 || http_code >= 400) {
1344
- LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
1345
- return false;
1346
- }
1347
-
1348
- // Causes file to be closed explicitly here before we rename it.
1349
- outfile.reset();
1350
-
1351
- // Write the updated JSON metadata file.
1352
- metadata.update({
1353
- {"url", url},
1354
- {"etag", headers.etag},
1355
- {"lastModified", headers.last_modified}
1356
- });
1357
- std::ofstream(metadata_path) << metadata.dump(4);
1358
- LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
1359
-
1360
- if (rename(path_temporary.c_str(), path.c_str()) != 0) {
1361
- LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
1362
- return false;
1363
- }
1364
- }
1365
-
1366
- return true;
1367
- }
1368
1156
 
1369
1157
  struct llama_model * common_load_model_from_url(
1370
1158
  const std::string & model_url,
package/cpp/common.h CHANGED
@@ -534,6 +534,9 @@ struct llama_model * common_load_model_from_hf(
534
534
  const std::string & local_path,
535
535
  const std::string & hf_token,
536
536
  const struct llama_model_params & params);
537
+ std::pair<std::string, std::string> common_get_hf_file(
538
+ const std::string & hf_repo_with_tag,
539
+ const std::string & hf_token);
537
540
 
538
541
  std::pair<std::string, std::string> common_get_hf_file(
539
542
  const std::string & hf_repo_with_tag,