cui-llama.rn 1.2.6 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -2
- package/android/src/main/CMakeLists.txt +26 -6
- package/android/src/main/java/com/rnllama/LlamaContext.java +115 -27
- package/android/src/main/java/com/rnllama/RNLlama.java +40 -7
- package/android/src/main/jni.cpp +228 -40
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +9 -4
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +9 -4
- package/cpp/amx/amx.cpp +196 -0
- package/cpp/amx/amx.h +20 -0
- package/cpp/amx/common.h +101 -0
- package/cpp/amx/mmq.cpp +2524 -0
- package/cpp/amx/mmq.h +16 -0
- package/cpp/common.cpp +118 -251
- package/cpp/common.h +53 -30
- package/cpp/ggml-aarch64.c +46 -3395
- package/cpp/ggml-aarch64.h +0 -20
- package/cpp/ggml-alloc.c +6 -8
- package/cpp/ggml-backend-impl.h +33 -11
- package/cpp/ggml-backend-reg.cpp +423 -0
- package/cpp/ggml-backend.cpp +14 -676
- package/cpp/ggml-backend.h +46 -9
- package/cpp/ggml-common.h +6 -0
- package/cpp/ggml-cpu-aarch64.c +3823 -0
- package/cpp/ggml-cpu-aarch64.h +32 -0
- package/cpp/ggml-cpu-impl.h +14 -242
- package/cpp/ggml-cpu-quants.c +10835 -0
- package/cpp/ggml-cpu-quants.h +63 -0
- package/cpp/ggml-cpu.c +13971 -13720
- package/cpp/ggml-cpu.cpp +715 -0
- package/cpp/ggml-cpu.h +65 -63
- package/cpp/ggml-impl.h +285 -25
- package/cpp/ggml-metal.h +8 -8
- package/cpp/ggml-metal.m +1221 -728
- package/cpp/ggml-quants.c +189 -10681
- package/cpp/ggml-quants.h +78 -125
- package/cpp/ggml-threading.cpp +12 -0
- package/cpp/ggml-threading.h +12 -0
- package/cpp/ggml.c +688 -1460
- package/cpp/ggml.h +58 -244
- package/cpp/json-schema-to-grammar.cpp +1045 -1045
- package/cpp/json.hpp +24766 -24766
- package/cpp/llama-sampling.cpp +5 -2
- package/cpp/llama.cpp +409 -123
- package/cpp/llama.h +8 -4
- package/cpp/rn-llama.hpp +89 -25
- package/cpp/sampling.cpp +42 -3
- package/cpp/sampling.h +22 -1
- package/cpp/sgemm.cpp +608 -0
- package/cpp/speculative.cpp +270 -0
- package/cpp/speculative.h +28 -0
- package/cpp/unicode.cpp +11 -0
- package/ios/RNLlama.mm +43 -20
- package/ios/RNLlamaContext.h +9 -3
- package/ios/RNLlamaContext.mm +146 -33
- package/jest/mock.js +0 -1
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/grammar.js +4 -2
- package/lib/commonjs/grammar.js.map +1 -1
- package/lib/commonjs/index.js +52 -15
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/grammar.js +2 -1
- package/lib/module/grammar.js.map +1 -1
- package/lib/module/index.js +51 -15
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +122 -8
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/grammar.d.ts +5 -6
- package/lib/typescript/grammar.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +15 -6
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +2 -1
- package/src/NativeRNLlama.ts +135 -13
- package/src/grammar.ts +10 -8
- package/src/index.ts +104 -28
package/android/src/main/jni.cpp
CHANGED
@@ -4,13 +4,15 @@
|
|
4
4
|
#include <android/log.h>
|
5
5
|
#include <cstdlib>
|
6
6
|
#include <ctime>
|
7
|
+
#include <ctime>
|
7
8
|
#include <sys/sysinfo.h>
|
8
9
|
#include <string>
|
9
10
|
#include <thread>
|
10
11
|
#include <unordered_map>
|
11
12
|
#include "llama.h"
|
12
|
-
#include "
|
13
|
+
#include "llama-impl.h"
|
13
14
|
#include "ggml.h"
|
15
|
+
#include "rn-llama.hpp"
|
14
16
|
|
15
17
|
#define UNUSED(x) (void)(x)
|
16
18
|
#define TAG "RNLLAMA_ANDROID_JNI"
|
@@ -22,6 +24,13 @@ static inline int min(int a, int b) {
|
|
22
24
|
return (a < b) ? a : b;
|
23
25
|
}
|
24
26
|
|
27
|
+
static void log_callback(lm_ggml_log_level level, const char * fmt, void * data) {
|
28
|
+
if (level == LM_GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
|
29
|
+
else if (level == LM_GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
|
30
|
+
else if (level == LM_GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
|
31
|
+
else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
|
32
|
+
}
|
33
|
+
|
25
34
|
extern "C" {
|
26
35
|
|
27
36
|
// Method to create WritableMap
|
@@ -107,6 +116,15 @@ static inline void pushDouble(JNIEnv *env, jobject arr, double value) {
|
|
107
116
|
env->CallVoidMethod(arr, pushDoubleMethod, value);
|
108
117
|
}
|
109
118
|
|
119
|
+
// Method to push string into WritableArray
|
120
|
+
static inline void pushString(JNIEnv *env, jobject arr, const char *value) {
|
121
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
|
122
|
+
jmethodID pushStringMethod = env->GetMethodID(mapClass, "pushString", "(Ljava/lang/String;)V");
|
123
|
+
|
124
|
+
jstring jValue = env->NewStringUTF(value);
|
125
|
+
env->CallVoidMethod(arr, pushStringMethod, jValue);
|
126
|
+
}
|
127
|
+
|
110
128
|
// Method to push WritableMap into WritableArray
|
111
129
|
static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
|
112
130
|
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
|
@@ -125,6 +143,77 @@ static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject v
|
|
125
143
|
env->CallVoidMethod(map, putArrayMethod, jKey, value);
|
126
144
|
}
|
127
145
|
|
146
|
+
JNIEXPORT jobject JNICALL
|
147
|
+
Java_com_rnllama_LlamaContext_modelInfo(
|
148
|
+
JNIEnv *env,
|
149
|
+
jobject thiz,
|
150
|
+
jstring model_path_str,
|
151
|
+
jobjectArray skip
|
152
|
+
) {
|
153
|
+
UNUSED(thiz);
|
154
|
+
|
155
|
+
const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
|
156
|
+
|
157
|
+
std::vector<std::string> skip_vec;
|
158
|
+
int skip_len = env->GetArrayLength(skip);
|
159
|
+
for (int i = 0; i < skip_len; i++) {
|
160
|
+
jstring skip_str = (jstring) env->GetObjectArrayElement(skip, i);
|
161
|
+
const char *skip_chars = env->GetStringUTFChars(skip_str, nullptr);
|
162
|
+
skip_vec.push_back(skip_chars);
|
163
|
+
env->ReleaseStringUTFChars(skip_str, skip_chars);
|
164
|
+
}
|
165
|
+
|
166
|
+
struct lm_gguf_init_params params = {
|
167
|
+
/*.no_alloc = */ false,
|
168
|
+
/*.ctx = */ NULL,
|
169
|
+
};
|
170
|
+
struct lm_gguf_context * ctx = lm_gguf_init_from_file(model_path_chars, params);
|
171
|
+
|
172
|
+
if (!ctx) {
|
173
|
+
LOGI("%s: failed to load '%s'\n", __func__, model_path_chars);
|
174
|
+
return nullptr;
|
175
|
+
}
|
176
|
+
|
177
|
+
auto info = createWriteableMap(env);
|
178
|
+
putInt(env, info, "version", lm_gguf_get_version(ctx));
|
179
|
+
putInt(env, info, "alignment", lm_gguf_get_alignment(ctx));
|
180
|
+
putInt(env, info, "data_offset", lm_gguf_get_data_offset(ctx));
|
181
|
+
{
|
182
|
+
const int n_kv = lm_gguf_get_n_kv(ctx);
|
183
|
+
|
184
|
+
for (int i = 0; i < n_kv; ++i) {
|
185
|
+
const char * key = lm_gguf_get_key(ctx, i);
|
186
|
+
|
187
|
+
bool skipped = false;
|
188
|
+
if (skip_len > 0) {
|
189
|
+
for (int j = 0; j < skip_len; j++) {
|
190
|
+
if (skip_vec[j] == key) {
|
191
|
+
skipped = true;
|
192
|
+
break;
|
193
|
+
}
|
194
|
+
}
|
195
|
+
}
|
196
|
+
|
197
|
+
if (skipped) {
|
198
|
+
continue;
|
199
|
+
}
|
200
|
+
|
201
|
+
const std::string value = rnllama::lm_gguf_kv_to_str(ctx, i);
|
202
|
+
putString(env, info, key, value.c_str());
|
203
|
+
}
|
204
|
+
}
|
205
|
+
|
206
|
+
env->ReleaseStringUTFChars(model_path_str, model_path_chars);
|
207
|
+
lm_gguf_free(ctx);
|
208
|
+
|
209
|
+
return reinterpret_cast<jobject>(info);
|
210
|
+
}
|
211
|
+
|
212
|
+
struct callback_context {
|
213
|
+
JNIEnv *env;
|
214
|
+
rnllama::llama_rn_context *llama;
|
215
|
+
jobject callback;
|
216
|
+
};
|
128
217
|
|
129
218
|
std::unordered_map<long, rnllama::llama_rn_context *> context_map;
|
130
219
|
|
@@ -141,10 +230,14 @@ Java_com_rnllama_LlamaContext_initContext(
|
|
141
230
|
jobject thiz,
|
142
231
|
jstring model_path_str,
|
143
232
|
jboolean embedding,
|
233
|
+
jint embd_normalize,
|
144
234
|
jint n_ctx,
|
145
235
|
jint n_batch,
|
146
236
|
jint n_threads,
|
147
237
|
jint n_gpu_layers, // TODO: Support this
|
238
|
+
jboolean flash_attn,
|
239
|
+
jstring cache_type_k,
|
240
|
+
jstring cache_type_v,
|
148
241
|
jboolean use_mlock,
|
149
242
|
jboolean use_mmap,
|
150
243
|
jboolean vocab_only,
|
@@ -152,7 +245,8 @@ Java_com_rnllama_LlamaContext_initContext(
|
|
152
245
|
jfloat lora_scaled,
|
153
246
|
jfloat rope_freq_base,
|
154
247
|
jfloat rope_freq_scale,
|
155
|
-
|
248
|
+
jint pooling_type,
|
249
|
+
jobject load_progress_callback
|
156
250
|
) {
|
157
251
|
UNUSED(thiz);
|
158
252
|
|
@@ -165,65 +259,110 @@ Java_com_rnllama_LlamaContext_initContext(
|
|
165
259
|
|
166
260
|
const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
|
167
261
|
defaultParams.model = model_path_chars;
|
168
|
-
|
169
|
-
defaultParams.embedding = embedding;
|
170
|
-
|
262
|
+
|
171
263
|
defaultParams.n_ctx = n_ctx;
|
172
264
|
defaultParams.n_batch = n_batch;
|
173
265
|
|
266
|
+
if (pooling_type != -1) {
|
267
|
+
defaultParams.pooling_type = static_cast<enum llama_pooling_type>(pooling_type);
|
268
|
+
}
|
269
|
+
|
270
|
+
defaultParams.embedding = embedding;
|
271
|
+
if (embd_normalize != -1) {
|
272
|
+
defaultParams.embd_normalize = embd_normalize;
|
273
|
+
}
|
274
|
+
if (embedding) {
|
275
|
+
// For non-causal models, batch size must be equal to ubatch size
|
276
|
+
defaultParams.n_ubatch = defaultParams.n_batch;
|
277
|
+
}
|
278
|
+
|
174
279
|
int max_threads = std::thread::hardware_concurrency();
|
175
280
|
// Use 2 threads by default on 4-core devices, 4 threads on more cores
|
176
281
|
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
|
177
282
|
defaultParams.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
|
178
283
|
|
179
|
-
defaultParams.n_gpu_layers = n_gpu_layers;
|
180
|
-
|
284
|
+
// defaultParams.n_gpu_layers = n_gpu_layers;
|
285
|
+
defaultParams.flash_attn = flash_attn;
|
286
|
+
|
287
|
+
const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
|
288
|
+
const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
|
289
|
+
defaultParams.cache_type_k = cache_type_k_chars;
|
290
|
+
defaultParams.cache_type_v = cache_type_v_chars;
|
291
|
+
|
181
292
|
defaultParams.use_mlock = use_mlock;
|
182
293
|
defaultParams.use_mmap = use_mmap;
|
183
294
|
|
184
295
|
const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
|
185
296
|
if (lora_chars != nullptr && lora_chars[0] != '\0') {
|
186
297
|
defaultParams.lora_adapters.push_back({lora_chars, lora_scaled});
|
187
|
-
defaultParams.use_mmap = false;
|
188
298
|
}
|
189
299
|
|
190
300
|
defaultParams.rope_freq_base = rope_freq_base;
|
191
301
|
defaultParams.rope_freq_scale = rope_freq_scale;
|
192
302
|
|
193
|
-
// progress callback when loading
|
194
|
-
jclass llamaContextClass = env->GetObjectClass(javaLlamaContext);
|
195
|
-
jmethodID sendProgressMethod = env->GetMethodID(llamaContextClass, "emitModelProgressUpdate", "(I)V");
|
196
|
-
|
197
|
-
CallbackContext callbackctx = {env, javaLlamaContext, sendProgressMethod, 0};
|
198
|
-
|
199
|
-
defaultParams.progress_callback_user_data = &callbackctx;
|
200
|
-
defaultParams.progress_callback = [](float progress, void * ctx) {
|
201
|
-
unsigned percentage = (unsigned) (100 * progress);
|
202
|
-
CallbackContext * cbctx = static_cast<CallbackContext*>(ctx);
|
203
|
-
// reduce call frequency by only calling method when value changes
|
204
|
-
if (percentage <= cbctx->current) return true;
|
205
|
-
cbctx->current = percentage;
|
206
|
-
cbctx->env->CallVoidMethod(cbctx->thiz, cbctx->sendProgressMethod, percentage);
|
207
|
-
return true;
|
208
|
-
};
|
209
|
-
|
210
|
-
|
211
303
|
auto llama = new rnllama::llama_rn_context();
|
304
|
+
llama->is_load_interrupted = false;
|
305
|
+
llama->loading_progress = 0;
|
306
|
+
|
307
|
+
if (load_progress_callback != nullptr) {
|
308
|
+
defaultParams.progress_callback = [](float progress, void * user_data) {
|
309
|
+
callback_context *cb_ctx = (callback_context *)user_data;
|
310
|
+
JNIEnv *env = cb_ctx->env;
|
311
|
+
auto llama = cb_ctx->llama;
|
312
|
+
jobject callback = cb_ctx->callback;
|
313
|
+
int percentage = (int) (100 * progress);
|
314
|
+
if (percentage > llama->loading_progress) {
|
315
|
+
llama->loading_progress = percentage;
|
316
|
+
jclass callback_class = env->GetObjectClass(callback);
|
317
|
+
jmethodID onLoadProgress = env->GetMethodID(callback_class, "onLoadProgress", "(I)V");
|
318
|
+
env->CallVoidMethod(callback, onLoadProgress, percentage);
|
319
|
+
}
|
320
|
+
return !llama->is_load_interrupted;
|
321
|
+
};
|
322
|
+
|
323
|
+
callback_context *cb_ctx = new callback_context;
|
324
|
+
cb_ctx->env = env;
|
325
|
+
cb_ctx->llama = llama;
|
326
|
+
cb_ctx->callback = env->NewGlobalRef(load_progress_callback);
|
327
|
+
defaultParams.progress_callback_user_data = cb_ctx;
|
328
|
+
}
|
329
|
+
|
212
330
|
bool is_model_loaded = llama->loadModel(defaultParams);
|
213
331
|
|
332
|
+
env->ReleaseStringUTFChars(model_path_str, model_path_chars);
|
333
|
+
env->ReleaseStringUTFChars(lora_str, lora_chars);
|
334
|
+
env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
|
335
|
+
env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
|
336
|
+
|
214
337
|
LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
|
215
338
|
if (is_model_loaded) {
|
339
|
+
if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) {
|
340
|
+
LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
|
341
|
+
llama_free(llama->ctx);
|
342
|
+
return -1;
|
343
|
+
}
|
216
344
|
context_map[(long) llama->ctx] = llama;
|
217
345
|
} else {
|
218
346
|
llama_free(llama->ctx);
|
219
347
|
}
|
220
348
|
|
221
|
-
env->ReleaseStringUTFChars(model_path_str, model_path_chars);
|
222
|
-
env->ReleaseStringUTFChars(lora_str, lora_chars);
|
223
|
-
|
224
349
|
return reinterpret_cast<jlong>(llama->ctx);
|
225
350
|
}
|
226
351
|
|
352
|
+
|
353
|
+
JNIEXPORT void JNICALL
|
354
|
+
Java_com_rnllama_LlamaContext_interruptLoad(
|
355
|
+
JNIEnv *env,
|
356
|
+
jobject thiz,
|
357
|
+
jlong context_ptr
|
358
|
+
) {
|
359
|
+
UNUSED(thiz);
|
360
|
+
auto llama = context_map[(long) context_ptr];
|
361
|
+
if (llama) {
|
362
|
+
llama->is_load_interrupted = true;
|
363
|
+
}
|
364
|
+
}
|
365
|
+
|
227
366
|
JNIEXPORT jobject JNICALL
|
228
367
|
Java_com_rnllama_LlamaContext_loadModelDetails(
|
229
368
|
JNIEnv *env,
|
@@ -397,13 +536,18 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
397
536
|
jint top_k,
|
398
537
|
jfloat top_p,
|
399
538
|
jfloat min_p,
|
400
|
-
jfloat
|
401
|
-
jfloat
|
539
|
+
jfloat xtc_threshold,
|
540
|
+
jfloat xtc_probability,
|
402
541
|
jfloat typical_p,
|
403
542
|
jint seed,
|
404
543
|
jobjectArray stop,
|
405
544
|
jboolean ignore_eos,
|
406
545
|
jobjectArray logit_bias,
|
546
|
+
jfloat dry_multiplier,
|
547
|
+
jfloat dry_base,
|
548
|
+
jint dry_allowed_length,
|
549
|
+
jint dry_penalty_last_n,
|
550
|
+
jobjectArray dry_sequence_breakers,
|
407
551
|
jobject partial_completion_callback
|
408
552
|
) {
|
409
553
|
UNUSED(thiz);
|
@@ -414,7 +558,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
414
558
|
//llama_reset_timings(llama->ctx);
|
415
559
|
|
416
560
|
llama->params.prompt = env->GetStringUTFChars(prompt, nullptr);
|
417
|
-
llama->params.
|
561
|
+
llama->params.sampling.seed = (seed == -1) ? time(NULL) : seed;
|
418
562
|
|
419
563
|
int max_threads = std::thread::hardware_concurrency();
|
420
564
|
// Use 2 threads by default on 4-core devices, 4 threads on more cores
|
@@ -422,9 +566,9 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
422
566
|
llama->params.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
|
423
567
|
|
424
568
|
llama->params.n_predict = n_predict;
|
425
|
-
llama->params.
|
569
|
+
llama->params.sampling.ignore_eos = ignore_eos;
|
426
570
|
|
427
|
-
auto & sparams = llama->params.
|
571
|
+
auto & sparams = llama->params.sampling;
|
428
572
|
sparams.temp = temperature;
|
429
573
|
sparams.penalty_last_n = penalty_last_n;
|
430
574
|
sparams.penalty_repeat = penalty_repeat;
|
@@ -440,14 +584,34 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
440
584
|
sparams.typ_p = typical_p;
|
441
585
|
sparams.n_probs = n_probs;
|
442
586
|
sparams.grammar = env->GetStringUTFChars(grammar, nullptr);
|
443
|
-
sparams.xtc_threshold =
|
444
|
-
sparams.xtc_probability =
|
587
|
+
sparams.xtc_threshold = xtc_threshold;
|
588
|
+
sparams.xtc_probability = xtc_probability;
|
589
|
+
sparams.dry_multiplier = dry_multiplier;
|
590
|
+
sparams.dry_base = dry_base;
|
591
|
+
sparams.dry_allowed_length = dry_allowed_length;
|
592
|
+
sparams.dry_penalty_last_n = dry_penalty_last_n;
|
445
593
|
|
446
594
|
sparams.logit_bias.clear();
|
447
595
|
if (ignore_eos) {
|
448
596
|
sparams.logit_bias[llama_token_eos(llama->model)].bias = -INFINITY;
|
449
597
|
}
|
450
598
|
|
599
|
+
// dry break seq
|
600
|
+
|
601
|
+
jint size = env->GetArrayLength(dry_sequence_breakers);
|
602
|
+
std::vector<std::string> dry_sequence_breakers_vector;
|
603
|
+
|
604
|
+
for (jint i = 0; i < size; i++) {
|
605
|
+
jstring javaString = (jstring)env->GetObjectArrayElement(dry_sequence_breakers, i);
|
606
|
+
const char *nativeString = env->GetStringUTFChars(javaString, 0);
|
607
|
+
dry_sequence_breakers_vector.push_back(std::string(nativeString));
|
608
|
+
env->ReleaseStringUTFChars(javaString, nativeString);
|
609
|
+
env->DeleteLocalRef(javaString);
|
610
|
+
}
|
611
|
+
|
612
|
+
sparams.dry_sequence_breakers = dry_sequence_breakers_vector;
|
613
|
+
|
614
|
+
// logit bias
|
451
615
|
const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx));
|
452
616
|
jsize logit_bias_len = env->GetArrayLength(logit_bias);
|
453
617
|
|
@@ -529,7 +693,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
529
693
|
auto tokenResult = createWriteableMap(env);
|
530
694
|
putString(env, tokenResult, "token", to_send.c_str());
|
531
695
|
|
532
|
-
if (llama->params.
|
696
|
+
if (llama->params.sampling.n_probs > 0) {
|
533
697
|
const std::vector<llama_token> to_send_toks = common_tokenize(llama->ctx, to_send, false);
|
534
698
|
size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
|
535
699
|
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
|
@@ -651,16 +815,27 @@ Java_com_rnllama_LlamaContext_isEmbeddingEnabled(
|
|
651
815
|
|
652
816
|
JNIEXPORT jobject JNICALL
|
653
817
|
Java_com_rnllama_LlamaContext_embedding(
|
654
|
-
JNIEnv *env, jobject thiz,
|
818
|
+
JNIEnv *env, jobject thiz,
|
819
|
+
jlong context_ptr,
|
820
|
+
jstring text,
|
821
|
+
jint embd_normalize
|
822
|
+
) {
|
655
823
|
UNUSED(thiz);
|
656
824
|
auto llama = context_map[(long) context_ptr];
|
657
825
|
|
826
|
+
common_params embdParams;
|
827
|
+
embdParams.embedding = true;
|
828
|
+
embdParams.embd_normalize = llama->params.embd_normalize;
|
829
|
+
if (embd_normalize != -1) {
|
830
|
+
embdParams.embd_normalize = embd_normalize;
|
831
|
+
}
|
832
|
+
|
658
833
|
const char *text_chars = env->GetStringUTFChars(text, nullptr);
|
659
834
|
|
660
835
|
llama->rewind();
|
661
836
|
|
662
837
|
llama_perf_context_reset(llama->ctx);
|
663
|
-
|
838
|
+
|
664
839
|
llama->params.prompt = text_chars;
|
665
840
|
|
666
841
|
llama->params.n_predict = 0;
|
@@ -675,7 +850,7 @@ Java_com_rnllama_LlamaContext_embedding(
|
|
675
850
|
llama->loadPrompt();
|
676
851
|
llama->doCompletion();
|
677
852
|
|
678
|
-
std::vector<float> embedding = llama->getEmbedding();
|
853
|
+
std::vector<float> embedding = llama->getEmbedding(embdParams);
|
679
854
|
|
680
855
|
auto embeddings = createWritableArray(env);
|
681
856
|
for (const auto &val : embedding) {
|
@@ -683,6 +858,12 @@ Java_com_rnllama_LlamaContext_embedding(
|
|
683
858
|
}
|
684
859
|
putArray(env, result, "embedding", embeddings);
|
685
860
|
|
861
|
+
auto promptTokens = createWritableArray(env);
|
862
|
+
for (const auto &tok : llama->embd) {
|
863
|
+
pushString(env, promptTokens, common_token_to_piece(llama->ctx, tok).c_str());
|
864
|
+
}
|
865
|
+
putArray(env, result, "prompt_tokens", promptTokens);
|
866
|
+
|
686
867
|
env->ReleaseStringUTFChars(text, text_chars);
|
687
868
|
return result;
|
688
869
|
}
|
@@ -722,4 +903,11 @@ Java_com_rnllama_LlamaContext_freeContext(
|
|
722
903
|
context_map.erase((long) llama->ctx);
|
723
904
|
}
|
724
905
|
|
906
|
+
JNIEXPORT void JNICALL
|
907
|
+
Java_com_rnllama_LlamaContext_logToAndroid(JNIEnv *env, jobject thiz) {
|
908
|
+
UNUSED(env);
|
909
|
+
UNUSED(thiz);
|
910
|
+
llama_log_set(log_callback, NULL);
|
911
|
+
}
|
912
|
+
|
725
913
|
} // extern "C"
|
@@ -39,8 +39,13 @@ public class RNLlamaModule extends NativeRNLlamaSpec {
|
|
39
39
|
}
|
40
40
|
|
41
41
|
@ReactMethod
|
42
|
-
public void
|
43
|
-
rnllama.
|
42
|
+
public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
|
43
|
+
rnllama.modelInfo(model, skip, promise);
|
44
|
+
}
|
45
|
+
|
46
|
+
@ReactMethod
|
47
|
+
public void initContext(double id, final ReadableMap params, final Promise promise) {
|
48
|
+
rnllama.initContext(id, params, promise);
|
44
49
|
}
|
45
50
|
|
46
51
|
@ReactMethod
|
@@ -89,8 +94,8 @@ public class RNLlamaModule extends NativeRNLlamaSpec {
|
|
89
94
|
}
|
90
95
|
|
91
96
|
@ReactMethod
|
92
|
-
public void embedding(double id, final String text, final Promise promise) {
|
93
|
-
rnllama.embedding(id, text, promise);
|
97
|
+
public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
|
98
|
+
rnllama.embedding(id, text, params, promise);
|
94
99
|
}
|
95
100
|
|
96
101
|
@ReactMethod
|
@@ -40,8 +40,13 @@ public class RNLlamaModule extends ReactContextBaseJavaModule {
|
|
40
40
|
}
|
41
41
|
|
42
42
|
@ReactMethod
|
43
|
-
public void
|
44
|
-
rnllama.
|
43
|
+
public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
|
44
|
+
rnllama.modelInfo(model, skip, promise);
|
45
|
+
}
|
46
|
+
|
47
|
+
@ReactMethod
|
48
|
+
public void initContext(double id, final ReadableMap params, final Promise promise) {
|
49
|
+
rnllama.initContext(id, params, promise);
|
45
50
|
}
|
46
51
|
|
47
52
|
@ReactMethod
|
@@ -90,8 +95,8 @@ public class RNLlamaModule extends ReactContextBaseJavaModule {
|
|
90
95
|
}
|
91
96
|
|
92
97
|
@ReactMethod
|
93
|
-
public void embedding(double id, final String text, final Promise promise) {
|
94
|
-
rnllama.embedding(id, text, promise);
|
98
|
+
public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
|
99
|
+
rnllama.embedding(id, text, params, promise);
|
95
100
|
}
|
96
101
|
|
97
102
|
@ReactMethod
|
package/cpp/amx/amx.cpp
ADDED
@@ -0,0 +1,196 @@
|
|
1
|
+
#include "amx.h"
|
2
|
+
#include "common.h"
|
3
|
+
#include "mmq.h"
|
4
|
+
#include "ggml-backend-impl.h"
|
5
|
+
#include "ggml-backend.h"
|
6
|
+
#include "ggml-impl.h"
|
7
|
+
#include "ggml-cpu.h"
|
8
|
+
|
9
|
+
#if defined(__gnu_linux__)
|
10
|
+
#include <sys/syscall.h>
|
11
|
+
#include <unistd.h>
|
12
|
+
#endif
|
13
|
+
|
14
|
+
#include <cstdlib>
|
15
|
+
#include <cstring>
|
16
|
+
#include <memory>
|
17
|
+
|
18
|
+
#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
|
19
|
+
|
20
|
+
// AMX buffer interface
|
21
|
+
static void lm_ggml_backend_amx_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) {
|
22
|
+
free(buffer->context);
|
23
|
+
}
|
24
|
+
|
25
|
+
static void * lm_ggml_backend_amx_buffer_get_base(lm_ggml_backend_buffer_t buffer) {
|
26
|
+
return (void *)(buffer->context);
|
27
|
+
}
|
28
|
+
|
29
|
+
static void lm_ggml_backend_amx_buffer_memset_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
30
|
+
memset((char *)tensor->data + offset, value, size);
|
31
|
+
|
32
|
+
LM_GGML_UNUSED(buffer);
|
33
|
+
}
|
34
|
+
|
35
|
+
static void lm_ggml_backend_amx_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
36
|
+
if (qtype_has_amx_kernels(tensor->type)) {
|
37
|
+
lm_ggml_backend_amx_convert_weight(tensor, data, offset, size);
|
38
|
+
} else {
|
39
|
+
memcpy((char *)tensor->data + offset, data, size);
|
40
|
+
}
|
41
|
+
|
42
|
+
LM_GGML_UNUSED(buffer);
|
43
|
+
}
|
44
|
+
|
45
|
+
static void lm_ggml_backend_amx_buffer_get_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
46
|
+
LM_GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
|
47
|
+
memcpy(data, (const char *)tensor->data + offset, size);
|
48
|
+
|
49
|
+
LM_GGML_UNUSED(buffer);
|
50
|
+
}
|
51
|
+
|
52
|
+
static bool lm_ggml_backend_amx_buffer_cpy_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst) {
|
53
|
+
if (lm_ggml_backend_buffer_is_host(src->buffer)) {
|
54
|
+
if (qtype_has_amx_kernels(src->type)) {
|
55
|
+
lm_ggml_backend_amx_convert_weight(dst, src->data, 0, lm_ggml_nbytes(dst));
|
56
|
+
} else {
|
57
|
+
memcpy(dst->data, src->data, lm_ggml_nbytes(src));
|
58
|
+
}
|
59
|
+
return true;
|
60
|
+
}
|
61
|
+
return false;
|
62
|
+
|
63
|
+
LM_GGML_UNUSED(buffer);
|
64
|
+
}
|
65
|
+
|
66
|
+
static void lm_ggml_backend_amx_buffer_clear(lm_ggml_backend_buffer_t buffer, uint8_t value) {
|
67
|
+
memset(buffer->context, value, buffer->size);
|
68
|
+
}
|
69
|
+
|
70
|
+
static lm_ggml_backend_buffer_i lm_ggml_backend_amx_buffer_interface = {
|
71
|
+
/* .free_buffer = */ lm_ggml_backend_amx_buffer_free_buffer,
|
72
|
+
/* .get_base = */ lm_ggml_backend_amx_buffer_get_base,
|
73
|
+
/* .init_tensor = */ NULL, // no initialization required
|
74
|
+
/* .memset_tensor = */ lm_ggml_backend_amx_buffer_memset_tensor,
|
75
|
+
/* .set_tensor = */ lm_ggml_backend_amx_buffer_set_tensor,
|
76
|
+
/* .get_tensor = */ lm_ggml_backend_amx_buffer_get_tensor,
|
77
|
+
/* .cpy_tensor = */ lm_ggml_backend_amx_buffer_cpy_tensor,
|
78
|
+
/* .clear = */ lm_ggml_backend_amx_buffer_clear,
|
79
|
+
/* .reset = */ NULL,
|
80
|
+
};
|
81
|
+
|
82
|
+
static const char * lm_ggml_backend_amx_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) {
|
83
|
+
return "AMX";
|
84
|
+
|
85
|
+
LM_GGML_UNUSED(buft);
|
86
|
+
}
|
87
|
+
|
88
|
+
static lm_ggml_backend_buffer_t lm_ggml_backend_amx_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) {
|
89
|
+
void * data = aligned_alloc(TENSOR_ALIGNMENT, size);
|
90
|
+
if (data == NULL) {
|
91
|
+
fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
|
92
|
+
return NULL;
|
93
|
+
}
|
94
|
+
|
95
|
+
return lm_ggml_backend_buffer_init(buft, lm_ggml_backend_amx_buffer_interface, data, size);
|
96
|
+
}
|
97
|
+
|
98
|
+
static size_t lm_ggml_backend_amx_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) {
|
99
|
+
return TENSOR_ALIGNMENT;
|
100
|
+
|
101
|
+
LM_GGML_UNUSED(buft);
|
102
|
+
}
|
103
|
+
|
104
|
+
static size_t lm_ggml_backend_amx_buffer_type_get_alloc_size(lm_ggml_backend_buffer_type_t buft, const lm_ggml_tensor* tensor) {
|
105
|
+
return lm_ggml_backend_amx_get_alloc_size(tensor);
|
106
|
+
|
107
|
+
LM_GGML_UNUSED(buft);
|
108
|
+
}
|
109
|
+
|
110
|
+
static bool lm_ggml_backend_amx_buffer_type_is_host(lm_ggml_backend_buffer_type_t buft) {
|
111
|
+
return false;
|
112
|
+
|
113
|
+
LM_GGML_UNUSED(buft);
|
114
|
+
}
|
115
|
+
|
116
|
+
#define ARCH_GET_XCOMP_PERM 0x1022
|
117
|
+
#define ARCH_REQ_XCOMP_PERM 0x1023
|
118
|
+
#define XFEATURE_XTILECFG 17
|
119
|
+
#define XFEATURE_XTILEDATA 18
|
120
|
+
|
121
|
+
static bool lm_ggml_amx_init() {
|
122
|
+
#if defined(__gnu_linux__)
|
123
|
+
if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
|
124
|
+
fprintf(stderr, "AMX is not ready to be used!\n");
|
125
|
+
return false;
|
126
|
+
}
|
127
|
+
return true;
|
128
|
+
#elif defined(_WIN32)
|
129
|
+
return true;
|
130
|
+
#endif
|
131
|
+
}
|
132
|
+
lm_ggml_backend_buffer_type_t lm_ggml_backend_amx_buffer_type() {
|
133
|
+
static struct lm_ggml_backend_buffer_type lm_ggml_backend_buffer_type_amx = {
|
134
|
+
/* .iface = */ {
|
135
|
+
/* .get_name = */ lm_ggml_backend_amx_buffer_type_get_name,
|
136
|
+
/* .alloc_buffer = */ lm_ggml_backend_amx_buffer_type_alloc_buffer,
|
137
|
+
/* .get_alignment = */ lm_ggml_backend_amx_buffer_type_get_alignment,
|
138
|
+
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
139
|
+
/* .get_alloc_size = */ lm_ggml_backend_amx_buffer_type_get_alloc_size,
|
140
|
+
/* .is_host = */ lm_ggml_backend_amx_buffer_type_is_host,
|
141
|
+
},
|
142
|
+
/* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0),
|
143
|
+
/* .context = */ NULL,
|
144
|
+
};
|
145
|
+
|
146
|
+
if (!lm_ggml_amx_init()) {
|
147
|
+
return NULL;
|
148
|
+
}
|
149
|
+
|
150
|
+
return &lm_ggml_backend_buffer_type_amx;
|
151
|
+
}
|
152
|
+
|
153
|
+
bool lm_ggml_backend_amx_buft_is_amx(lm_ggml_backend_buffer_type_t buft) {
|
154
|
+
return buft->iface.get_name == lm_ggml_backend_amx_buffer_type_get_name;
|
155
|
+
}
|
156
|
+
|
157
|
+
bool lm_ggml_backend_amx_device_supports_op(const struct lm_ggml_tensor * op) {
|
158
|
+
// handle only 2d gemm for now
|
159
|
+
auto is_contiguous_2d = [](const struct lm_ggml_tensor * t) {
|
160
|
+
return lm_ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
|
161
|
+
};
|
162
|
+
|
163
|
+
switch (op->op) {
|
164
|
+
case LM_GGML_OP_NONE:
|
165
|
+
case LM_GGML_OP_RESHAPE:
|
166
|
+
case LM_GGML_OP_VIEW:
|
167
|
+
case LM_GGML_OP_PERMUTE:
|
168
|
+
case LM_GGML_OP_TRANSPOSE:
|
169
|
+
return true;
|
170
|
+
|
171
|
+
case LM_GGML_OP_MUL_MAT: {
|
172
|
+
const struct lm_ggml_tensor * src0 = op->src[0];
|
173
|
+
const struct lm_ggml_tensor * src1 = op->src[1];
|
174
|
+
|
175
|
+
const enum lm_ggml_type type = src0->type;
|
176
|
+
const int64_t ne0 = op->ne[0];
|
177
|
+
|
178
|
+
// amx kernels enables for Q4_0, Q4_1, Q8_0, F16
|
179
|
+
// Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
|
180
|
+
bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == LM_GGML_TYPE_F16);
|
181
|
+
|
182
|
+
bool can_use_amx =
|
183
|
+
is_contiguous_2d(src0) && // src0 must be contiguous
|
184
|
+
is_contiguous_2d(src1) && // src1 must be contiguous
|
185
|
+
src1->type == LM_GGML_TYPE_F32 && // src1 must be float32
|
186
|
+
has_amx_kernels && // with amx kernel impls
|
187
|
+
ne0 % (TILE_N * 2) == 0; // out_features is 32x
|
188
|
+
|
189
|
+
return can_use_amx;
|
190
|
+
}
|
191
|
+
default:
|
192
|
+
return false;
|
193
|
+
}
|
194
|
+
}
|
195
|
+
|
196
|
+
#endif // defined(__AMX_INT8__) && defined(__AVX512VNNI__)
|