cui-llama.rn 1.4.1 → 1.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +4 -23
- package/android/build.gradle +12 -3
- package/android/src/main/CMakeLists.txt +13 -7
- package/android/src/main/java/com/rnllama/LlamaContext.java +27 -20
- package/android/src/main/java/com/rnllama/RNLlama.java +5 -1
- package/android/src/main/jni.cpp +8 -5
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/cpp/README.md +1 -1
- package/cpp/common.cpp +0 -212
- package/cpp/common.h +3 -0
- package/cpp/rn-llama.cpp +822 -0
- package/cpp/rn-llama.h +123 -0
- package/ios/CMakeLists.txt +99 -0
- package/ios/RNLlama.h +5 -1
- package/ios/RNLlama.mm +2 -2
- package/ios/RNLlamaContext.h +8 -1
- package/ios/RNLlamaContext.mm +15 -11
- package/ios/rnllama.xcframework/Info.plist +74 -0
- package/jest/mock.js +3 -2
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/index.js +4 -2
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/index.js +4 -2
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +5 -1
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts.map +1 -1
- package/llama-rn.podspec +8 -2
- package/package.json +5 -2
- package/src/NativeRNLlama.ts +5 -1
- package/src/index.ts +9 -2
package/README.md
CHANGED
@@ -36,6 +36,8 @@ npm install llama.rn
|
|
36
36
|
|
37
37
|
Please re-run `npx pod-install` again.
|
38
38
|
|
39
|
+
By default, `llama.rn` will use pre-built `rnllama.xcframework` for iOS. If you want to build from source, please set `RNLLAMA_BUILD_FROM_SOURCE` to `1` in your Podfile.
|
40
|
+
|
39
41
|
#### Android
|
40
42
|
|
41
43
|
Add proguard rule if it's enabled in project (android/app/proguard-rules.pro):
|
@@ -45,6 +47,8 @@ Add proguard rule if it's enabled in project (android/app/proguard-rules.pro):
|
|
45
47
|
-keep class com.rnllama.** { *; }
|
46
48
|
```
|
47
49
|
|
50
|
+
By default, `llama.rn` will use pre-built libraries for Android. If you want to build from source, please set `rnllamaBuildFromSource` to `true` in `android/gradle.properties`.
|
51
|
+
|
48
52
|
## Obtain the model
|
49
53
|
|
50
54
|
You can search HuggingFace for available models (Keyword: [`GGUF`](https://huggingface.co/search/full-text?q=GGUF&type=model)).
|
@@ -137,29 +141,6 @@ Please visit the [Documentation](docs/API) for more details.
|
|
137
141
|
|
138
142
|
You can also visit the [example](example) to see how to use it.
|
139
143
|
|
140
|
-
Run the example:
|
141
|
-
|
142
|
-
```bash
|
143
|
-
yarn && yarn bootstrap
|
144
|
-
|
145
|
-
# iOS
|
146
|
-
yarn example ios
|
147
|
-
# Use device
|
148
|
-
yarn example ios --device "<device name>"
|
149
|
-
# With release mode
|
150
|
-
yarn example ios --mode Release
|
151
|
-
|
152
|
-
# Android
|
153
|
-
yarn example android
|
154
|
-
# With release mode
|
155
|
-
yarn example android --mode release
|
156
|
-
```
|
157
|
-
|
158
|
-
This example used [react-native-document-picker](https://github.com/rnmods/react-native-document-picker) for select model.
|
159
|
-
|
160
|
-
- iOS: You can move the model to iOS Simulator, or iCloud for real device.
|
161
|
-
- Android: Selected file will be copied or downloaded to cache directory so it may be slow.
|
162
|
-
|
163
144
|
## Grammar Sampling
|
164
145
|
|
165
146
|
GBNF (GGML BNF) is a format for defining [formal grammars](https://en.wikipedia.org/wiki/Formal_grammar) to constrain model outputs in `llama.cpp`. For example, you can use it to force the model to generate valid JSON, or speak only in emojis.
|
package/android/build.gradle
CHANGED
@@ -54,9 +54,18 @@ android {
|
|
54
54
|
}
|
55
55
|
}
|
56
56
|
}
|
57
|
-
|
58
|
-
|
59
|
-
|
57
|
+
def rnllamaBuildFromSource = project.properties["rnllamaBuildFromSource"]
|
58
|
+
if (rnllamaBuildFromSource == "true") {
|
59
|
+
externalNativeBuild {
|
60
|
+
cmake {
|
61
|
+
path = file('src/main/CMakeLists.txt')
|
62
|
+
}
|
63
|
+
}
|
64
|
+
// Exclude jniLibs
|
65
|
+
sourceSets {
|
66
|
+
main {
|
67
|
+
jniLibs.srcDirs = []
|
68
|
+
}
|
60
69
|
}
|
61
70
|
}
|
62
71
|
buildTypes {
|
@@ -2,6 +2,12 @@ cmake_minimum_required(VERSION 3.10)
|
|
2
2
|
|
3
3
|
project(llama.rn)
|
4
4
|
|
5
|
+
find_program(CCACHE_FOUND ccache)
|
6
|
+
if(CCACHE_FOUND)
|
7
|
+
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache)
|
8
|
+
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ccache)
|
9
|
+
endif(CCACHE_FOUND)
|
10
|
+
|
5
11
|
set(CMAKE_CXX_STANDARD 17)
|
6
12
|
set(RNLLAMA_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../cpp)
|
7
13
|
|
@@ -45,7 +51,7 @@ set(
|
|
45
51
|
${RNLLAMA_LIB_DIR}/unicode.cpp
|
46
52
|
${RNLLAMA_LIB_DIR}/sgemm.cpp
|
47
53
|
${RNLLAMA_LIB_DIR}/common.cpp
|
48
|
-
${RNLLAMA_LIB_DIR}/rn-llama.
|
54
|
+
${RNLLAMA_LIB_DIR}/rn-llama.cpp
|
49
55
|
${CMAKE_SOURCE_DIR}/jni-utils.h
|
50
56
|
${CMAKE_SOURCE_DIR}/jni.cpp
|
51
57
|
)
|
@@ -86,13 +92,13 @@ build_library("rnllama" "")
|
|
86
92
|
|
87
93
|
if (${ANDROID_ABI} STREQUAL "arm64-v8a")
|
88
94
|
# ARM64 targets
|
89
|
-
|
90
|
-
|
91
|
-
build_library("rnllama_v8_4_fp16_dotprod_i8mm" "-march=armv8.4-a+fp16+dotprod+i8mm")
|
92
|
-
build_library("rnllama_v8_4_fp16_dotprod" "-march=armv8.4-a+fp16+dotprod")
|
93
|
-
build_library("rnllama_v8_2_fp16_dotprod" "-march=armv8.2-a+fp16+dotprod")
|
94
|
-
build_library("rnllama_v8_2_fp16" "-march=armv8.2-a+fp16")
|
95
|
+
# Removing fp16 for now as it leads to issues with some models like deepseek r1 distills
|
96
|
+
# https://github.com/mybigday/llama.rn/pull/110#issuecomment-2609918310
|
95
97
|
build_library("rnllama_v8" "-march=armv8-a")
|
98
|
+
build_library("rnllama_v8_2" "-march=armv8.2-a")
|
99
|
+
build_library("rnllama_v8_2_dotprod" "-march=armv8.2-a+dotprod")
|
100
|
+
build_library("rnllama_v8_2_i8mm" "-march=armv8.2-a+i8mm")
|
101
|
+
build_library("rnllama_v8_2_dotprod_i8mm" "-march=armv8.2-a+dotprod+i8mm")
|
96
102
|
|
97
103
|
# https://github.com/ggerganov/llama.cpp/blob/master/docs/android.md#cross-compile-using-android-ndk
|
98
104
|
# llama.cpp will deal with the cpu features
|
@@ -26,6 +26,8 @@ import java.io.FileInputStream;
|
|
26
26
|
public class LlamaContext {
|
27
27
|
public static final String NAME = "RNLlamaContext";
|
28
28
|
|
29
|
+
private static String loadedLibrary = "";
|
30
|
+
|
29
31
|
private int id;
|
30
32
|
private ReactApplicationContext reactContext;
|
31
33
|
private long context;
|
@@ -160,6 +162,10 @@ public class LlamaContext {
|
|
160
162
|
return modelDetails;
|
161
163
|
}
|
162
164
|
|
165
|
+
public String getLoadedLibrary() {
|
166
|
+
return loadedLibrary;
|
167
|
+
}
|
168
|
+
|
163
169
|
public String getFormattedChat(ReadableArray messages, String chatTemplate) {
|
164
170
|
ReadableMap[] msgs = new ReadableMap[messages.size()];
|
165
171
|
for (int i = 0; i < messages.size(); i++) {
|
@@ -401,36 +407,37 @@ public class LlamaContext {
|
|
401
407
|
|
402
408
|
// TODO: Add runtime check for cpu features
|
403
409
|
if (LlamaContext.isArm64V8a()) {
|
404
|
-
if (
|
405
|
-
Log.d(NAME, "Loading
|
406
|
-
System.loadLibrary("
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
} else if (
|
417
|
-
Log.d(NAME, "Loading
|
418
|
-
System.loadLibrary("
|
419
|
-
|
420
|
-
Log.d(NAME, "Loading librnllama_v8_2_fp16.so");
|
421
|
-
System.loadLibrary("rnllama_v8_2_fp16");
|
410
|
+
if (hasDotProd && hasI8mm) {
|
411
|
+
Log.d(NAME, "Loading librnllama_v8_2_dotprod_i8mm.so");
|
412
|
+
System.loadLibrary("rnllama_v8_2_dotprod_i8mm");
|
413
|
+
loadedLibrary = "rnllama_v8_2_dotprod_i8mm";
|
414
|
+
} else if (hasDotProd) {
|
415
|
+
Log.d(NAME, "Loading librnllama_v8_2_dotprod.so");
|
416
|
+
System.loadLibrary("rnllama_v8_2_dotprod");
|
417
|
+
loadedLibrary = "rnllama_v8_2_dotprod";
|
418
|
+
} else if (hasI8mm) {
|
419
|
+
Log.d(NAME, "Loading librnllama_v8_2_i8mm.so");
|
420
|
+
System.loadLibrary("rnllama_v8_2_i8mm");
|
421
|
+
loadedLibrary = "rnllama_v8_2_i8mm";
|
422
|
+
} else if (hasFp16) {
|
423
|
+
Log.d(NAME, "Loading librnllama_v8_2.so");
|
424
|
+
System.loadLibrary("rnllama_v8_2");
|
425
|
+
loadedLibrary = "rnllama_v8_2";
|
422
426
|
} else {
|
423
|
-
Log.d(NAME, "Loading librnllama_v8.so");
|
427
|
+
Log.d(NAME, "Loading default librnllama_v8.so");
|
424
428
|
System.loadLibrary("rnllama_v8");
|
429
|
+
loadedLibrary = "rnllama_v8";
|
425
430
|
}
|
426
431
|
// Log.d(NAME, "Loading librnllama_v8_7.so with runtime feature detection");
|
427
432
|
// System.loadLibrary("rnllama_v8_7");
|
428
433
|
} else if (LlamaContext.isX86_64()) {
|
429
434
|
Log.d(NAME, "Loading librnllama_x86_64.so");
|
430
435
|
System.loadLibrary("rnllama_x86_64");
|
436
|
+
loadedLibrary = "rnllama_x86_64";
|
431
437
|
} else {
|
432
438
|
Log.d(NAME, "Loading default librnllama.so");
|
433
439
|
System.loadLibrary("rnllama");
|
440
|
+
loadedLibrary = "rnllama";
|
434
441
|
}
|
435
442
|
}
|
436
443
|
|
@@ -465,7 +472,7 @@ public class LlamaContext {
|
|
465
472
|
public void emitModelProgressUpdate(int progress) {
|
466
473
|
WritableMap event = Arguments.createMap();
|
467
474
|
event.putInt("progress", progress);
|
468
|
-
eventEmitter.emit("@
|
475
|
+
eventEmitter.emit("@RNLlama_onInitContextProgress", event);
|
469
476
|
}
|
470
477
|
|
471
478
|
protected static native WritableMap modelInfo(
|
@@ -35,7 +35,7 @@ public class RNLlama implements LifecycleEventListener {
|
|
35
35
|
|
36
36
|
private HashMap<Integer, LlamaContext> contexts = new HashMap<>();
|
37
37
|
|
38
|
-
private int llamaContextLimit = 1;
|
38
|
+
private int llamaContextLimit = -1;
|
39
39
|
|
40
40
|
public void setContextLimit(double limit, Promise promise) {
|
41
41
|
llamaContextLimit = (int) limit;
|
@@ -83,6 +83,9 @@ public class RNLlama implements LifecycleEventListener {
|
|
83
83
|
if (context != null) {
|
84
84
|
throw new Exception("Context already exists");
|
85
85
|
}
|
86
|
+
if (llamaContextLimit > -1 && contexts.size() >= llamaContextLimit) {
|
87
|
+
throw new Exception("Context limit reached");
|
88
|
+
}
|
86
89
|
LlamaContext llamaContext = new LlamaContext(contextId, reactContext, params);
|
87
90
|
if (llamaContext.getContext() == 0) {
|
88
91
|
throw new Exception("Failed to initialize context");
|
@@ -92,6 +95,7 @@ public class RNLlama implements LifecycleEventListener {
|
|
92
95
|
result.putBoolean("gpu", false);
|
93
96
|
result.putString("reasonNoGPU", "Currently not supported");
|
94
97
|
result.putMap("model", llamaContext.getModelDetails());
|
98
|
+
result.putString("androidLib", llamaContext.getLoadedLibrary());
|
95
99
|
return result;
|
96
100
|
} catch (Exception e) {
|
97
101
|
exception = e;
|
package/android/src/main/jni.cpp
CHANGED
@@ -11,9 +11,8 @@
|
|
11
11
|
#include <unordered_map>
|
12
12
|
#include "llama.h"
|
13
13
|
#include "llama-impl.h"
|
14
|
-
#include "
|
15
|
-
#include "
|
16
|
-
#include "rn-llama.hpp"
|
14
|
+
#include "ggml.h"
|
15
|
+
#include "rn-llama.h"
|
17
16
|
#include "jni-utils.h"
|
18
17
|
|
19
18
|
#define UNUSED(x) (void)(x)
|
@@ -421,6 +420,7 @@ Java_com_rnllama_LlamaContext_loadModelDetails(
|
|
421
420
|
llama_model_desc(llama->model, desc, sizeof(desc));
|
422
421
|
putString(env, result, "desc", desc);
|
423
422
|
putDouble(env, result, "size", llama_model_size(llama->model));
|
423
|
+
putDouble(env, result, "nEmbd", llama_model_n_embd(llama->model));
|
424
424
|
putDouble(env, result, "nParams", llama_model_n_params(llama->model));
|
425
425
|
putBoolean(env, result, "isChatTemplateSupported", llama->validateModelChatTemplate());
|
426
426
|
putMap(env, result, "metadata", meta);
|
@@ -621,9 +621,12 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
621
621
|
sparams.dry_allowed_length = dry_allowed_length;
|
622
622
|
sparams.dry_penalty_last_n = dry_penalty_last_n;
|
623
623
|
|
624
|
+
const llama_model * model = llama_get_model(llama->ctx);
|
625
|
+
const llama_vocab * vocab = llama_model_get_vocab(model);
|
626
|
+
|
624
627
|
sparams.logit_bias.clear();
|
625
628
|
if (ignore_eos) {
|
626
|
-
sparams.logit_bias[llama_vocab_eos(
|
629
|
+
sparams.logit_bias[llama_vocab_eos(vocab)].bias = -INFINITY;
|
627
630
|
}
|
628
631
|
|
629
632
|
// dry break seq
|
@@ -642,7 +645,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
642
645
|
sparams.dry_sequence_breakers = dry_sequence_breakers_vector;
|
643
646
|
|
644
647
|
// logit bias
|
645
|
-
const int n_vocab = llama_vocab_n_tokens(
|
648
|
+
const int n_vocab = llama_vocab_n_tokens(vocab);
|
646
649
|
jsize logit_bias_len = env->GetArrayLength(logit_bias);
|
647
650
|
|
648
651
|
for (jsize i = 0; i < logit_bias_len; i++) {
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
package/cpp/README.md
CHANGED
@@ -1,4 +1,4 @@
|
|
1
1
|
# Note
|
2
2
|
|
3
|
-
- Only `rn-llama.
|
3
|
+
- Only `rn-llama.h` and `rn-llama.cpp` are the specific files for this folder, others are sync from [llama.cpp](https://github.com/ggerganov/llama.cpp).
|
4
4
|
- We can update the native source by using the [bootstrap](../scripts/bootstrap.sh) script.
|
package/cpp/common.cpp
CHANGED
@@ -1153,218 +1153,6 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma
|
|
1153
1153
|
return false;
|
1154
1154
|
}
|
1155
1155
|
|
1156
|
-
static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
|
1157
|
-
// Initialize libcurl
|
1158
|
-
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
|
1159
|
-
curl_slist_ptr http_headers;
|
1160
|
-
if (!curl) {
|
1161
|
-
LOG_ERR("%s: error initializing libcurl\n", __func__);
|
1162
|
-
return false;
|
1163
|
-
}
|
1164
|
-
|
1165
|
-
bool force_download = false;
|
1166
|
-
|
1167
|
-
// Set the URL, allow to follow http redirection
|
1168
|
-
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
|
1169
|
-
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
|
1170
|
-
|
1171
|
-
// Check if hf-token or bearer-token was specified
|
1172
|
-
if (!hf_token.empty()) {
|
1173
|
-
std::string auth_header = "Authorization: Bearer " + hf_token;
|
1174
|
-
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
|
1175
|
-
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
|
1176
|
-
}
|
1177
|
-
|
1178
|
-
#if defined(_WIN32)
|
1179
|
-
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
|
1180
|
-
// operating system. Currently implemented under MS-Windows.
|
1181
|
-
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
|
1182
|
-
#endif
|
1183
|
-
|
1184
|
-
// Check if the file already exists locally
|
1185
|
-
auto file_exists = std::filesystem::exists(path);
|
1186
|
-
|
1187
|
-
// If the file exists, check its JSON metadata companion file.
|
1188
|
-
std::string metadata_path = path + ".json";
|
1189
|
-
nlohmann::json metadata;
|
1190
|
-
std::string etag;
|
1191
|
-
std::string last_modified;
|
1192
|
-
|
1193
|
-
if (file_exists) {
|
1194
|
-
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
|
1195
|
-
std::ifstream metadata_in(metadata_path);
|
1196
|
-
if (metadata_in.good()) {
|
1197
|
-
try {
|
1198
|
-
metadata_in >> metadata;
|
1199
|
-
LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
|
1200
|
-
if (metadata.contains("url") && metadata.at("url").is_string()) {
|
1201
|
-
auto previous_url = metadata.at("url").get<std::string>();
|
1202
|
-
if (previous_url != url) {
|
1203
|
-
LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str());
|
1204
|
-
return false;
|
1205
|
-
}
|
1206
|
-
}
|
1207
|
-
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
|
1208
|
-
etag = metadata.at("etag");
|
1209
|
-
}
|
1210
|
-
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
|
1211
|
-
last_modified = metadata.at("lastModified");
|
1212
|
-
}
|
1213
|
-
} catch (const nlohmann::json::exception & e) {
|
1214
|
-
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
|
1215
|
-
return false;
|
1216
|
-
}
|
1217
|
-
}
|
1218
|
-
} else {
|
1219
|
-
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
|
1220
|
-
}
|
1221
|
-
|
1222
|
-
// Send a HEAD request to retrieve the etag and last-modified headers
|
1223
|
-
struct common_load_model_from_url_headers {
|
1224
|
-
std::string etag;
|
1225
|
-
std::string last_modified;
|
1226
|
-
};
|
1227
|
-
|
1228
|
-
common_load_model_from_url_headers headers;
|
1229
|
-
|
1230
|
-
{
|
1231
|
-
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
|
1232
|
-
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
|
1233
|
-
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
|
1234
|
-
|
1235
|
-
static std::regex header_regex("([^:]+): (.*)\r\n");
|
1236
|
-
static std::regex etag_regex("ETag", std::regex_constants::icase);
|
1237
|
-
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
|
1238
|
-
|
1239
|
-
std::string header(buffer, n_items);
|
1240
|
-
std::smatch match;
|
1241
|
-
if (std::regex_match(header, match, header_regex)) {
|
1242
|
-
const std::string & key = match[1];
|
1243
|
-
const std::string & value = match[2];
|
1244
|
-
if (std::regex_match(key, match, etag_regex)) {
|
1245
|
-
headers->etag = value;
|
1246
|
-
} else if (std::regex_match(key, match, last_modified_regex)) {
|
1247
|
-
headers->last_modified = value;
|
1248
|
-
}
|
1249
|
-
}
|
1250
|
-
return n_items;
|
1251
|
-
};
|
1252
|
-
|
1253
|
-
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
|
1254
|
-
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
|
1255
|
-
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
|
1256
|
-
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
|
1257
|
-
|
1258
|
-
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
|
1259
|
-
if (!was_perform_successful) {
|
1260
|
-
return false;
|
1261
|
-
}
|
1262
|
-
|
1263
|
-
long http_code = 0;
|
1264
|
-
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
|
1265
|
-
if (http_code != 200) {
|
1266
|
-
// HEAD not supported, we don't know if the file has changed
|
1267
|
-
// force trigger downloading
|
1268
|
-
force_download = true;
|
1269
|
-
LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
|
1270
|
-
}
|
1271
|
-
}
|
1272
|
-
|
1273
|
-
bool should_download = !file_exists || force_download;
|
1274
|
-
if (!should_download) {
|
1275
|
-
if (!etag.empty() && etag != headers.etag) {
|
1276
|
-
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
|
1277
|
-
should_download = true;
|
1278
|
-
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
|
1279
|
-
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
|
1280
|
-
should_download = true;
|
1281
|
-
}
|
1282
|
-
}
|
1283
|
-
if (should_download) {
|
1284
|
-
std::string path_temporary = path + ".downloadInProgress";
|
1285
|
-
if (file_exists) {
|
1286
|
-
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
|
1287
|
-
if (remove(path.c_str()) != 0) {
|
1288
|
-
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
1289
|
-
return false;
|
1290
|
-
}
|
1291
|
-
}
|
1292
|
-
|
1293
|
-
// Set the output file
|
1294
|
-
|
1295
|
-
struct FILE_deleter {
|
1296
|
-
void operator()(FILE * f) const {
|
1297
|
-
fclose(f);
|
1298
|
-
}
|
1299
|
-
};
|
1300
|
-
|
1301
|
-
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
|
1302
|
-
if (!outfile) {
|
1303
|
-
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
|
1304
|
-
return false;
|
1305
|
-
}
|
1306
|
-
|
1307
|
-
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
|
1308
|
-
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
|
1309
|
-
return fwrite(data, size, nmemb, (FILE *)fd);
|
1310
|
-
};
|
1311
|
-
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
|
1312
|
-
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
|
1313
|
-
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
|
1314
|
-
|
1315
|
-
// display download progress
|
1316
|
-
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
|
1317
|
-
|
1318
|
-
// helper function to hide password in URL
|
1319
|
-
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
|
1320
|
-
std::size_t protocol_pos = url.find("://");
|
1321
|
-
if (protocol_pos == std::string::npos) {
|
1322
|
-
return url; // Malformed URL
|
1323
|
-
}
|
1324
|
-
|
1325
|
-
std::size_t at_pos = url.find('@', protocol_pos + 3);
|
1326
|
-
if (at_pos == std::string::npos) {
|
1327
|
-
return url; // No password in URL
|
1328
|
-
}
|
1329
|
-
|
1330
|
-
return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
|
1331
|
-
};
|
1332
|
-
|
1333
|
-
// start the download
|
1334
|
-
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
|
1335
|
-
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
|
1336
|
-
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
|
1337
|
-
if (!was_perform_successful) {
|
1338
|
-
return false;
|
1339
|
-
}
|
1340
|
-
|
1341
|
-
long http_code = 0;
|
1342
|
-
curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
|
1343
|
-
if (http_code < 200 || http_code >= 400) {
|
1344
|
-
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
|
1345
|
-
return false;
|
1346
|
-
}
|
1347
|
-
|
1348
|
-
// Causes file to be closed explicitly here before we rename it.
|
1349
|
-
outfile.reset();
|
1350
|
-
|
1351
|
-
// Write the updated JSON metadata file.
|
1352
|
-
metadata.update({
|
1353
|
-
{"url", url},
|
1354
|
-
{"etag", headers.etag},
|
1355
|
-
{"lastModified", headers.last_modified}
|
1356
|
-
});
|
1357
|
-
std::ofstream(metadata_path) << metadata.dump(4);
|
1358
|
-
LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
|
1359
|
-
|
1360
|
-
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
|
1361
|
-
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
1362
|
-
return false;
|
1363
|
-
}
|
1364
|
-
}
|
1365
|
-
|
1366
|
-
return true;
|
1367
|
-
}
|
1368
1156
|
|
1369
1157
|
struct llama_model * common_load_model_from_url(
|
1370
1158
|
const std::string & model_url,
|
package/cpp/common.h
CHANGED
@@ -534,6 +534,9 @@ struct llama_model * common_load_model_from_hf(
|
|
534
534
|
const std::string & local_path,
|
535
535
|
const std::string & hf_token,
|
536
536
|
const struct llama_model_params & params);
|
537
|
+
std::pair<std::string, std::string> common_get_hf_file(
|
538
|
+
const std::string & hf_repo_with_tag,
|
539
|
+
const std::string & hf_token);
|
537
540
|
|
538
541
|
std::pair<std::string, std::string> common_get_hf_file(
|
539
542
|
const std::string & hf_repo_with_tag,
|