llama-cpp-capacitor 0.0.13 → 0.0.21
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LlamaCpp.podspec +17 -17
- package/Package.swift +27 -27
- package/README.md +717 -574
- package/android/build.gradle +88 -69
- package/android/src/main/AndroidManifest.xml +2 -2
- package/android/src/main/CMakeLists-arm64.txt +131 -0
- package/android/src/main/CMakeLists-x86_64.txt +135 -0
- package/android/src/main/CMakeLists.txt +35 -52
- package/android/src/main/java/ai/annadata/plugin/capacitor/LlamaCpp.java +956 -717
- package/android/src/main/java/ai/annadata/plugin/capacitor/LlamaCppPlugin.java +710 -590
- package/android/src/main/jni-utils.h +7 -7
- package/android/src/main/jni.cpp +868 -127
- package/cpp/{rn-completion.cpp → cap-completion.cpp} +202 -24
- package/cpp/{rn-completion.h → cap-completion.h} +22 -11
- package/cpp/{rn-llama.cpp → cap-llama.cpp} +81 -27
- package/cpp/{rn-llama.h → cap-llama.h} +32 -20
- package/cpp/{rn-mtmd.hpp → cap-mtmd.hpp} +15 -15
- package/cpp/{rn-tts.cpp → cap-tts.cpp} +12 -12
- package/cpp/{rn-tts.h → cap-tts.h} +14 -14
- package/cpp/ggml-cpu/ggml-cpu-impl.h +30 -0
- package/dist/docs.json +100 -3
- package/dist/esm/definitions.d.ts +45 -2
- package/dist/esm/definitions.js.map +1 -1
- package/dist/esm/index.d.ts +22 -0
- package/dist/esm/index.js +66 -3
- package/dist/esm/index.js.map +1 -1
- package/dist/plugin.cjs.js +71 -3
- package/dist/plugin.cjs.js.map +1 -1
- package/dist/plugin.js +71 -3
- package/dist/plugin.js.map +1 -1
- package/ios/Sources/LlamaCppPlugin/LlamaCpp.swift +596 -596
- package/ios/Sources/LlamaCppPlugin/LlamaCppPlugin.swift +591 -514
- package/ios/Tests/LlamaCppPluginTests/LlamaCppPluginTests.swift +15 -15
- package/package.json +111 -110
package/android/src/main/jni.cpp
CHANGED
|
@@ -1,16 +1,21 @@
|
|
|
1
1
|
#include "jni-utils.h"
|
|
2
|
-
#include "
|
|
2
|
+
#include "cap-llama.h"
|
|
3
|
+
#include "cap-completion.h"
|
|
3
4
|
#include <android/log.h>
|
|
4
5
|
#include <cstring>
|
|
5
6
|
#include <memory>
|
|
6
7
|
#include <fstream> // Added for file existence and size checks
|
|
7
8
|
#include <signal.h> // Added for signal handling
|
|
8
9
|
#include <sys/signal.h> // Added for sigaction
|
|
10
|
+
#include <thread> // For background downloads
|
|
11
|
+
#include <atomic> // For thread-safe progress tracking
|
|
12
|
+
#include <filesystem> // For file operations
|
|
13
|
+
#include <mutex> // For thread synchronization
|
|
9
14
|
|
|
10
15
|
// Add missing symbol
|
|
11
|
-
namespace rnllama {
|
|
12
|
-
|
|
13
|
-
}
|
|
16
|
+
// namespace rnllama {
|
|
17
|
+
// bool rnllama_verbose = false;
|
|
18
|
+
// }
|
|
14
19
|
|
|
15
20
|
#define LOG_TAG "LlamaCpp"
|
|
16
21
|
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
|
|
@@ -130,75 +135,67 @@ jclass find_class(JNIEnv* env, const char* name) {
|
|
|
130
135
|
return clazz;
|
|
131
136
|
}
|
|
132
137
|
|
|
133
|
-
//
|
|
134
|
-
|
|
138
|
+
// Convert llama_cap_context to jobject
|
|
139
|
+
jobject llama_context_to_jobject(JNIEnv* env, const capllama::llama_cap_context* context);
|
|
140
|
+
|
|
141
|
+
// Convert jobject to llama_cap_context
|
|
142
|
+
capllama::llama_cap_context* jobject_to_llama_context(JNIEnv* env, jobject obj);
|
|
143
|
+
|
|
144
|
+
// Convert completion result to jobject
|
|
145
|
+
jobject completion_result_to_jobject(JNIEnv* env, const capllama::completion_token_output& result);
|
|
146
|
+
|
|
147
|
+
// Convert tokenize result to jobject
|
|
148
|
+
jobject tokenize_result_to_jobject(JNIEnv* env, const capllama::llama_cap_tokenize_result& result);
|
|
149
|
+
|
|
150
|
+
// Global context storage - fix namespace
|
|
151
|
+
static std::map<jlong, std::unique_ptr<capllama::llama_cap_context>> contexts;
|
|
135
152
|
static jlong next_context_id = 1;
|
|
136
153
|
|
|
154
|
+
// Download progress tracking (simplified for now)
|
|
155
|
+
// This can be enhanced later to track actual download progress
|
|
156
|
+
|
|
137
157
|
extern "C" {
|
|
138
158
|
|
|
139
159
|
JNIEXPORT jlong JNICALL
|
|
140
160
|
Java_ai_annadata_plugin_capacitor_LlamaCpp_initContextNative(
|
|
141
|
-
JNIEnv*
|
|
161
|
+
JNIEnv *env, jobject thiz, jstring modelPath, jobjectArray searchPaths, jobject params) {
|
|
142
162
|
|
|
143
163
|
try {
|
|
144
|
-
std::string model_path_str = jstring_to_string(env,
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
std::vector<std::string> paths_to_check
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
164
|
+
std::string model_path_str = jstring_to_string(env, modelPath);
|
|
165
|
+
|
|
166
|
+
// Get search paths from Java
|
|
167
|
+
jsize pathCount = env->GetArrayLength(searchPaths);
|
|
168
|
+
std::vector<std::string> paths_to_check;
|
|
169
|
+
|
|
170
|
+
// Add the original path first
|
|
171
|
+
paths_to_check.push_back(model_path_str);
|
|
172
|
+
|
|
173
|
+
// Add all search paths from Java
|
|
174
|
+
for (jsize i = 0; i < pathCount; i++) {
|
|
175
|
+
jstring pathJString = (jstring)env->GetObjectArrayElement(searchPaths, i);
|
|
176
|
+
std::string path = jstring_to_string(env, pathJString);
|
|
177
|
+
paths_to_check.push_back(path);
|
|
178
|
+
env->DeleteLocalRef(pathJString);
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
// Rest of the existing logic remains the same...
|
|
158
182
|
std::string full_model_path;
|
|
159
183
|
bool file_found = false;
|
|
160
184
|
|
|
161
185
|
for (const auto& path : paths_to_check) {
|
|
162
186
|
LOGI("Checking path: %s", path.c_str());
|
|
163
|
-
std::
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
LOGI("Found file at: %s, size: %ld bytes", path.c_str(), file_size);
|
|
169
|
-
|
|
170
|
-
// Validate file size
|
|
171
|
-
if (file_size < 1024 * 1024) { // Less than 1MB
|
|
172
|
-
LOGE("Model file is too small, likely corrupted: %s", path.c_str());
|
|
173
|
-
continue; // Try next path
|
|
174
|
-
}
|
|
175
|
-
|
|
176
|
-
// Check if it's a valid GGUF file by reading the magic number
|
|
177
|
-
std::ifstream magic_file(path, std::ios::binary);
|
|
178
|
-
if (magic_file.good()) {
|
|
179
|
-
char magic[4];
|
|
180
|
-
if (magic_file.read(magic, 4)) {
|
|
181
|
-
if (magic[0] == 'G' && magic[1] == 'G' && magic[2] == 'U' && magic[3] == 'F') {
|
|
182
|
-
LOGI("Valid GGUF file detected at: %s", path.c_str());
|
|
183
|
-
full_model_path = path;
|
|
184
|
-
file_found = true;
|
|
185
|
-
break;
|
|
186
|
-
} else {
|
|
187
|
-
LOGI("File does not appear to be a GGUF file (magic: %c%c%c%c) at: %s",
|
|
188
|
-
magic[0], magic[1], magic[2], magic[3], path.c_str());
|
|
189
|
-
}
|
|
190
|
-
}
|
|
191
|
-
magic_file.close();
|
|
192
|
-
}
|
|
187
|
+
if (std::filesystem::exists(path)) {
|
|
188
|
+
full_model_path = path;
|
|
189
|
+
file_found = true;
|
|
190
|
+
LOGI("Found model file at: %s", path.c_str());
|
|
191
|
+
break;
|
|
193
192
|
} else {
|
|
194
|
-
|
|
193
|
+
LOGE("Path not found: %s", path.c_str());
|
|
195
194
|
}
|
|
196
|
-
file_check.close();
|
|
197
195
|
}
|
|
198
|
-
|
|
196
|
+
|
|
199
197
|
if (!file_found) {
|
|
200
|
-
LOGE("Model file not found in any of the
|
|
201
|
-
throw_java_exception(env, "java/lang/RuntimeException", "Model file not found in any expected location");
|
|
198
|
+
LOGE("Model file not found in any of the search paths");
|
|
202
199
|
return -1;
|
|
203
200
|
}
|
|
204
201
|
|
|
@@ -221,9 +218,9 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_initContextNative(
|
|
|
221
218
|
validation_file.close();
|
|
222
219
|
}
|
|
223
220
|
|
|
224
|
-
// Create new context
|
|
225
|
-
auto context = std::make_unique<
|
|
226
|
-
LOGI("Created
|
|
221
|
+
// Create new context - fix namespace
|
|
222
|
+
auto context = std::make_unique<capllama::llama_cap_context>();
|
|
223
|
+
LOGI("Created llama_cap_context");
|
|
227
224
|
|
|
228
225
|
// Initialize common parameters
|
|
229
226
|
common_params cparams;
|
|
@@ -240,47 +237,21 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_initContextNative(
|
|
|
240
237
|
cparams.chat_template = "";
|
|
241
238
|
cparams.embedding = false;
|
|
242
239
|
cparams.cont_batching = false;
|
|
243
|
-
cparams.
|
|
244
|
-
cparams.grammar = "";
|
|
245
|
-
cparams.grammar_penalty.clear();
|
|
240
|
+
cparams.n_parallel = 1;
|
|
246
241
|
cparams.antiprompt.clear();
|
|
247
|
-
cparams.lora_adapter.clear();
|
|
248
|
-
cparams.lora_base = "";
|
|
249
|
-
cparams.mul_mat_q = true;
|
|
250
|
-
cparams.f16_kv = true;
|
|
251
|
-
cparams.logits_all = false;
|
|
252
242
|
cparams.vocab_only = false;
|
|
253
243
|
cparams.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
|
|
254
|
-
cparams.rope_scaling_factor = 0.0f;
|
|
255
|
-
cparams.rope_scaling_orig_ctx_len = 0;
|
|
256
244
|
cparams.yarn_ext_factor = -1.0f;
|
|
257
245
|
cparams.yarn_attn_factor = 1.0f;
|
|
258
246
|
cparams.yarn_beta_fast = 32.0f;
|
|
259
247
|
cparams.yarn_beta_slow = 1.0f;
|
|
260
248
|
cparams.yarn_orig_ctx = 0;
|
|
261
|
-
cparams.offload_kqv = true;
|
|
262
249
|
cparams.flash_attn = false;
|
|
263
|
-
cparams.flash_attn_kernel = false;
|
|
264
|
-
cparams.flash_attn_causal = true;
|
|
265
|
-
cparams.mmproj = "";
|
|
266
|
-
cparams.image = "";
|
|
267
|
-
cparams.export = "";
|
|
268
|
-
cparams.export_path = "";
|
|
269
|
-
cparams.seed = -1;
|
|
270
250
|
cparams.n_keep = 0;
|
|
271
|
-
cparams.n_discard = -1;
|
|
272
|
-
cparams.n_draft = 0;
|
|
273
251
|
cparams.n_chunks = -1;
|
|
274
|
-
cparams.n_parallel = 1;
|
|
275
252
|
cparams.n_sequences = 1;
|
|
276
|
-
cparams.p_accept = 0.5f;
|
|
277
|
-
cparams.p_split = 0.1f;
|
|
278
|
-
cparams.n_gqa = 8;
|
|
279
|
-
cparams.rms_norm_eps = 5e-6f;
|
|
280
253
|
cparams.model_alias = "unknown";
|
|
281
|
-
|
|
282
|
-
cparams.ubatch_seq_len_max = 1;
|
|
283
|
-
|
|
254
|
+
|
|
284
255
|
LOGI("Initialized common parameters, attempting to load model from: %s", full_model_path.c_str());
|
|
285
256
|
LOGI("Model parameters: n_ctx=%d, n_batch=%d, n_gpu_layers=%d",
|
|
286
257
|
cparams.n_ctx, cparams.n_batch, cparams.n_gpu_layers);
|
|
@@ -335,47 +306,21 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_initContextNative(
|
|
|
335
306
|
ultra_minimal_params.chat_template = "";
|
|
336
307
|
ultra_minimal_params.embedding = false;
|
|
337
308
|
ultra_minimal_params.cont_batching = false;
|
|
338
|
-
ultra_minimal_params.
|
|
339
|
-
ultra_minimal_params.grammar = "";
|
|
340
|
-
ultra_minimal_params.grammar_penalty.clear();
|
|
309
|
+
ultra_minimal_params.n_parallel = 1;
|
|
341
310
|
ultra_minimal_params.antiprompt.clear();
|
|
342
|
-
ultra_minimal_params.lora_adapter.clear();
|
|
343
|
-
ultra_minimal_params.lora_base = "";
|
|
344
|
-
ultra_minimal_params.mul_mat_q = false; // Disable quantized matrix multiplication
|
|
345
|
-
ultra_minimal_params.f16_kv = false; // Disable f16 key-value cache
|
|
346
|
-
ultra_minimal_params.logits_all = false;
|
|
347
311
|
ultra_minimal_params.vocab_only = false;
|
|
348
312
|
ultra_minimal_params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
|
|
349
|
-
ultra_minimal_params.rope_scaling_factor = 0.0f;
|
|
350
|
-
ultra_minimal_params.rope_scaling_orig_ctx_len = 0;
|
|
351
313
|
ultra_minimal_params.yarn_ext_factor = -1.0f;
|
|
352
314
|
ultra_minimal_params.yarn_attn_factor = 1.0f;
|
|
353
315
|
ultra_minimal_params.yarn_beta_fast = 32.0f;
|
|
354
316
|
ultra_minimal_params.yarn_beta_slow = 1.0f;
|
|
355
317
|
ultra_minimal_params.yarn_orig_ctx = 0;
|
|
356
|
-
ultra_minimal_params.offload_kqv = false; // Disable offloading
|
|
357
318
|
ultra_minimal_params.flash_attn = false;
|
|
358
|
-
ultra_minimal_params.flash_attn_kernel = false;
|
|
359
|
-
ultra_minimal_params.flash_attn_causal = true;
|
|
360
|
-
ultra_minimal_params.mmproj = "";
|
|
361
|
-
ultra_minimal_params.image = "";
|
|
362
|
-
ultra_minimal_params.export = "";
|
|
363
|
-
ultra_minimal_params.export_path = "";
|
|
364
|
-
ultra_minimal_params.seed = -1;
|
|
365
319
|
ultra_minimal_params.n_keep = 0;
|
|
366
|
-
ultra_minimal_params.n_discard = -1;
|
|
367
|
-
ultra_minimal_params.n_draft = 0;
|
|
368
320
|
ultra_minimal_params.n_chunks = -1;
|
|
369
|
-
ultra_minimal_params.n_parallel = 1;
|
|
370
321
|
ultra_minimal_params.n_sequences = 1;
|
|
371
|
-
ultra_minimal_params.p_accept = 0.5f;
|
|
372
|
-
ultra_minimal_params.p_split = 0.1f;
|
|
373
|
-
ultra_minimal_params.n_gqa = 8;
|
|
374
|
-
ultra_minimal_params.rms_norm_eps = 5e-6f;
|
|
375
322
|
ultra_minimal_params.model_alias = "unknown";
|
|
376
|
-
|
|
377
|
-
ultra_minimal_params.ubatch_seq_len_max = 1;
|
|
378
|
-
|
|
323
|
+
|
|
379
324
|
// Set up signal handler again for ultra-minimal attempt
|
|
380
325
|
if (sigaction(SIGSEGV, &new_action, &old_action) == 0) {
|
|
381
326
|
LOGI("Signal handler reinstalled for ultra-minimal attempt");
|
|
@@ -435,28 +380,400 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_releaseContextNative(
|
|
|
435
380
|
}
|
|
436
381
|
}
|
|
437
382
|
|
|
438
|
-
JNIEXPORT
|
|
383
|
+
JNIEXPORT jobject JNICALL
|
|
439
384
|
Java_ai_annadata_plugin_capacitor_LlamaCpp_completionNative(
|
|
440
|
-
JNIEnv* env, jobject thiz, jlong context_id,
|
|
385
|
+
JNIEnv* env, jobject thiz, jlong context_id, jobject params) {
|
|
441
386
|
|
|
442
387
|
try {
|
|
388
|
+
LOGI("Starting completion for context: %ld", context_id);
|
|
389
|
+
|
|
443
390
|
auto it = contexts.find(context_id);
|
|
444
391
|
if (it == contexts.end()) {
|
|
392
|
+
LOGE("Context not found: %ld", context_id);
|
|
445
393
|
throw_java_exception(env, "java/lang/IllegalArgumentException", "Invalid context ID");
|
|
446
394
|
return nullptr;
|
|
447
395
|
}
|
|
448
396
|
|
|
449
|
-
|
|
397
|
+
auto& ctx = it->second;
|
|
398
|
+
if (!ctx || !ctx->ctx) {
|
|
399
|
+
LOGE("Invalid context or llama context is null");
|
|
400
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Invalid context");
|
|
401
|
+
return nullptr;
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
// Extract parameters from JSObject using compatible API
|
|
405
|
+
jclass jsObjectClass = env->GetObjectClass(params);
|
|
450
406
|
|
|
451
|
-
//
|
|
452
|
-
|
|
407
|
+
// Try to get method IDs and handle exceptions
|
|
408
|
+
jmethodID getStringMethod = nullptr;
|
|
409
|
+
jmethodID getIntegerMethod = nullptr;
|
|
410
|
+
jmethodID getDoubleMethod = nullptr;
|
|
453
411
|
|
|
454
|
-
//
|
|
455
|
-
|
|
456
|
-
|
|
412
|
+
// Clear any pending exceptions first
|
|
413
|
+
if (env->ExceptionCheck()) {
|
|
414
|
+
env->ExceptionClear();
|
|
415
|
+
}
|
|
457
416
|
|
|
458
|
-
|
|
459
|
-
|
|
417
|
+
try {
|
|
418
|
+
getStringMethod = env->GetMethodID(jsObjectClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;");
|
|
419
|
+
if (env->ExceptionCheck()) {
|
|
420
|
+
env->ExceptionClear();
|
|
421
|
+
getStringMethod = nullptr;
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
getIntegerMethod = env->GetMethodID(jsObjectClass, "getInteger", "(Ljava/lang/String;)Ljava/lang/Integer;");
|
|
425
|
+
if (env->ExceptionCheck()) {
|
|
426
|
+
env->ExceptionClear();
|
|
427
|
+
getIntegerMethod = nullptr;
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
getDoubleMethod = env->GetMethodID(jsObjectClass, "getDouble", "(Ljava/lang/String;)Ljava/lang/Double;");
|
|
431
|
+
if (env->ExceptionCheck()) {
|
|
432
|
+
env->ExceptionClear();
|
|
433
|
+
getDoubleMethod = nullptr;
|
|
434
|
+
}
|
|
435
|
+
} catch (...) {
|
|
436
|
+
LOGE("Exception getting JSObject method IDs");
|
|
437
|
+
if (env->ExceptionCheck()) {
|
|
438
|
+
env->ExceptionClear();
|
|
439
|
+
}
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
// Get prompt with safe method calls
|
|
443
|
+
std::string prompt_str = "Once upon a time";
|
|
444
|
+
jint n_predict = 50;
|
|
445
|
+
jdouble temperature = 0.7;
|
|
446
|
+
|
|
447
|
+
if (getStringMethod) {
|
|
448
|
+
jstring promptKey = jni_utils::string_to_jstring(env, "prompt");
|
|
449
|
+
jstring promptObj = (jstring)env->CallObjectMethod(params, getStringMethod, promptKey);
|
|
450
|
+
if (promptObj && !env->ExceptionCheck()) {
|
|
451
|
+
prompt_str = jni_utils::jstring_to_string(env, promptObj);
|
|
452
|
+
} else if (env->ExceptionCheck()) {
|
|
453
|
+
env->ExceptionClear();
|
|
454
|
+
}
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
// Get n_predict with safe method calls
|
|
458
|
+
if (getIntegerMethod) {
|
|
459
|
+
jstring nPredictKey = jni_utils::string_to_jstring(env, "n_predict");
|
|
460
|
+
jobject nPredictObj = env->CallObjectMethod(params, getIntegerMethod, nPredictKey);
|
|
461
|
+
if (nPredictObj && !env->ExceptionCheck()) {
|
|
462
|
+
n_predict = env->CallIntMethod(nPredictObj, env->GetMethodID(env->FindClass("java/lang/Integer"), "intValue", "()I"));
|
|
463
|
+
if (env->ExceptionCheck()) {
|
|
464
|
+
env->ExceptionClear();
|
|
465
|
+
n_predict = 50; // fallback
|
|
466
|
+
}
|
|
467
|
+
} else if (env->ExceptionCheck()) {
|
|
468
|
+
env->ExceptionClear();
|
|
469
|
+
}
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
// Get temperature with safe method calls
|
|
473
|
+
if (getDoubleMethod) {
|
|
474
|
+
jstring temperatureKey = jni_utils::string_to_jstring(env, "temperature");
|
|
475
|
+
jobject tempObj = env->CallObjectMethod(params, getDoubleMethod, temperatureKey);
|
|
476
|
+
if (tempObj && !env->ExceptionCheck()) {
|
|
477
|
+
temperature = env->CallDoubleMethod(tempObj, env->GetMethodID(env->FindClass("java/lang/Double"), "doubleValue", "()D"));
|
|
478
|
+
if (env->ExceptionCheck()) {
|
|
479
|
+
env->ExceptionClear();
|
|
480
|
+
temperature = 0.7; // fallback
|
|
481
|
+
}
|
|
482
|
+
} else if (env->ExceptionCheck()) {
|
|
483
|
+
env->ExceptionClear();
|
|
484
|
+
}
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
LOGI("Completion params - prompt: %s, n_predict: %d, temperature: %.2f",
|
|
488
|
+
prompt_str.c_str(), n_predict, temperature);
|
|
489
|
+
|
|
490
|
+
// Set sampling parameters based on extracted values
|
|
491
|
+
ctx->params.sampling.temp = temperature;
|
|
492
|
+
ctx->params.sampling.top_k = 40; // Default value
|
|
493
|
+
ctx->params.sampling.top_p = 0.95f; // Default value
|
|
494
|
+
ctx->params.sampling.penalty_repeat = 1.1f; // Default value (correct field name)
|
|
495
|
+
ctx->params.n_predict = n_predict;
|
|
496
|
+
ctx->params.prompt = prompt_str;
|
|
497
|
+
|
|
498
|
+
LOGI("Updated context sampling params - temp: %.2f, top_k: %d, top_p: %.2f",
|
|
499
|
+
ctx->params.sampling.temp, ctx->params.sampling.top_k, ctx->params.sampling.top_p);
|
|
500
|
+
|
|
501
|
+
// Tokenize the prompt
|
|
502
|
+
capllama::llama_cap_tokenize_result tokenize_result = ctx->tokenize(prompt_str, {});
|
|
503
|
+
std::vector<llama_token> prompt_tokens = tokenize_result.tokens;
|
|
504
|
+
|
|
505
|
+
LOGI("Tokenized prompt into %zu tokens", prompt_tokens.size());
|
|
506
|
+
|
|
507
|
+
// Initialize completion context if not already done
|
|
508
|
+
if (!ctx->completion) {
|
|
509
|
+
LOGI("Initializing completion context for the first time");
|
|
510
|
+
|
|
511
|
+
// Validate parent context before creating completion
|
|
512
|
+
if (!ctx->ctx || !ctx->model) {
|
|
513
|
+
LOGE("Parent context is invalid - missing llama context or model");
|
|
514
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Parent context is not properly initialized");
|
|
515
|
+
return nullptr;
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
try {
|
|
519
|
+
LOGI("Creating llama_cap_context_completion...");
|
|
520
|
+
LOGI("Parent context pointer: %p", ctx.get());
|
|
521
|
+
LOGI("Parent context->ctx: %p", ctx->ctx);
|
|
522
|
+
LOGI("Parent context->model: %p", ctx->model);
|
|
523
|
+
|
|
524
|
+
// Additional safety checks before constructor
|
|
525
|
+
if (!ctx.get()) {
|
|
526
|
+
LOGE("Parent context pointer is null");
|
|
527
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Parent context pointer is null");
|
|
528
|
+
return nullptr;
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
ctx->completion = new capllama::llama_cap_context_completion(ctx.get());
|
|
532
|
+
|
|
533
|
+
if (!ctx->completion) {
|
|
534
|
+
LOGE("Failed to create completion context - constructor returned null");
|
|
535
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Failed to create completion context");
|
|
536
|
+
return nullptr;
|
|
537
|
+
}
|
|
538
|
+
|
|
539
|
+
LOGI("Completion context created successfully at: %p", ctx->completion);
|
|
540
|
+
|
|
541
|
+
LOGI("Initializing sampling for completion context...");
|
|
542
|
+
LOGI("Parent context params before initSampling - model: %p, params: %p", ctx->model, &(ctx->params));
|
|
543
|
+
LOGI("Parent context sampling params - temperature: %.2f, top_k: %d, top_p: %.2f",
|
|
544
|
+
ctx->params.sampling.temp, ctx->params.sampling.top_k, ctx->params.sampling.top_p);
|
|
545
|
+
|
|
546
|
+
bool sampling_result = false;
|
|
547
|
+
try {
|
|
548
|
+
sampling_result = ctx->completion->initSampling();
|
|
549
|
+
LOGI("initSampling completed, result: %s", sampling_result ? "true" : "false");
|
|
550
|
+
LOGI("Sampler pointer after init: %p", ctx->completion->ctx_sampling);
|
|
551
|
+
} catch (const std::exception& e) {
|
|
552
|
+
LOGE("Exception in initSampling: %s", e.what());
|
|
553
|
+
delete ctx->completion;
|
|
554
|
+
ctx->completion = nullptr;
|
|
555
|
+
throw_java_exception(env, "java/lang/RuntimeException",
|
|
556
|
+
("Failed to initialize sampling: " + std::string(e.what())).c_str());
|
|
557
|
+
return nullptr;
|
|
558
|
+
} catch (...) {
|
|
559
|
+
LOGE("Unknown exception in initSampling");
|
|
560
|
+
delete ctx->completion;
|
|
561
|
+
ctx->completion = nullptr;
|
|
562
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Unknown error in sampling initialization");
|
|
563
|
+
return nullptr;
|
|
564
|
+
}
|
|
565
|
+
|
|
566
|
+
if (!sampling_result || !ctx->completion->ctx_sampling) {
|
|
567
|
+
LOGE("Failed to initialize sampling - result: %s, sampler: %p",
|
|
568
|
+
sampling_result ? "true" : "false", ctx->completion->ctx_sampling);
|
|
569
|
+
delete ctx->completion;
|
|
570
|
+
ctx->completion = nullptr;
|
|
571
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Failed to initialize sampling context");
|
|
572
|
+
return nullptr;
|
|
573
|
+
}
|
|
574
|
+
|
|
575
|
+
LOGI("Completion context initialized successfully");
|
|
576
|
+
} catch (const std::exception& e) {
|
|
577
|
+
LOGE("Exception during completion context creation: %s", e.what());
|
|
578
|
+
if (ctx->completion) {
|
|
579
|
+
delete ctx->completion;
|
|
580
|
+
ctx->completion = nullptr;
|
|
581
|
+
}
|
|
582
|
+
throw_java_exception(env, "java/lang/RuntimeException",
|
|
583
|
+
("Failed to create completion context: " + std::string(e.what())).c_str());
|
|
584
|
+
return nullptr;
|
|
585
|
+
} catch (...) {
|
|
586
|
+
LOGE("Unknown exception during completion context creation");
|
|
587
|
+
if (ctx->completion) {
|
|
588
|
+
delete ctx->completion;
|
|
589
|
+
ctx->completion = nullptr;
|
|
590
|
+
}
|
|
591
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Unknown error during completion context creation");
|
|
592
|
+
return nullptr;
|
|
593
|
+
}
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
// Set up sampling parameters
|
|
597
|
+
// Note: For now, we'll use the completion context's default parameters
|
|
598
|
+
// TODO: Update sampling parameters with user values
|
|
599
|
+
//
|
|
600
|
+
// Declare variables outside try block so they're accessible later
|
|
601
|
+
std::string generated_text;
|
|
602
|
+
int tokens_generated = 0;
|
|
603
|
+
|
|
604
|
+
try {
|
|
605
|
+
LOGI("Rewinding completion context...");
|
|
606
|
+
try {
|
|
607
|
+
ctx->completion->rewind();
|
|
608
|
+
LOGI("Rewind completed successfully");
|
|
609
|
+
} catch (const std::exception& e) {
|
|
610
|
+
LOGE("Exception in rewind: %s", e.what());
|
|
611
|
+
throw;
|
|
612
|
+
}
|
|
613
|
+
|
|
614
|
+
LOGI("Loading prompt into completion context...");
|
|
615
|
+
try {
|
|
616
|
+
// Validate sampler is properly initialized before loadPrompt
|
|
617
|
+
if (!ctx->completion->ctx_sampling) {
|
|
618
|
+
LOGE("Sampler context is null - reinitializing");
|
|
619
|
+
if (!ctx->completion->initSampling()) {
|
|
620
|
+
LOGE("Failed to reinitialize sampling");
|
|
621
|
+
throw std::runtime_error("Sampler initialization failed");
|
|
622
|
+
}
|
|
623
|
+
LOGI("Sampler reinitialized successfully");
|
|
624
|
+
}
|
|
625
|
+
|
|
626
|
+
ctx->completion->loadPrompt({});
|
|
627
|
+
LOGI("loadPrompt completed successfully");
|
|
628
|
+
} catch (const std::exception& e) {
|
|
629
|
+
LOGE("Exception in loadPrompt: %s", e.what());
|
|
630
|
+
throw;
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
LOGI("Beginning completion generation...");
|
|
634
|
+
try {
|
|
635
|
+
ctx->completion->beginCompletion();
|
|
636
|
+
LOGI("beginCompletion completed successfully");
|
|
637
|
+
} catch (const std::exception& e) {
|
|
638
|
+
LOGE("Exception in beginCompletion: %s", e.what());
|
|
639
|
+
throw;
|
|
640
|
+
}
|
|
641
|
+
|
|
642
|
+
LOGI("Starting token generation loop (max tokens: %d)...", n_predict);
|
|
643
|
+
|
|
644
|
+
while (tokens_generated < n_predict && !ctx->completion->is_interrupted) {
|
|
645
|
+
try {
|
|
646
|
+
LOGI("Generating token %d...", tokens_generated + 1);
|
|
647
|
+
auto token_output = ctx->completion->nextToken();
|
|
648
|
+
|
|
649
|
+
// Check for end-of-sequence (simplified check)
|
|
650
|
+
if (token_output.tok == 2) { // Most models use 2 as EOS token
|
|
651
|
+
LOGI("Reached EOS token, stopping generation");
|
|
652
|
+
break;
|
|
653
|
+
}
|
|
654
|
+
|
|
655
|
+
// Convert token to text
|
|
656
|
+
std::string token_text = capllama::tokens_to_output_formatted_string(ctx->ctx, token_output.tok);
|
|
657
|
+
generated_text += token_text;
|
|
658
|
+
tokens_generated++;
|
|
659
|
+
|
|
660
|
+
LOGI("Generated token %d (ID: %d): %s", tokens_generated, token_output.tok, token_text.c_str());
|
|
661
|
+
|
|
662
|
+
} catch (const std::exception& e) {
|
|
663
|
+
LOGE("Exception during token generation %d: %s", tokens_generated + 1, e.what());
|
|
664
|
+
break;
|
|
665
|
+
} catch (...) {
|
|
666
|
+
LOGE("Unknown exception during token generation %d", tokens_generated + 1);
|
|
667
|
+
break;
|
|
668
|
+
}
|
|
669
|
+
}
|
|
670
|
+
|
|
671
|
+
LOGI("Token generation completed. Generated %d tokens.", tokens_generated);
|
|
672
|
+
|
|
673
|
+
// End completion
|
|
674
|
+
LOGI("Ending completion...");
|
|
675
|
+
ctx->completion->endCompletion();
|
|
676
|
+
|
|
677
|
+
} catch (const std::exception& e) {
|
|
678
|
+
LOGE("Exception during completion process: %s", e.what());
|
|
679
|
+
try {
|
|
680
|
+
ctx->completion->endCompletion();
|
|
681
|
+
} catch (...) {
|
|
682
|
+
LOGE("Failed to properly end completion after exception");
|
|
683
|
+
}
|
|
684
|
+
throw_java_exception(env, "java/lang/RuntimeException",
|
|
685
|
+
("Completion process failed: " + std::string(e.what())).c_str());
|
|
686
|
+
return nullptr;
|
|
687
|
+
} catch (...) {
|
|
688
|
+
LOGE("Unknown exception during completion process");
|
|
689
|
+
try {
|
|
690
|
+
ctx->completion->endCompletion();
|
|
691
|
+
} catch (...) {
|
|
692
|
+
LOGE("Failed to properly end completion after unknown exception");
|
|
693
|
+
}
|
|
694
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Unknown error during completion process");
|
|
695
|
+
return nullptr;
|
|
696
|
+
}
|
|
697
|
+
|
|
698
|
+
LOGI("Completion finished. Generated %d tokens: %s", tokens_generated, generated_text.c_str());
|
|
699
|
+
|
|
700
|
+
// Create result HashMap
|
|
701
|
+
jclass hashMapClass = env->FindClass("java/util/HashMap");
|
|
702
|
+
jmethodID hashMapConstructor = env->GetMethodID(hashMapClass, "<init>", "()V");
|
|
703
|
+
jmethodID putMethod = env->GetMethodID(hashMapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
|
|
704
|
+
|
|
705
|
+
jobject resultMap = env->NewObject(hashMapClass, hashMapConstructor);
|
|
706
|
+
|
|
707
|
+
// Add completion results
|
|
708
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
709
|
+
jni_utils::string_to_jstring(env, "text"), jni_utils::string_to_jstring(env, generated_text));
|
|
710
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
711
|
+
jni_utils::string_to_jstring(env, "content"), jni_utils::string_to_jstring(env, generated_text));
|
|
712
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
713
|
+
jni_utils::string_to_jstring(env, "reasoning_content"), jni_utils::string_to_jstring(env, ""));
|
|
714
|
+
|
|
715
|
+
// Create empty tool_calls array
|
|
716
|
+
jclass arrayListClass = env->FindClass("java/util/ArrayList");
|
|
717
|
+
jmethodID arrayListConstructor = env->GetMethodID(arrayListClass, "<init>", "()V");
|
|
718
|
+
jobject emptyToolCalls = env->NewObject(arrayListClass, arrayListConstructor);
|
|
719
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
720
|
+
jni_utils::string_to_jstring(env, "tool_calls"), emptyToolCalls);
|
|
721
|
+
|
|
722
|
+
// Add token counts and status
|
|
723
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
724
|
+
jni_utils::string_to_jstring(env, "tokens_predicted"),
|
|
725
|
+
env->NewObject(env->FindClass("java/lang/Integer"),
|
|
726
|
+
env->GetMethodID(env->FindClass("java/lang/Integer"), "<init>", "(I)V"), tokens_generated));
|
|
727
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
728
|
+
jni_utils::string_to_jstring(env, "tokens_evaluated"),
|
|
729
|
+
env->NewObject(env->FindClass("java/lang/Integer"),
|
|
730
|
+
env->GetMethodID(env->FindClass("java/lang/Integer"), "<init>", "(I)V"), (jint)prompt_tokens.size()));
|
|
731
|
+
|
|
732
|
+
// Add completion status flags
|
|
733
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
734
|
+
jni_utils::string_to_jstring(env, "truncated"),
|
|
735
|
+
env->NewObject(env->FindClass("java/lang/Boolean"),
|
|
736
|
+
env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"), JNI_FALSE));
|
|
737
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
738
|
+
jni_utils::string_to_jstring(env, "stopped_eos"),
|
|
739
|
+
env->NewObject(env->FindClass("java/lang/Boolean"),
|
|
740
|
+
env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"),
|
|
741
|
+
tokens_generated < n_predict ? JNI_TRUE : JNI_FALSE));
|
|
742
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
743
|
+
jni_utils::string_to_jstring(env, "stopped_limit"),
|
|
744
|
+
env->NewObject(env->FindClass("java/lang/Boolean"),
|
|
745
|
+
env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"),
|
|
746
|
+
tokens_generated >= n_predict ? JNI_TRUE : JNI_FALSE));
|
|
747
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
748
|
+
jni_utils::string_to_jstring(env, "context_full"),
|
|
749
|
+
env->NewObject(env->FindClass("java/lang/Boolean"),
|
|
750
|
+
env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"), JNI_FALSE));
|
|
751
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
752
|
+
jni_utils::string_to_jstring(env, "interrupted"),
|
|
753
|
+
env->NewObject(env->FindClass("java/lang/Boolean"),
|
|
754
|
+
env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"), JNI_FALSE));
|
|
755
|
+
|
|
756
|
+
// Add empty strings for stop reasons
|
|
757
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
758
|
+
jni_utils::string_to_jstring(env, "stopped_word"), jni_utils::string_to_jstring(env, ""));
|
|
759
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
760
|
+
jni_utils::string_to_jstring(env, "stopping_word"), jni_utils::string_to_jstring(env, ""));
|
|
761
|
+
|
|
762
|
+
// Add timing information (basic)
|
|
763
|
+
jobject timingsMap = env->NewObject(hashMapClass, hashMapConstructor);
|
|
764
|
+
env->CallObjectMethod(timingsMap, putMethod,
|
|
765
|
+
jni_utils::string_to_jstring(env, "prompt_n"),
|
|
766
|
+
env->NewObject(env->FindClass("java/lang/Integer"),
|
|
767
|
+
env->GetMethodID(env->FindClass("java/lang/Integer"), "<init>", "(I)V"), (jint)prompt_tokens.size()));
|
|
768
|
+
env->CallObjectMethod(timingsMap, putMethod,
|
|
769
|
+
jni_utils::string_to_jstring(env, "predicted_n"),
|
|
770
|
+
env->NewObject(env->FindClass("java/lang/Integer"),
|
|
771
|
+
env->GetMethodID(env->FindClass("java/lang/Integer"), "<init>", "(I)V"), tokens_generated));
|
|
772
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
773
|
+
jni_utils::string_to_jstring(env, "timings"), timingsMap);
|
|
774
|
+
|
|
775
|
+
LOGI("Completion result created successfully");
|
|
776
|
+
return resultMap;
|
|
460
777
|
|
|
461
778
|
} catch (const std::exception& e) {
|
|
462
779
|
LOGE("Exception in completion: %s", e.what());
|
|
@@ -495,7 +812,7 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_getFormattedChatNative(
|
|
|
495
812
|
std::string messages_str = jstring_to_string(env, messages);
|
|
496
813
|
std::string template_str = jstring_to_string(env, chat_template);
|
|
497
814
|
|
|
498
|
-
|
|
815
|
+
capllama::llama_cap_context* context = it->second.get();
|
|
499
816
|
|
|
500
817
|
// Format chat using the context's method
|
|
501
818
|
std::string result = context->getFormattedChat(messages_str, template_str);
|
|
@@ -515,7 +832,7 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_toggleNativeLogNative(
|
|
|
515
832
|
JNIEnv* env, jobject thiz, jboolean enabled) {
|
|
516
833
|
|
|
517
834
|
try {
|
|
518
|
-
rnllama::rnllama_verbose = jboolean_to_bool(enabled);
|
|
835
|
+
// rnllama::rnllama_verbose = jboolean_to_bool(enabled); // This line is removed as per the edit hint
|
|
519
836
|
LOGI("Native logging %s", enabled ? "enabled" : "disabled");
|
|
520
837
|
return bool_to_jboolean(true);
|
|
521
838
|
} catch (const std::exception& e) {
|
|
@@ -525,7 +842,431 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_toggleNativeLogNative(
|
|
|
525
842
|
}
|
|
526
843
|
}
|
|
527
844
|
|
|
845
|
+
JNIEXPORT jobject JNICALL
|
|
846
|
+
Java_ai_annadata_plugin_capacitor_LlamaCpp_modelInfoNative(
|
|
847
|
+
JNIEnv* env, jobject thiz, jstring model_path) {
|
|
848
|
+
|
|
849
|
+
try {
|
|
850
|
+
std::string model_path_str = jstring_to_string(env, model_path);
|
|
851
|
+
LOGI("Getting model info for: %s", model_path_str.c_str());
|
|
852
|
+
|
|
853
|
+
// Extract filename from path
|
|
854
|
+
std::string filename = model_path_str;
|
|
855
|
+
size_t last_slash = model_path_str.find_last_of('/');
|
|
856
|
+
if (last_slash != std::string::npos) {
|
|
857
|
+
filename = model_path_str.substr(last_slash + 1);
|
|
858
|
+
}
|
|
859
|
+
LOGI("Extracted filename for model info: %s", filename.c_str());
|
|
528
860
|
|
|
861
|
+
// List all possible paths we should check (same as initContextNative)
|
|
862
|
+
std::vector<std::string> paths_to_check = {
|
|
863
|
+
model_path_str, // Try the original path first
|
|
864
|
+
"/data/data/ai.annadata.llamacpp/files/" + filename,
|
|
865
|
+
"/data/data/ai.annadata.llamacpp/files/Documents/" + filename,
|
|
866
|
+
"/storage/emulated/0/Android/data/ai.annadata.llamacpp/files/" + filename,
|
|
867
|
+
"/storage/emulated/0/Android/data/ai.annadata.llamacpp/files/Documents/" + filename,
|
|
868
|
+
"/storage/emulated/0/Documents/" + filename,
|
|
869
|
+
"/storage/emulated/0/Download/" + filename
|
|
870
|
+
};
|
|
871
|
+
|
|
872
|
+
// Check each path and find the actual file
|
|
873
|
+
std::string full_model_path;
|
|
874
|
+
bool file_found = false;
|
|
875
|
+
|
|
876
|
+
for (const auto& path : paths_to_check) {
|
|
877
|
+
LOGI("Checking path for model info: %s", path.c_str());
|
|
878
|
+
std::ifstream file_check(path, std::ios::binary);
|
|
879
|
+
if (file_check.good()) {
|
|
880
|
+
file_check.seekg(0, std::ios::end);
|
|
881
|
+
std::streamsize file_size = file_check.tellg();
|
|
882
|
+
file_check.seekg(0, std::ios::beg);
|
|
883
|
+
|
|
884
|
+
// Validate file size
|
|
885
|
+
if (file_size < 1024 * 1024) { // Less than 1MB
|
|
886
|
+
LOGE("Model file is too small, likely corrupted: %s", path.c_str());
|
|
887
|
+
file_check.close();
|
|
888
|
+
continue; // Try next path
|
|
889
|
+
}
|
|
890
|
+
|
|
891
|
+
// Check if it's a valid GGUF file by reading the magic number
|
|
892
|
+
char magic[4];
|
|
893
|
+
if (file_check.read(magic, 4)) {
|
|
894
|
+
if (magic[0] == 'G' && magic[1] == 'G' && magic[2] == 'U' && magic[3] == 'F') {
|
|
895
|
+
LOGI("Valid GGUF file detected for model info at: %s", path.c_str());
|
|
896
|
+
full_model_path = path;
|
|
897
|
+
file_found = true;
|
|
898
|
+
file_check.close();
|
|
899
|
+
break;
|
|
900
|
+
} else {
|
|
901
|
+
LOGI("File does not appear to be a GGUF file (magic: %c%c%c%c) at: %s",
|
|
902
|
+
magic[0], magic[1], magic[2], magic[3], path.c_str());
|
|
903
|
+
}
|
|
904
|
+
}
|
|
905
|
+
file_check.close();
|
|
906
|
+
} else {
|
|
907
|
+
LOGI("File not found at: %s", path.c_str());
|
|
908
|
+
}
|
|
909
|
+
}
|
|
910
|
+
|
|
911
|
+
if (!file_found) {
|
|
912
|
+
LOGE("Model file not found in any of the checked paths");
|
|
913
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Model file not found");
|
|
914
|
+
return nullptr;
|
|
915
|
+
}
|
|
916
|
+
|
|
917
|
+
// Now use the found path for getting model info
|
|
918
|
+
std::ifstream file_check(full_model_path, std::ios::binary);
|
|
919
|
+
|
|
920
|
+
// Get file size
|
|
921
|
+
file_check.seekg(0, std::ios::end);
|
|
922
|
+
std::streamsize file_size = file_check.tellg();
|
|
923
|
+
file_check.seekg(0, std::ios::beg);
|
|
924
|
+
|
|
925
|
+
// Check GGUF magic number
|
|
926
|
+
char magic[4];
|
|
927
|
+
if (!file_check.read(magic, 4)) {
|
|
928
|
+
LOGE("Failed to read magic number from: %s", full_model_path.c_str());
|
|
929
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Failed to read model file header");
|
|
930
|
+
return nullptr;
|
|
931
|
+
}
|
|
932
|
+
|
|
933
|
+
if (magic[0] != 'G' || magic[1] != 'G' || magic[2] != 'U' || magic[3] != 'F') {
|
|
934
|
+
LOGE("Invalid GGUF file (magic: %c%c%c%c): %s", magic[0], magic[1], magic[2], magic[3], full_model_path.c_str());
|
|
935
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Invalid GGUF file format");
|
|
936
|
+
return nullptr;
|
|
937
|
+
}
|
|
938
|
+
|
|
939
|
+
// Read GGUF version
|
|
940
|
+
uint32_t version;
|
|
941
|
+
if (!file_check.read(reinterpret_cast<char*>(&version), sizeof(version))) {
|
|
942
|
+
LOGE("Failed to read GGUF version from: %s", full_model_path.c_str());
|
|
943
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Failed to read GGUF version");
|
|
944
|
+
return nullptr;
|
|
945
|
+
}
|
|
946
|
+
|
|
947
|
+
file_check.close();
|
|
948
|
+
|
|
949
|
+
// Create Java HashMap
|
|
950
|
+
jclass hashMapClass = env->FindClass("java/util/HashMap");
|
|
951
|
+
jmethodID hashMapConstructor = env->GetMethodID(hashMapClass, "<init>", "()V");
|
|
952
|
+
jmethodID putMethod = env->GetMethodID(hashMapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
|
|
953
|
+
|
|
954
|
+
jobject hashMap = env->NewObject(hashMapClass, hashMapConstructor);
|
|
955
|
+
|
|
956
|
+
// Add model info to HashMap
|
|
957
|
+
env->CallObjectMethod(hashMap, putMethod,
|
|
958
|
+
string_to_jstring(env, "path"),
|
|
959
|
+
string_to_jstring(env, full_model_path));
|
|
960
|
+
|
|
961
|
+
env->CallObjectMethod(hashMap, putMethod,
|
|
962
|
+
string_to_jstring(env, "size"),
|
|
963
|
+
env->NewObject(env->FindClass("java/lang/Long"),
|
|
964
|
+
env->GetMethodID(env->FindClass("java/lang/Long"), "<init>", "(J)V"),
|
|
965
|
+
static_cast<jlong>(file_size)));
|
|
966
|
+
|
|
967
|
+
env->CallObjectMethod(hashMap, putMethod,
|
|
968
|
+
string_to_jstring(env, "desc"),
|
|
969
|
+
string_to_jstring(env, "GGUF Model (v" + std::to_string(version) + ")"));
|
|
970
|
+
|
|
971
|
+
env->CallObjectMethod(hashMap, putMethod,
|
|
972
|
+
string_to_jstring(env, "nEmbd"),
|
|
973
|
+
env->NewObject(env->FindClass("java/lang/Integer"),
|
|
974
|
+
env->GetMethodID(env->FindClass("java/lang/Integer"), "<init>", "(I)V"),
|
|
975
|
+
0)); // Will be filled by actual model loading
|
|
976
|
+
|
|
977
|
+
env->CallObjectMethod(hashMap, putMethod,
|
|
978
|
+
string_to_jstring(env, "nParams"),
|
|
979
|
+
env->NewObject(env->FindClass("java/lang/Integer"),
|
|
980
|
+
env->GetMethodID(env->FindClass("java/lang/Integer"), "<init>", "(I)V"),
|
|
981
|
+
0)); // Will be filled by actual model loading
|
|
982
|
+
|
|
983
|
+
LOGI("Model info retrieved successfully from %s: size=%ld, version=%u", full_model_path.c_str(), file_size, version);
|
|
984
|
+
return hashMap;
|
|
985
|
+
|
|
986
|
+
} catch (const std::exception& e) {
|
|
987
|
+
LOGE("Exception in modelInfo: %s", e.what());
|
|
988
|
+
throw_java_exception(env, "java/lang/RuntimeException", e.what());
|
|
989
|
+
return nullptr;
|
|
990
|
+
}
|
|
991
|
+
}
|
|
992
|
+
|
|
993
|
+
|
|
994
|
+
|
|
995
|
+
JNIEXPORT jstring JNICALL
|
|
996
|
+
Java_ai_annadata_plugin_capacitor_LlamaCpp_downloadModelNative(
|
|
997
|
+
JNIEnv* env, jobject thiz, jstring url, jstring filename) {
|
|
998
|
+
|
|
999
|
+
try {
|
|
1000
|
+
std::string url_str = jstring_to_string(env, url);
|
|
1001
|
+
std::string filename_str = jstring_to_string(env, filename);
|
|
1002
|
+
|
|
1003
|
+
LOGI("Preparing download path for model: %s", filename_str.c_str());
|
|
1004
|
+
|
|
1005
|
+
// Determine local storage path (use external storage for large files)
|
|
1006
|
+
std::string local_path = "/storage/emulated/0/Android/data/ai.annadata.llamacpp/files/Models/" + filename_str;
|
|
1007
|
+
|
|
1008
|
+
// Create directory if it doesn't exist
|
|
1009
|
+
std::string dir_path = "/storage/emulated/0/Android/data/ai.annadata.llamacpp/files/Models/";
|
|
1010
|
+
std::filesystem::create_directories(dir_path);
|
|
1011
|
+
|
|
1012
|
+
LOGI("Download path prepared: %s", local_path.c_str());
|
|
1013
|
+
|
|
1014
|
+
return string_to_jstring(env, local_path);
|
|
1015
|
+
|
|
1016
|
+
} catch (const std::exception& e) {
|
|
1017
|
+
LOGE("Exception in downloadModel: %s", e.what());
|
|
1018
|
+
throw_java_exception(env, "java/lang/RuntimeException", e.what());
|
|
1019
|
+
return nullptr;
|
|
1020
|
+
}
|
|
1021
|
+
}
|
|
1022
|
+
|
|
1023
|
+
JNIEXPORT jobject JNICALL
|
|
1024
|
+
Java_ai_annadata_plugin_capacitor_LlamaCpp_getDownloadProgressNative(
|
|
1025
|
+
JNIEnv* env, jobject thiz, jstring url) {
|
|
1026
|
+
|
|
1027
|
+
try {
|
|
1028
|
+
// For now, return a placeholder since we'll handle download in Java
|
|
1029
|
+
// This can be enhanced later to track actual download progress
|
|
1030
|
+
|
|
1031
|
+
jclass hashMapClass = env->FindClass("java/util/HashMap");
|
|
1032
|
+
jmethodID hashMapConstructor = env->GetMethodID(hashMapClass, "<init>", "()V");
|
|
1033
|
+
jmethodID putMethod = env->GetMethodID(hashMapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
|
|
1034
|
+
|
|
1035
|
+
jobject hashMap = env->NewObject(hashMapClass, hashMapConstructor);
|
|
1036
|
+
|
|
1037
|
+
// Return placeholder progress info
|
|
1038
|
+
env->CallObjectMethod(hashMap, putMethod,
|
|
1039
|
+
string_to_jstring(env, "progress"),
|
|
1040
|
+
env->NewObject(env->FindClass("java/lang/Double"),
|
|
1041
|
+
env->GetMethodID(env->FindClass("java/lang/Double"), "<init>", "(D)V"),
|
|
1042
|
+
0.0));
|
|
1043
|
+
|
|
1044
|
+
env->CallObjectMethod(hashMap, putMethod,
|
|
1045
|
+
string_to_jstring(env, "completed"),
|
|
1046
|
+
env->NewObject(env->FindClass("java/lang/Boolean"),
|
|
1047
|
+
env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"),
|
|
1048
|
+
false));
|
|
1049
|
+
|
|
1050
|
+
env->CallObjectMethod(hashMap, putMethod,
|
|
1051
|
+
string_to_jstring(env, "failed"),
|
|
1052
|
+
env->NewObject(env->FindClass("java/lang/Boolean"),
|
|
1053
|
+
env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"),
|
|
1054
|
+
false));
|
|
1055
|
+
|
|
1056
|
+
return hashMap;
|
|
1057
|
+
|
|
1058
|
+
} catch (const std::exception& e) {
|
|
1059
|
+
LOGE("Exception in getDownloadProgress: %s", e.what());
|
|
1060
|
+
throw_java_exception(env, "java/lang/RuntimeException", e.what());
|
|
1061
|
+
return nullptr;
|
|
1062
|
+
}
|
|
1063
|
+
}
|
|
1064
|
+
|
|
1065
|
+
JNIEXPORT jboolean JNICALL
|
|
1066
|
+
Java_ai_annadata_plugin_capacitor_LlamaCpp_cancelDownloadNative(
|
|
1067
|
+
JNIEnv* env, jobject thiz, jstring url) {
|
|
1068
|
+
|
|
1069
|
+
try {
|
|
1070
|
+
// For now, return false since we'll handle download cancellation in Java
|
|
1071
|
+
// This can be enhanced later to actually cancel downloads
|
|
1072
|
+
return JNI_FALSE;
|
|
1073
|
+
|
|
1074
|
+
} catch (const std::exception& e) {
|
|
1075
|
+
LOGE("Exception in cancelDownload: %s", e.what());
|
|
1076
|
+
throw_java_exception(env, "java/lang/RuntimeException", e.what());
|
|
1077
|
+
return JNI_FALSE;
|
|
1078
|
+
}
|
|
1079
|
+
}
|
|
1080
|
+
|
|
1081
|
+
JNIEXPORT jobject JNICALL
|
|
1082
|
+
Java_ai_annadata_plugin_capacitor_LlamaCpp_getAvailableModelsNative(
|
|
1083
|
+
JNIEnv* env, jobject thiz) {
|
|
1084
|
+
|
|
1085
|
+
try {
|
|
1086
|
+
std::string models_dir = "/storage/emulated/0/Android/data/ai.annadata.llamacpp/files/Models/";
|
|
1087
|
+
|
|
1088
|
+
// Create Java ArrayList
|
|
1089
|
+
jclass arrayListClass = env->FindClass("java/util/ArrayList");
|
|
1090
|
+
jmethodID arrayListConstructor = env->GetMethodID(arrayListClass, "<init>", "()V");
|
|
1091
|
+
jmethodID addMethod = env->GetMethodID(arrayListClass, "add", "(Ljava/lang/Object;)Z");
|
|
1092
|
+
|
|
1093
|
+
jobject arrayList = env->NewObject(arrayListClass, arrayListConstructor);
|
|
1094
|
+
|
|
1095
|
+
if (std::filesystem::exists(models_dir)) {
|
|
1096
|
+
for (const auto& entry : std::filesystem::directory_iterator(models_dir)) {
|
|
1097
|
+
if (entry.is_regular_file() && entry.path().extension() == ".gguf") {
|
|
1098
|
+
std::string filename = entry.path().filename().string();
|
|
1099
|
+
std::string full_path = entry.path().string();
|
|
1100
|
+
size_t file_size = entry.file_size();
|
|
1101
|
+
|
|
1102
|
+
// Create model info HashMap
|
|
1103
|
+
jclass hashMapClass = env->FindClass("java/util/HashMap");
|
|
1104
|
+
jmethodID hashMapConstructor = env->GetMethodID(hashMapClass, "<init>", "()V");
|
|
1105
|
+
jmethodID putMethod = env->GetMethodID(hashMapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
|
|
1106
|
+
|
|
1107
|
+
jobject modelInfo = env->NewObject(hashMapClass, hashMapConstructor);
|
|
1108
|
+
|
|
1109
|
+
env->CallObjectMethod(modelInfo, putMethod,
|
|
1110
|
+
string_to_jstring(env, "name"),
|
|
1111
|
+
string_to_jstring(env, filename));
|
|
1112
|
+
|
|
1113
|
+
env->CallObjectMethod(modelInfo, putMethod,
|
|
1114
|
+
string_to_jstring(env, "path"),
|
|
1115
|
+
string_to_jstring(env, full_path));
|
|
1116
|
+
|
|
1117
|
+
env->CallObjectMethod(modelInfo, putMethod,
|
|
1118
|
+
string_to_jstring(env, "size"),
|
|
1119
|
+
env->NewObject(env->FindClass("java/lang/Long"),
|
|
1120
|
+
env->GetMethodID(env->FindClass("java/lang/Long"), "<init>", "(J)V"),
|
|
1121
|
+
static_cast<jlong>(file_size)));
|
|
1122
|
+
|
|
1123
|
+
// Add to ArrayList
|
|
1124
|
+
env->CallBooleanMethod(arrayList, addMethod, modelInfo);
|
|
1125
|
+
}
|
|
1126
|
+
}
|
|
1127
|
+
}
|
|
1128
|
+
|
|
1129
|
+
return arrayList;
|
|
1130
|
+
|
|
1131
|
+
} catch (const std::exception& e) {
|
|
1132
|
+
LOGE("Exception in getAvailableModels: %s", e.what());
|
|
1133
|
+
throw_java_exception(env, "java/lang/RuntimeException", e.what());
|
|
1134
|
+
return nullptr;
|
|
1135
|
+
}
|
|
1136
|
+
}
|
|
1137
|
+
|
|
1138
|
+
// MARK: - Tokenization methods
|
|
1139
|
+
|
|
1140
|
+
JNIEXPORT jobject JNICALL
|
|
1141
|
+
Java_ai_annadata_plugin_capacitor_LlamaCpp_tokenizeNative(
|
|
1142
|
+
JNIEnv* env, jobject thiz, jlong contextId, jstring text, jobjectArray imagePaths) {
|
|
1143
|
+
|
|
1144
|
+
try {
|
|
1145
|
+
LOGI("Tokenizing with context ID: %ld", contextId);
|
|
1146
|
+
|
|
1147
|
+
std::string text_str = jni_utils::jstring_to_string(env, text);
|
|
1148
|
+
LOGI("Text to tokenize: %s", text_str.c_str());
|
|
1149
|
+
|
|
1150
|
+
// Find the context
|
|
1151
|
+
auto it = contexts.find(contextId);
|
|
1152
|
+
if (it == contexts.end()) {
|
|
1153
|
+
LOGE("Context not found: %ld", contextId);
|
|
1154
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Context not found");
|
|
1155
|
+
return nullptr;
|
|
1156
|
+
}
|
|
1157
|
+
|
|
1158
|
+
auto& ctx = it->second;
|
|
1159
|
+
if (!ctx || !ctx->ctx) {
|
|
1160
|
+
LOGE("Invalid context or llama context is null");
|
|
1161
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Invalid context");
|
|
1162
|
+
return nullptr;
|
|
1163
|
+
}
|
|
1164
|
+
|
|
1165
|
+
// Tokenize the text using the context's tokenize method
|
|
1166
|
+
capllama::llama_cap_tokenize_result tokenize_result = ctx->tokenize(text_str, {});
|
|
1167
|
+
std::vector<llama_token> tokens = tokenize_result.tokens;
|
|
1168
|
+
|
|
1169
|
+
LOGI("Tokenized %zu tokens", tokens.size());
|
|
1170
|
+
|
|
1171
|
+
// Create Java HashMap for result
|
|
1172
|
+
jclass hashMapClass = env->FindClass("java/util/HashMap");
|
|
1173
|
+
jmethodID hashMapConstructor = env->GetMethodID(hashMapClass, "<init>", "()V");
|
|
1174
|
+
jmethodID putMethod = env->GetMethodID(hashMapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
|
|
1175
|
+
|
|
1176
|
+
jobject resultMap = env->NewObject(hashMapClass, hashMapConstructor);
|
|
1177
|
+
|
|
1178
|
+
// Create Java ArrayList for tokens
|
|
1179
|
+
jclass arrayListClass = env->FindClass("java/util/ArrayList");
|
|
1180
|
+
jmethodID arrayListConstructor = env->GetMethodID(arrayListClass, "<init>", "()V");
|
|
1181
|
+
jmethodID addMethod = env->GetMethodID(arrayListClass, "add", "(Ljava/lang/Object;)Z");
|
|
1182
|
+
|
|
1183
|
+
jobject tokensArray = env->NewObject(arrayListClass, arrayListConstructor);
|
|
1184
|
+
|
|
1185
|
+
// Add tokens to ArrayList
|
|
1186
|
+
jclass integerClass = env->FindClass("java/lang/Integer");
|
|
1187
|
+
jmethodID integerConstructor = env->GetMethodID(integerClass, "<init>", "(I)V");
|
|
1188
|
+
|
|
1189
|
+
for (llama_token token : tokens) {
|
|
1190
|
+
jobject jToken = env->NewObject(integerClass, integerConstructor, static_cast<jint>(token));
|
|
1191
|
+
env->CallBooleanMethod(tokensArray, addMethod, jToken);
|
|
1192
|
+
env->DeleteLocalRef(jToken);
|
|
1193
|
+
}
|
|
1194
|
+
|
|
1195
|
+
// Create empty arrays for other fields
|
|
1196
|
+
jobject emptyBitmapHashes = env->NewObject(arrayListClass, arrayListConstructor);
|
|
1197
|
+
jobject emptyChunkPos = env->NewObject(arrayListClass, arrayListConstructor);
|
|
1198
|
+
jobject emptyChunkPosImages = env->NewObject(arrayListClass, arrayListConstructor);
|
|
1199
|
+
|
|
1200
|
+
// Put all data into result map
|
|
1201
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
1202
|
+
jni_utils::string_to_jstring(env, "tokens"), tokensArray);
|
|
1203
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
1204
|
+
jni_utils::string_to_jstring(env, "has_images"),
|
|
1205
|
+
env->NewObject(env->FindClass("java/lang/Boolean"),
|
|
1206
|
+
env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"), JNI_FALSE));
|
|
1207
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
1208
|
+
jni_utils::string_to_jstring(env, "bitmap_hashes"), emptyBitmapHashes);
|
|
1209
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
1210
|
+
jni_utils::string_to_jstring(env, "chunk_pos"), emptyChunkPos);
|
|
1211
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
1212
|
+
jni_utils::string_to_jstring(env, "chunk_pos_images"), emptyChunkPosImages);
|
|
1213
|
+
|
|
1214
|
+
LOGI("Tokenization completed successfully");
|
|
1215
|
+
return resultMap;
|
|
1216
|
+
|
|
1217
|
+
} catch (const std::exception& e) {
|
|
1218
|
+
LOGE("Exception in tokenize: %s", e.what());
|
|
1219
|
+
throw_java_exception(env, "java/lang/RuntimeException", e.what());
|
|
1220
|
+
return nullptr;
|
|
1221
|
+
}
|
|
1222
|
+
}
|
|
1223
|
+
|
|
1224
|
+
JNIEXPORT jstring JNICALL
|
|
1225
|
+
Java_ai_annadata_plugin_capacitor_LlamaCpp_detokenizeNative(
|
|
1226
|
+
JNIEnv* env, jobject thiz, jlong contextId, jintArray tokens) {
|
|
1227
|
+
|
|
1228
|
+
try {
|
|
1229
|
+
LOGI("Detokenizing with context ID: %ld", contextId);
|
|
1230
|
+
|
|
1231
|
+
// Find the context
|
|
1232
|
+
auto it = contexts.find(contextId);
|
|
1233
|
+
if (it == contexts.end()) {
|
|
1234
|
+
LOGE("Context not found: %ld", contextId);
|
|
1235
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Context not found");
|
|
1236
|
+
return nullptr;
|
|
1237
|
+
}
|
|
1238
|
+
|
|
1239
|
+
auto& ctx = it->second;
|
|
1240
|
+
if (!ctx || !ctx->ctx) {
|
|
1241
|
+
LOGE("Invalid context or llama context is null");
|
|
1242
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Invalid context");
|
|
1243
|
+
return nullptr;
|
|
1244
|
+
}
|
|
1245
|
+
|
|
1246
|
+
// Convert Java int array to C++ vector
|
|
1247
|
+
jsize length = env->GetArrayLength(tokens);
|
|
1248
|
+
jint* tokenArray = env->GetIntArrayElements(tokens, nullptr);
|
|
1249
|
+
|
|
1250
|
+
std::vector<llama_token> llamaTokens;
|
|
1251
|
+
for (jsize i = 0; i < length; i++) {
|
|
1252
|
+
llamaTokens.push_back(static_cast<llama_token>(tokenArray[i]));
|
|
1253
|
+
}
|
|
1254
|
+
|
|
1255
|
+
env->ReleaseIntArrayElements(tokens, tokenArray, JNI_ABORT);
|
|
1256
|
+
|
|
1257
|
+
// Detokenize using llama.cpp
|
|
1258
|
+
std::string result = capllama::tokens_to_str(ctx->ctx, llamaTokens.begin(), llamaTokens.end());
|
|
1259
|
+
|
|
1260
|
+
LOGI("Detokenized to: %s", result.c_str());
|
|
1261
|
+
|
|
1262
|
+
return jni_utils::string_to_jstring(env, result);
|
|
1263
|
+
|
|
1264
|
+
} catch (const std::exception& e) {
|
|
1265
|
+
LOGE("Exception in detokenize: %s", e.what());
|
|
1266
|
+
throw_java_exception(env, "java/lang/RuntimeException", e.what());
|
|
1267
|
+
return nullptr;
|
|
1268
|
+
}
|
|
1269
|
+
}
|
|
529
1270
|
|
|
530
1271
|
} // extern "C"
|
|
531
1272
|
|