llama-cpp-capacitor 0.0.13 → 0.0.21

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. package/LlamaCpp.podspec +17 -17
  2. package/Package.swift +27 -27
  3. package/README.md +717 -574
  4. package/android/build.gradle +88 -69
  5. package/android/src/main/AndroidManifest.xml +2 -2
  6. package/android/src/main/CMakeLists-arm64.txt +131 -0
  7. package/android/src/main/CMakeLists-x86_64.txt +135 -0
  8. package/android/src/main/CMakeLists.txt +35 -52
  9. package/android/src/main/java/ai/annadata/plugin/capacitor/LlamaCpp.java +956 -717
  10. package/android/src/main/java/ai/annadata/plugin/capacitor/LlamaCppPlugin.java +710 -590
  11. package/android/src/main/jni-utils.h +7 -7
  12. package/android/src/main/jni.cpp +868 -127
  13. package/cpp/{rn-completion.cpp → cap-completion.cpp} +202 -24
  14. package/cpp/{rn-completion.h → cap-completion.h} +22 -11
  15. package/cpp/{rn-llama.cpp → cap-llama.cpp} +81 -27
  16. package/cpp/{rn-llama.h → cap-llama.h} +32 -20
  17. package/cpp/{rn-mtmd.hpp → cap-mtmd.hpp} +15 -15
  18. package/cpp/{rn-tts.cpp → cap-tts.cpp} +12 -12
  19. package/cpp/{rn-tts.h → cap-tts.h} +14 -14
  20. package/cpp/ggml-cpu/ggml-cpu-impl.h +30 -0
  21. package/dist/docs.json +100 -3
  22. package/dist/esm/definitions.d.ts +45 -2
  23. package/dist/esm/definitions.js.map +1 -1
  24. package/dist/esm/index.d.ts +22 -0
  25. package/dist/esm/index.js +66 -3
  26. package/dist/esm/index.js.map +1 -1
  27. package/dist/plugin.cjs.js +71 -3
  28. package/dist/plugin.cjs.js.map +1 -1
  29. package/dist/plugin.js +71 -3
  30. package/dist/plugin.js.map +1 -1
  31. package/ios/Sources/LlamaCppPlugin/LlamaCpp.swift +596 -596
  32. package/ios/Sources/LlamaCppPlugin/LlamaCppPlugin.swift +591 -514
  33. package/ios/Tests/LlamaCppPluginTests/LlamaCppPluginTests.swift +15 -15
  34. package/package.json +111 -110
@@ -1,16 +1,21 @@
1
1
  #include "jni-utils.h"
2
- #include "rn-llama.h"
2
+ #include "cap-llama.h"
3
+ #include "cap-completion.h"
3
4
  #include <android/log.h>
4
5
  #include <cstring>
5
6
  #include <memory>
6
7
  #include <fstream> // Added for file existence and size checks
7
8
  #include <signal.h> // Added for signal handling
8
9
  #include <sys/signal.h> // Added for sigaction
10
+ #include <thread> // For background downloads
11
+ #include <atomic> // For thread-safe progress tracking
12
+ #include <filesystem> // For file operations
13
+ #include <mutex> // For thread synchronization
9
14
 
10
15
  // Add missing symbol
11
- namespace rnllama {
12
- bool rnllama_verbose = false;
13
- }
16
+ // namespace rnllama {
17
+ // bool rnllama_verbose = false;
18
+ // }
14
19
 
15
20
  #define LOG_TAG "LlamaCpp"
16
21
  #define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
@@ -130,75 +135,67 @@ jclass find_class(JNIEnv* env, const char* name) {
130
135
  return clazz;
131
136
  }
132
137
 
133
- // Global context storage
134
- static std::map<jlong, std::unique_ptr<rnllama::llama_rn_context>> contexts;
138
+ // Convert llama_cap_context to jobject
139
+ jobject llama_context_to_jobject(JNIEnv* env, const capllama::llama_cap_context* context);
140
+
141
+ // Convert jobject to llama_cap_context
142
+ capllama::llama_cap_context* jobject_to_llama_context(JNIEnv* env, jobject obj);
143
+
144
+ // Convert completion result to jobject
145
+ jobject completion_result_to_jobject(JNIEnv* env, const capllama::completion_token_output& result);
146
+
147
+ // Convert tokenize result to jobject
148
+ jobject tokenize_result_to_jobject(JNIEnv* env, const capllama::llama_cap_tokenize_result& result);
149
+
150
+ // Global context storage - fix namespace
151
+ static std::map<jlong, std::unique_ptr<capllama::llama_cap_context>> contexts;
135
152
  static jlong next_context_id = 1;
136
153
 
154
+ // Download progress tracking (simplified for now)
155
+ // This can be enhanced later to track actual download progress
156
+
137
157
  extern "C" {
138
158
 
139
159
  JNIEXPORT jlong JNICALL
140
160
  Java_ai_annadata_plugin_capacitor_LlamaCpp_initContextNative(
141
- JNIEnv* env, jobject thiz, jstring model_path, jobject params) {
161
+ JNIEnv *env, jobject thiz, jstring modelPath, jobjectArray searchPaths, jobject params) {
142
162
 
143
163
  try {
144
- std::string model_path_str = jstring_to_string(env, model_path);
145
- LOGI("Attempting to load model from path: %s", model_path_str.c_str());
146
-
147
- // List all possible paths we should check
148
- std::vector<std::string> paths_to_check = {
149
- model_path_str,
150
- "/data/data/ai.annadata.app/files/" + model_path_str,
151
- "/data/data/ai.annadata.app/files/Documents/" + model_path_str,
152
- "/storage/emulated/0/Android/data/ai.annadata.app/files/" + model_path_str,
153
- "/storage/emulated/0/Android/data/ai.annadata.app/files/Documents/" + model_path_str,
154
- "/storage/emulated/0/Documents/" + model_path_str
155
- };
156
-
157
- // Check each path and log what we find
164
+ std::string model_path_str = jstring_to_string(env, modelPath);
165
+
166
+ // Get search paths from Java
167
+ jsize pathCount = env->GetArrayLength(searchPaths);
168
+ std::vector<std::string> paths_to_check;
169
+
170
+ // Add the original path first
171
+ paths_to_check.push_back(model_path_str);
172
+
173
+ // Add all search paths from Java
174
+ for (jsize i = 0; i < pathCount; i++) {
175
+ jstring pathJString = (jstring)env->GetObjectArrayElement(searchPaths, i);
176
+ std::string path = jstring_to_string(env, pathJString);
177
+ paths_to_check.push_back(path);
178
+ env->DeleteLocalRef(pathJString);
179
+ }
180
+
181
+ // Rest of the existing logic remains the same...
158
182
  std::string full_model_path;
159
183
  bool file_found = false;
160
184
 
161
185
  for (const auto& path : paths_to_check) {
162
186
  LOGI("Checking path: %s", path.c_str());
163
- std::ifstream file_check(path);
164
- if (file_check.good()) {
165
- file_check.seekg(0, std::ios::end);
166
- std::streamsize file_size = file_check.tellg();
167
- file_check.close();
168
- LOGI("Found file at: %s, size: %ld bytes", path.c_str(), file_size);
169
-
170
- // Validate file size
171
- if (file_size < 1024 * 1024) { // Less than 1MB
172
- LOGE("Model file is too small, likely corrupted: %s", path.c_str());
173
- continue; // Try next path
174
- }
175
-
176
- // Check if it's a valid GGUF file by reading the magic number
177
- std::ifstream magic_file(path, std::ios::binary);
178
- if (magic_file.good()) {
179
- char magic[4];
180
- if (magic_file.read(magic, 4)) {
181
- if (magic[0] == 'G' && magic[1] == 'G' && magic[2] == 'U' && magic[3] == 'F') {
182
- LOGI("Valid GGUF file detected at: %s", path.c_str());
183
- full_model_path = path;
184
- file_found = true;
185
- break;
186
- } else {
187
- LOGI("File does not appear to be a GGUF file (magic: %c%c%c%c) at: %s",
188
- magic[0], magic[1], magic[2], magic[3], path.c_str());
189
- }
190
- }
191
- magic_file.close();
192
- }
187
+ if (std::filesystem::exists(path)) {
188
+ full_model_path = path;
189
+ file_found = true;
190
+ LOGI("Found model file at: %s", path.c_str());
191
+ break;
193
192
  } else {
194
- LOGI("File not found at: %s", path.c_str());
193
+ LOGE("Path not found: %s", path.c_str());
195
194
  }
196
- file_check.close();
197
195
  }
198
-
196
+
199
197
  if (!file_found) {
200
- LOGE("Model file not found in any of the checked paths");
201
- throw_java_exception(env, "java/lang/RuntimeException", "Model file not found in any expected location");
198
+ LOGE("Model file not found in any of the search paths");
202
199
  return -1;
203
200
  }
204
201
 
@@ -221,9 +218,9 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_initContextNative(
221
218
  validation_file.close();
222
219
  }
223
220
 
224
- // Create new context
225
- auto context = std::make_unique<rnllama::llama_rn_context>();
226
- LOGI("Created llama_rn_context");
221
+ // Create new context - fix namespace
222
+ auto context = std::make_unique<capllama::llama_cap_context>();
223
+ LOGI("Created llama_cap_context");
227
224
 
228
225
  // Initialize common parameters
229
226
  common_params cparams;
@@ -240,47 +237,21 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_initContextNative(
240
237
  cparams.chat_template = "";
241
238
  cparams.embedding = false;
242
239
  cparams.cont_batching = false;
243
- cparams.parallel = false;
244
- cparams.grammar = "";
245
- cparams.grammar_penalty.clear();
240
+ cparams.n_parallel = 1;
246
241
  cparams.antiprompt.clear();
247
- cparams.lora_adapter.clear();
248
- cparams.lora_base = "";
249
- cparams.mul_mat_q = true;
250
- cparams.f16_kv = true;
251
- cparams.logits_all = false;
252
242
  cparams.vocab_only = false;
253
243
  cparams.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
254
- cparams.rope_scaling_factor = 0.0f;
255
- cparams.rope_scaling_orig_ctx_len = 0;
256
244
  cparams.yarn_ext_factor = -1.0f;
257
245
  cparams.yarn_attn_factor = 1.0f;
258
246
  cparams.yarn_beta_fast = 32.0f;
259
247
  cparams.yarn_beta_slow = 1.0f;
260
248
  cparams.yarn_orig_ctx = 0;
261
- cparams.offload_kqv = true;
262
249
  cparams.flash_attn = false;
263
- cparams.flash_attn_kernel = false;
264
- cparams.flash_attn_causal = true;
265
- cparams.mmproj = "";
266
- cparams.image = "";
267
- cparams.export = "";
268
- cparams.export_path = "";
269
- cparams.seed = -1;
270
250
  cparams.n_keep = 0;
271
- cparams.n_discard = -1;
272
- cparams.n_draft = 0;
273
251
  cparams.n_chunks = -1;
274
- cparams.n_parallel = 1;
275
252
  cparams.n_sequences = 1;
276
- cparams.p_accept = 0.5f;
277
- cparams.p_split = 0.1f;
278
- cparams.n_gqa = 8;
279
- cparams.rms_norm_eps = 5e-6f;
280
253
  cparams.model_alias = "unknown";
281
- cparams.ubatch_size = 512;
282
- cparams.ubatch_seq_len_max = 1;
283
-
254
+
284
255
  LOGI("Initialized common parameters, attempting to load model from: %s", full_model_path.c_str());
285
256
  LOGI("Model parameters: n_ctx=%d, n_batch=%d, n_gpu_layers=%d",
286
257
  cparams.n_ctx, cparams.n_batch, cparams.n_gpu_layers);
@@ -335,47 +306,21 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_initContextNative(
335
306
  ultra_minimal_params.chat_template = "";
336
307
  ultra_minimal_params.embedding = false;
337
308
  ultra_minimal_params.cont_batching = false;
338
- ultra_minimal_params.parallel = false;
339
- ultra_minimal_params.grammar = "";
340
- ultra_minimal_params.grammar_penalty.clear();
309
+ ultra_minimal_params.n_parallel = 1;
341
310
  ultra_minimal_params.antiprompt.clear();
342
- ultra_minimal_params.lora_adapter.clear();
343
- ultra_minimal_params.lora_base = "";
344
- ultra_minimal_params.mul_mat_q = false; // Disable quantized matrix multiplication
345
- ultra_minimal_params.f16_kv = false; // Disable f16 key-value cache
346
- ultra_minimal_params.logits_all = false;
347
311
  ultra_minimal_params.vocab_only = false;
348
312
  ultra_minimal_params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
349
- ultra_minimal_params.rope_scaling_factor = 0.0f;
350
- ultra_minimal_params.rope_scaling_orig_ctx_len = 0;
351
313
  ultra_minimal_params.yarn_ext_factor = -1.0f;
352
314
  ultra_minimal_params.yarn_attn_factor = 1.0f;
353
315
  ultra_minimal_params.yarn_beta_fast = 32.0f;
354
316
  ultra_minimal_params.yarn_beta_slow = 1.0f;
355
317
  ultra_minimal_params.yarn_orig_ctx = 0;
356
- ultra_minimal_params.offload_kqv = false; // Disable offloading
357
318
  ultra_minimal_params.flash_attn = false;
358
- ultra_minimal_params.flash_attn_kernel = false;
359
- ultra_minimal_params.flash_attn_causal = true;
360
- ultra_minimal_params.mmproj = "";
361
- ultra_minimal_params.image = "";
362
- ultra_minimal_params.export = "";
363
- ultra_minimal_params.export_path = "";
364
- ultra_minimal_params.seed = -1;
365
319
  ultra_minimal_params.n_keep = 0;
366
- ultra_minimal_params.n_discard = -1;
367
- ultra_minimal_params.n_draft = 0;
368
320
  ultra_minimal_params.n_chunks = -1;
369
- ultra_minimal_params.n_parallel = 1;
370
321
  ultra_minimal_params.n_sequences = 1;
371
- ultra_minimal_params.p_accept = 0.5f;
372
- ultra_minimal_params.p_split = 0.1f;
373
- ultra_minimal_params.n_gqa = 8;
374
- ultra_minimal_params.rms_norm_eps = 5e-6f;
375
322
  ultra_minimal_params.model_alias = "unknown";
376
- ultra_minimal_params.ubatch_size = 128;
377
- ultra_minimal_params.ubatch_seq_len_max = 1;
378
-
323
+
379
324
  // Set up signal handler again for ultra-minimal attempt
380
325
  if (sigaction(SIGSEGV, &new_action, &old_action) == 0) {
381
326
  LOGI("Signal handler reinstalled for ultra-minimal attempt");
@@ -435,28 +380,400 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_releaseContextNative(
435
380
  }
436
381
  }
437
382
 
438
- JNIEXPORT jstring JNICALL
383
+ JNIEXPORT jobject JNICALL
439
384
  Java_ai_annadata_plugin_capacitor_LlamaCpp_completionNative(
440
- JNIEnv* env, jobject thiz, jlong context_id, jstring prompt) {
385
+ JNIEnv* env, jobject thiz, jlong context_id, jobject params) {
441
386
 
442
387
  try {
388
+ LOGI("Starting completion for context: %ld", context_id);
389
+
443
390
  auto it = contexts.find(context_id);
444
391
  if (it == contexts.end()) {
392
+ LOGE("Context not found: %ld", context_id);
445
393
  throw_java_exception(env, "java/lang/IllegalArgumentException", "Invalid context ID");
446
394
  return nullptr;
447
395
  }
448
396
 
449
- std::string prompt_str = jstring_to_string(env, prompt);
397
+ auto& ctx = it->second;
398
+ if (!ctx || !ctx->ctx) {
399
+ LOGE("Invalid context or llama context is null");
400
+ throw_java_exception(env, "java/lang/RuntimeException", "Invalid context");
401
+ return nullptr;
402
+ }
403
+
404
+ // Extract parameters from JSObject using compatible API
405
+ jclass jsObjectClass = env->GetObjectClass(params);
450
406
 
451
- // Get the context
452
- rnllama::llama_rn_context* context = it->second.get();
407
+ // Try to get method IDs and handle exceptions
408
+ jmethodID getStringMethod = nullptr;
409
+ jmethodID getIntegerMethod = nullptr;
410
+ jmethodID getDoubleMethod = nullptr;
453
411
 
454
- // For now, return a simple completion
455
- // In a full implementation, this would use the actual llama.cpp completion logic
456
- std::string result = "Generated response for: " + prompt_str;
412
+ // Clear any pending exceptions first
413
+ if (env->ExceptionCheck()) {
414
+ env->ExceptionClear();
415
+ }
457
416
 
458
- LOGI("Completion for context %ld: %s", context_id, prompt_str.c_str());
459
- return string_to_jstring(env, result);
417
+ try {
418
+ getStringMethod = env->GetMethodID(jsObjectClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;");
419
+ if (env->ExceptionCheck()) {
420
+ env->ExceptionClear();
421
+ getStringMethod = nullptr;
422
+ }
423
+
424
+ getIntegerMethod = env->GetMethodID(jsObjectClass, "getInteger", "(Ljava/lang/String;)Ljava/lang/Integer;");
425
+ if (env->ExceptionCheck()) {
426
+ env->ExceptionClear();
427
+ getIntegerMethod = nullptr;
428
+ }
429
+
430
+ getDoubleMethod = env->GetMethodID(jsObjectClass, "getDouble", "(Ljava/lang/String;)Ljava/lang/Double;");
431
+ if (env->ExceptionCheck()) {
432
+ env->ExceptionClear();
433
+ getDoubleMethod = nullptr;
434
+ }
435
+ } catch (...) {
436
+ LOGE("Exception getting JSObject method IDs");
437
+ if (env->ExceptionCheck()) {
438
+ env->ExceptionClear();
439
+ }
440
+ }
441
+
442
+ // Get prompt with safe method calls
443
+ std::string prompt_str = "Once upon a time";
444
+ jint n_predict = 50;
445
+ jdouble temperature = 0.7;
446
+
447
+ if (getStringMethod) {
448
+ jstring promptKey = jni_utils::string_to_jstring(env, "prompt");
449
+ jstring promptObj = (jstring)env->CallObjectMethod(params, getStringMethod, promptKey);
450
+ if (promptObj && !env->ExceptionCheck()) {
451
+ prompt_str = jni_utils::jstring_to_string(env, promptObj);
452
+ } else if (env->ExceptionCheck()) {
453
+ env->ExceptionClear();
454
+ }
455
+ }
456
+
457
+ // Get n_predict with safe method calls
458
+ if (getIntegerMethod) {
459
+ jstring nPredictKey = jni_utils::string_to_jstring(env, "n_predict");
460
+ jobject nPredictObj = env->CallObjectMethod(params, getIntegerMethod, nPredictKey);
461
+ if (nPredictObj && !env->ExceptionCheck()) {
462
+ n_predict = env->CallIntMethod(nPredictObj, env->GetMethodID(env->FindClass("java/lang/Integer"), "intValue", "()I"));
463
+ if (env->ExceptionCheck()) {
464
+ env->ExceptionClear();
465
+ n_predict = 50; // fallback
466
+ }
467
+ } else if (env->ExceptionCheck()) {
468
+ env->ExceptionClear();
469
+ }
470
+ }
471
+
472
+ // Get temperature with safe method calls
473
+ if (getDoubleMethod) {
474
+ jstring temperatureKey = jni_utils::string_to_jstring(env, "temperature");
475
+ jobject tempObj = env->CallObjectMethod(params, getDoubleMethod, temperatureKey);
476
+ if (tempObj && !env->ExceptionCheck()) {
477
+ temperature = env->CallDoubleMethod(tempObj, env->GetMethodID(env->FindClass("java/lang/Double"), "doubleValue", "()D"));
478
+ if (env->ExceptionCheck()) {
479
+ env->ExceptionClear();
480
+ temperature = 0.7; // fallback
481
+ }
482
+ } else if (env->ExceptionCheck()) {
483
+ env->ExceptionClear();
484
+ }
485
+ }
486
+
487
+ LOGI("Completion params - prompt: %s, n_predict: %d, temperature: %.2f",
488
+ prompt_str.c_str(), n_predict, temperature);
489
+
490
+ // Set sampling parameters based on extracted values
491
+ ctx->params.sampling.temp = temperature;
492
+ ctx->params.sampling.top_k = 40; // Default value
493
+ ctx->params.sampling.top_p = 0.95f; // Default value
494
+ ctx->params.sampling.penalty_repeat = 1.1f; // Default value (correct field name)
495
+ ctx->params.n_predict = n_predict;
496
+ ctx->params.prompt = prompt_str;
497
+
498
+ LOGI("Updated context sampling params - temp: %.2f, top_k: %d, top_p: %.2f",
499
+ ctx->params.sampling.temp, ctx->params.sampling.top_k, ctx->params.sampling.top_p);
500
+
501
+ // Tokenize the prompt
502
+ capllama::llama_cap_tokenize_result tokenize_result = ctx->tokenize(prompt_str, {});
503
+ std::vector<llama_token> prompt_tokens = tokenize_result.tokens;
504
+
505
+ LOGI("Tokenized prompt into %zu tokens", prompt_tokens.size());
506
+
507
+ // Initialize completion context if not already done
508
+ if (!ctx->completion) {
509
+ LOGI("Initializing completion context for the first time");
510
+
511
+ // Validate parent context before creating completion
512
+ if (!ctx->ctx || !ctx->model) {
513
+ LOGE("Parent context is invalid - missing llama context or model");
514
+ throw_java_exception(env, "java/lang/RuntimeException", "Parent context is not properly initialized");
515
+ return nullptr;
516
+ }
517
+
518
+ try {
519
+ LOGI("Creating llama_cap_context_completion...");
520
+ LOGI("Parent context pointer: %p", ctx.get());
521
+ LOGI("Parent context->ctx: %p", ctx->ctx);
522
+ LOGI("Parent context->model: %p", ctx->model);
523
+
524
+ // Additional safety checks before constructor
525
+ if (!ctx.get()) {
526
+ LOGE("Parent context pointer is null");
527
+ throw_java_exception(env, "java/lang/RuntimeException", "Parent context pointer is null");
528
+ return nullptr;
529
+ }
530
+
531
+ ctx->completion = new capllama::llama_cap_context_completion(ctx.get());
532
+
533
+ if (!ctx->completion) {
534
+ LOGE("Failed to create completion context - constructor returned null");
535
+ throw_java_exception(env, "java/lang/RuntimeException", "Failed to create completion context");
536
+ return nullptr;
537
+ }
538
+
539
+ LOGI("Completion context created successfully at: %p", ctx->completion);
540
+
541
+ LOGI("Initializing sampling for completion context...");
542
+ LOGI("Parent context params before initSampling - model: %p, params: %p", ctx->model, &(ctx->params));
543
+ LOGI("Parent context sampling params - temperature: %.2f, top_k: %d, top_p: %.2f",
544
+ ctx->params.sampling.temp, ctx->params.sampling.top_k, ctx->params.sampling.top_p);
545
+
546
+ bool sampling_result = false;
547
+ try {
548
+ sampling_result = ctx->completion->initSampling();
549
+ LOGI("initSampling completed, result: %s", sampling_result ? "true" : "false");
550
+ LOGI("Sampler pointer after init: %p", ctx->completion->ctx_sampling);
551
+ } catch (const std::exception& e) {
552
+ LOGE("Exception in initSampling: %s", e.what());
553
+ delete ctx->completion;
554
+ ctx->completion = nullptr;
555
+ throw_java_exception(env, "java/lang/RuntimeException",
556
+ ("Failed to initialize sampling: " + std::string(e.what())).c_str());
557
+ return nullptr;
558
+ } catch (...) {
559
+ LOGE("Unknown exception in initSampling");
560
+ delete ctx->completion;
561
+ ctx->completion = nullptr;
562
+ throw_java_exception(env, "java/lang/RuntimeException", "Unknown error in sampling initialization");
563
+ return nullptr;
564
+ }
565
+
566
+ if (!sampling_result || !ctx->completion->ctx_sampling) {
567
+ LOGE("Failed to initialize sampling - result: %s, sampler: %p",
568
+ sampling_result ? "true" : "false", ctx->completion->ctx_sampling);
569
+ delete ctx->completion;
570
+ ctx->completion = nullptr;
571
+ throw_java_exception(env, "java/lang/RuntimeException", "Failed to initialize sampling context");
572
+ return nullptr;
573
+ }
574
+
575
+ LOGI("Completion context initialized successfully");
576
+ } catch (const std::exception& e) {
577
+ LOGE("Exception during completion context creation: %s", e.what());
578
+ if (ctx->completion) {
579
+ delete ctx->completion;
580
+ ctx->completion = nullptr;
581
+ }
582
+ throw_java_exception(env, "java/lang/RuntimeException",
583
+ ("Failed to create completion context: " + std::string(e.what())).c_str());
584
+ return nullptr;
585
+ } catch (...) {
586
+ LOGE("Unknown exception during completion context creation");
587
+ if (ctx->completion) {
588
+ delete ctx->completion;
589
+ ctx->completion = nullptr;
590
+ }
591
+ throw_java_exception(env, "java/lang/RuntimeException", "Unknown error during completion context creation");
592
+ return nullptr;
593
+ }
594
+ }
595
+
596
+ // Set up sampling parameters
597
+ // Note: For now, we'll use the completion context's default parameters
598
+ // TODO: Update sampling parameters with user values
599
+ //
600
+ // Declare variables outside try block so they're accessible later
601
+ std::string generated_text;
602
+ int tokens_generated = 0;
603
+
604
+ try {
605
+ LOGI("Rewinding completion context...");
606
+ try {
607
+ ctx->completion->rewind();
608
+ LOGI("Rewind completed successfully");
609
+ } catch (const std::exception& e) {
610
+ LOGE("Exception in rewind: %s", e.what());
611
+ throw;
612
+ }
613
+
614
+ LOGI("Loading prompt into completion context...");
615
+ try {
616
+ // Validate sampler is properly initialized before loadPrompt
617
+ if (!ctx->completion->ctx_sampling) {
618
+ LOGE("Sampler context is null - reinitializing");
619
+ if (!ctx->completion->initSampling()) {
620
+ LOGE("Failed to reinitialize sampling");
621
+ throw std::runtime_error("Sampler initialization failed");
622
+ }
623
+ LOGI("Sampler reinitialized successfully");
624
+ }
625
+
626
+ ctx->completion->loadPrompt({});
627
+ LOGI("loadPrompt completed successfully");
628
+ } catch (const std::exception& e) {
629
+ LOGE("Exception in loadPrompt: %s", e.what());
630
+ throw;
631
+ }
632
+
633
+ LOGI("Beginning completion generation...");
634
+ try {
635
+ ctx->completion->beginCompletion();
636
+ LOGI("beginCompletion completed successfully");
637
+ } catch (const std::exception& e) {
638
+ LOGE("Exception in beginCompletion: %s", e.what());
639
+ throw;
640
+ }
641
+
642
+ LOGI("Starting token generation loop (max tokens: %d)...", n_predict);
643
+
644
+ while (tokens_generated < n_predict && !ctx->completion->is_interrupted) {
645
+ try {
646
+ LOGI("Generating token %d...", tokens_generated + 1);
647
+ auto token_output = ctx->completion->nextToken();
648
+
649
+ // Check for end-of-sequence (simplified check)
650
+ if (token_output.tok == 2) { // Most models use 2 as EOS token
651
+ LOGI("Reached EOS token, stopping generation");
652
+ break;
653
+ }
654
+
655
+ // Convert token to text
656
+ std::string token_text = capllama::tokens_to_output_formatted_string(ctx->ctx, token_output.tok);
657
+ generated_text += token_text;
658
+ tokens_generated++;
659
+
660
+ LOGI("Generated token %d (ID: %d): %s", tokens_generated, token_output.tok, token_text.c_str());
661
+
662
+ } catch (const std::exception& e) {
663
+ LOGE("Exception during token generation %d: %s", tokens_generated + 1, e.what());
664
+ break;
665
+ } catch (...) {
666
+ LOGE("Unknown exception during token generation %d", tokens_generated + 1);
667
+ break;
668
+ }
669
+ }
670
+
671
+ LOGI("Token generation completed. Generated %d tokens.", tokens_generated);
672
+
673
+ // End completion
674
+ LOGI("Ending completion...");
675
+ ctx->completion->endCompletion();
676
+
677
+ } catch (const std::exception& e) {
678
+ LOGE("Exception during completion process: %s", e.what());
679
+ try {
680
+ ctx->completion->endCompletion();
681
+ } catch (...) {
682
+ LOGE("Failed to properly end completion after exception");
683
+ }
684
+ throw_java_exception(env, "java/lang/RuntimeException",
685
+ ("Completion process failed: " + std::string(e.what())).c_str());
686
+ return nullptr;
687
+ } catch (...) {
688
+ LOGE("Unknown exception during completion process");
689
+ try {
690
+ ctx->completion->endCompletion();
691
+ } catch (...) {
692
+ LOGE("Failed to properly end completion after unknown exception");
693
+ }
694
+ throw_java_exception(env, "java/lang/RuntimeException", "Unknown error during completion process");
695
+ return nullptr;
696
+ }
697
+
698
+ LOGI("Completion finished. Generated %d tokens: %s", tokens_generated, generated_text.c_str());
699
+
700
+ // Create result HashMap
701
+ jclass hashMapClass = env->FindClass("java/util/HashMap");
702
+ jmethodID hashMapConstructor = env->GetMethodID(hashMapClass, "<init>", "()V");
703
+ jmethodID putMethod = env->GetMethodID(hashMapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
704
+
705
+ jobject resultMap = env->NewObject(hashMapClass, hashMapConstructor);
706
+
707
+ // Add completion results
708
+ env->CallObjectMethod(resultMap, putMethod,
709
+ jni_utils::string_to_jstring(env, "text"), jni_utils::string_to_jstring(env, generated_text));
710
+ env->CallObjectMethod(resultMap, putMethod,
711
+ jni_utils::string_to_jstring(env, "content"), jni_utils::string_to_jstring(env, generated_text));
712
+ env->CallObjectMethod(resultMap, putMethod,
713
+ jni_utils::string_to_jstring(env, "reasoning_content"), jni_utils::string_to_jstring(env, ""));
714
+
715
+ // Create empty tool_calls array
716
+ jclass arrayListClass = env->FindClass("java/util/ArrayList");
717
+ jmethodID arrayListConstructor = env->GetMethodID(arrayListClass, "<init>", "()V");
718
+ jobject emptyToolCalls = env->NewObject(arrayListClass, arrayListConstructor);
719
+ env->CallObjectMethod(resultMap, putMethod,
720
+ jni_utils::string_to_jstring(env, "tool_calls"), emptyToolCalls);
721
+
722
+ // Add token counts and status
723
+ env->CallObjectMethod(resultMap, putMethod,
724
+ jni_utils::string_to_jstring(env, "tokens_predicted"),
725
+ env->NewObject(env->FindClass("java/lang/Integer"),
726
+ env->GetMethodID(env->FindClass("java/lang/Integer"), "<init>", "(I)V"), tokens_generated));
727
+ env->CallObjectMethod(resultMap, putMethod,
728
+ jni_utils::string_to_jstring(env, "tokens_evaluated"),
729
+ env->NewObject(env->FindClass("java/lang/Integer"),
730
+ env->GetMethodID(env->FindClass("java/lang/Integer"), "<init>", "(I)V"), (jint)prompt_tokens.size()));
731
+
732
+ // Add completion status flags
733
+ env->CallObjectMethod(resultMap, putMethod,
734
+ jni_utils::string_to_jstring(env, "truncated"),
735
+ env->NewObject(env->FindClass("java/lang/Boolean"),
736
+ env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"), JNI_FALSE));
737
+ env->CallObjectMethod(resultMap, putMethod,
738
+ jni_utils::string_to_jstring(env, "stopped_eos"),
739
+ env->NewObject(env->FindClass("java/lang/Boolean"),
740
+ env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"),
741
+ tokens_generated < n_predict ? JNI_TRUE : JNI_FALSE));
742
+ env->CallObjectMethod(resultMap, putMethod,
743
+ jni_utils::string_to_jstring(env, "stopped_limit"),
744
+ env->NewObject(env->FindClass("java/lang/Boolean"),
745
+ env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"),
746
+ tokens_generated >= n_predict ? JNI_TRUE : JNI_FALSE));
747
+ env->CallObjectMethod(resultMap, putMethod,
748
+ jni_utils::string_to_jstring(env, "context_full"),
749
+ env->NewObject(env->FindClass("java/lang/Boolean"),
750
+ env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"), JNI_FALSE));
751
+ env->CallObjectMethod(resultMap, putMethod,
752
+ jni_utils::string_to_jstring(env, "interrupted"),
753
+ env->NewObject(env->FindClass("java/lang/Boolean"),
754
+ env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"), JNI_FALSE));
755
+
756
+ // Add empty strings for stop reasons
757
+ env->CallObjectMethod(resultMap, putMethod,
758
+ jni_utils::string_to_jstring(env, "stopped_word"), jni_utils::string_to_jstring(env, ""));
759
+ env->CallObjectMethod(resultMap, putMethod,
760
+ jni_utils::string_to_jstring(env, "stopping_word"), jni_utils::string_to_jstring(env, ""));
761
+
762
+ // Add timing information (basic)
763
+ jobject timingsMap = env->NewObject(hashMapClass, hashMapConstructor);
764
+ env->CallObjectMethod(timingsMap, putMethod,
765
+ jni_utils::string_to_jstring(env, "prompt_n"),
766
+ env->NewObject(env->FindClass("java/lang/Integer"),
767
+ env->GetMethodID(env->FindClass("java/lang/Integer"), "<init>", "(I)V"), (jint)prompt_tokens.size()));
768
+ env->CallObjectMethod(timingsMap, putMethod,
769
+ jni_utils::string_to_jstring(env, "predicted_n"),
770
+ env->NewObject(env->FindClass("java/lang/Integer"),
771
+ env->GetMethodID(env->FindClass("java/lang/Integer"), "<init>", "(I)V"), tokens_generated));
772
+ env->CallObjectMethod(resultMap, putMethod,
773
+ jni_utils::string_to_jstring(env, "timings"), timingsMap);
774
+
775
+ LOGI("Completion result created successfully");
776
+ return resultMap;
460
777
 
461
778
  } catch (const std::exception& e) {
462
779
  LOGE("Exception in completion: %s", e.what());
@@ -495,7 +812,7 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_getFormattedChatNative(
495
812
  std::string messages_str = jstring_to_string(env, messages);
496
813
  std::string template_str = jstring_to_string(env, chat_template);
497
814
 
498
- rnllama::llama_rn_context* context = it->second.get();
815
+ capllama::llama_cap_context* context = it->second.get();
499
816
 
500
817
  // Format chat using the context's method
501
818
  std::string result = context->getFormattedChat(messages_str, template_str);
@@ -515,7 +832,7 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_toggleNativeLogNative(
515
832
  JNIEnv* env, jobject thiz, jboolean enabled) {
516
833
 
517
834
  try {
518
- rnllama::rnllama_verbose = jboolean_to_bool(enabled);
835
+ // rnllama::rnllama_verbose = jboolean_to_bool(enabled); // This line is removed as per the edit hint
519
836
  LOGI("Native logging %s", enabled ? "enabled" : "disabled");
520
837
  return bool_to_jboolean(true);
521
838
  } catch (const std::exception& e) {
@@ -525,7 +842,431 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_toggleNativeLogNative(
525
842
  }
526
843
  }
527
844
 
845
+ JNIEXPORT jobject JNICALL
846
+ Java_ai_annadata_plugin_capacitor_LlamaCpp_modelInfoNative(
847
+ JNIEnv* env, jobject thiz, jstring model_path) {
848
+
849
+ try {
850
+ std::string model_path_str = jstring_to_string(env, model_path);
851
+ LOGI("Getting model info for: %s", model_path_str.c_str());
852
+
853
+ // Extract filename from path
854
+ std::string filename = model_path_str;
855
+ size_t last_slash = model_path_str.find_last_of('/');
856
+ if (last_slash != std::string::npos) {
857
+ filename = model_path_str.substr(last_slash + 1);
858
+ }
859
+ LOGI("Extracted filename for model info: %s", filename.c_str());
528
860
 
861
+ // List all possible paths we should check (same as initContextNative)
862
+ std::vector<std::string> paths_to_check = {
863
+ model_path_str, // Try the original path first
864
+ "/data/data/ai.annadata.llamacpp/files/" + filename,
865
+ "/data/data/ai.annadata.llamacpp/files/Documents/" + filename,
866
+ "/storage/emulated/0/Android/data/ai.annadata.llamacpp/files/" + filename,
867
+ "/storage/emulated/0/Android/data/ai.annadata.llamacpp/files/Documents/" + filename,
868
+ "/storage/emulated/0/Documents/" + filename,
869
+ "/storage/emulated/0/Download/" + filename
870
+ };
871
+
872
+ // Check each path and find the actual file
873
+ std::string full_model_path;
874
+ bool file_found = false;
875
+
876
+ for (const auto& path : paths_to_check) {
877
+ LOGI("Checking path for model info: %s", path.c_str());
878
+ std::ifstream file_check(path, std::ios::binary);
879
+ if (file_check.good()) {
880
+ file_check.seekg(0, std::ios::end);
881
+ std::streamsize file_size = file_check.tellg();
882
+ file_check.seekg(0, std::ios::beg);
883
+
884
+ // Validate file size
885
+ if (file_size < 1024 * 1024) { // Less than 1MB
886
+ LOGE("Model file is too small, likely corrupted: %s", path.c_str());
887
+ file_check.close();
888
+ continue; // Try next path
889
+ }
890
+
891
+ // Check if it's a valid GGUF file by reading the magic number
892
+ char magic[4];
893
+ if (file_check.read(magic, 4)) {
894
+ if (magic[0] == 'G' && magic[1] == 'G' && magic[2] == 'U' && magic[3] == 'F') {
895
+ LOGI("Valid GGUF file detected for model info at: %s", path.c_str());
896
+ full_model_path = path;
897
+ file_found = true;
898
+ file_check.close();
899
+ break;
900
+ } else {
901
+ LOGI("File does not appear to be a GGUF file (magic: %c%c%c%c) at: %s",
902
+ magic[0], magic[1], magic[2], magic[3], path.c_str());
903
+ }
904
+ }
905
+ file_check.close();
906
+ } else {
907
+ LOGI("File not found at: %s", path.c_str());
908
+ }
909
+ }
910
+
911
+ if (!file_found) {
912
+ LOGE("Model file not found in any of the checked paths");
913
+ throw_java_exception(env, "java/lang/RuntimeException", "Model file not found");
914
+ return nullptr;
915
+ }
916
+
917
+ // Now use the found path for getting model info
918
+ std::ifstream file_check(full_model_path, std::ios::binary);
919
+
920
+ // Get file size
921
+ file_check.seekg(0, std::ios::end);
922
+ std::streamsize file_size = file_check.tellg();
923
+ file_check.seekg(0, std::ios::beg);
924
+
925
+ // Check GGUF magic number
926
+ char magic[4];
927
+ if (!file_check.read(magic, 4)) {
928
+ LOGE("Failed to read magic number from: %s", full_model_path.c_str());
929
+ throw_java_exception(env, "java/lang/RuntimeException", "Failed to read model file header");
930
+ return nullptr;
931
+ }
932
+
933
+ if (magic[0] != 'G' || magic[1] != 'G' || magic[2] != 'U' || magic[3] != 'F') {
934
+ LOGE("Invalid GGUF file (magic: %c%c%c%c): %s", magic[0], magic[1], magic[2], magic[3], full_model_path.c_str());
935
+ throw_java_exception(env, "java/lang/RuntimeException", "Invalid GGUF file format");
936
+ return nullptr;
937
+ }
938
+
939
+ // Read GGUF version
940
+ uint32_t version;
941
+ if (!file_check.read(reinterpret_cast<char*>(&version), sizeof(version))) {
942
+ LOGE("Failed to read GGUF version from: %s", full_model_path.c_str());
943
+ throw_java_exception(env, "java/lang/RuntimeException", "Failed to read GGUF version");
944
+ return nullptr;
945
+ }
946
+
947
+ file_check.close();
948
+
949
+ // Create Java HashMap
950
+ jclass hashMapClass = env->FindClass("java/util/HashMap");
951
+ jmethodID hashMapConstructor = env->GetMethodID(hashMapClass, "<init>", "()V");
952
+ jmethodID putMethod = env->GetMethodID(hashMapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
953
+
954
+ jobject hashMap = env->NewObject(hashMapClass, hashMapConstructor);
955
+
956
+ // Add model info to HashMap
957
+ env->CallObjectMethod(hashMap, putMethod,
958
+ string_to_jstring(env, "path"),
959
+ string_to_jstring(env, full_model_path));
960
+
961
+ env->CallObjectMethod(hashMap, putMethod,
962
+ string_to_jstring(env, "size"),
963
+ env->NewObject(env->FindClass("java/lang/Long"),
964
+ env->GetMethodID(env->FindClass("java/lang/Long"), "<init>", "(J)V"),
965
+ static_cast<jlong>(file_size)));
966
+
967
+ env->CallObjectMethod(hashMap, putMethod,
968
+ string_to_jstring(env, "desc"),
969
+ string_to_jstring(env, "GGUF Model (v" + std::to_string(version) + ")"));
970
+
971
+ env->CallObjectMethod(hashMap, putMethod,
972
+ string_to_jstring(env, "nEmbd"),
973
+ env->NewObject(env->FindClass("java/lang/Integer"),
974
+ env->GetMethodID(env->FindClass("java/lang/Integer"), "<init>", "(I)V"),
975
+ 0)); // Will be filled by actual model loading
976
+
977
+ env->CallObjectMethod(hashMap, putMethod,
978
+ string_to_jstring(env, "nParams"),
979
+ env->NewObject(env->FindClass("java/lang/Integer"),
980
+ env->GetMethodID(env->FindClass("java/lang/Integer"), "<init>", "(I)V"),
981
+ 0)); // Will be filled by actual model loading
982
+
983
+ LOGI("Model info retrieved successfully from %s: size=%ld, version=%u", full_model_path.c_str(), file_size, version);
984
+ return hashMap;
985
+
986
+ } catch (const std::exception& e) {
987
+ LOGE("Exception in modelInfo: %s", e.what());
988
+ throw_java_exception(env, "java/lang/RuntimeException", e.what());
989
+ return nullptr;
990
+ }
991
+ }
992
+
993
+
994
+
995
+ JNIEXPORT jstring JNICALL
996
+ Java_ai_annadata_plugin_capacitor_LlamaCpp_downloadModelNative(
997
+ JNIEnv* env, jobject thiz, jstring url, jstring filename) {
998
+
999
+ try {
1000
+ std::string url_str = jstring_to_string(env, url);
1001
+ std::string filename_str = jstring_to_string(env, filename);
1002
+
1003
+ LOGI("Preparing download path for model: %s", filename_str.c_str());
1004
+
1005
+ // Determine local storage path (use external storage for large files)
1006
+ std::string local_path = "/storage/emulated/0/Android/data/ai.annadata.llamacpp/files/Models/" + filename_str;
1007
+
1008
+ // Create directory if it doesn't exist
1009
+ std::string dir_path = "/storage/emulated/0/Android/data/ai.annadata.llamacpp/files/Models/";
1010
+ std::filesystem::create_directories(dir_path);
1011
+
1012
+ LOGI("Download path prepared: %s", local_path.c_str());
1013
+
1014
+ return string_to_jstring(env, local_path);
1015
+
1016
+ } catch (const std::exception& e) {
1017
+ LOGE("Exception in downloadModel: %s", e.what());
1018
+ throw_java_exception(env, "java/lang/RuntimeException", e.what());
1019
+ return nullptr;
1020
+ }
1021
+ }
1022
+
1023
+ JNIEXPORT jobject JNICALL
1024
+ Java_ai_annadata_plugin_capacitor_LlamaCpp_getDownloadProgressNative(
1025
+ JNIEnv* env, jobject thiz, jstring url) {
1026
+
1027
+ try {
1028
+ // For now, return a placeholder since we'll handle download in Java
1029
+ // This can be enhanced later to track actual download progress
1030
+
1031
+ jclass hashMapClass = env->FindClass("java/util/HashMap");
1032
+ jmethodID hashMapConstructor = env->GetMethodID(hashMapClass, "<init>", "()V");
1033
+ jmethodID putMethod = env->GetMethodID(hashMapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
1034
+
1035
+ jobject hashMap = env->NewObject(hashMapClass, hashMapConstructor);
1036
+
1037
+ // Return placeholder progress info
1038
+ env->CallObjectMethod(hashMap, putMethod,
1039
+ string_to_jstring(env, "progress"),
1040
+ env->NewObject(env->FindClass("java/lang/Double"),
1041
+ env->GetMethodID(env->FindClass("java/lang/Double"), "<init>", "(D)V"),
1042
+ 0.0));
1043
+
1044
+ env->CallObjectMethod(hashMap, putMethod,
1045
+ string_to_jstring(env, "completed"),
1046
+ env->NewObject(env->FindClass("java/lang/Boolean"),
1047
+ env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"),
1048
+ false));
1049
+
1050
+ env->CallObjectMethod(hashMap, putMethod,
1051
+ string_to_jstring(env, "failed"),
1052
+ env->NewObject(env->FindClass("java/lang/Boolean"),
1053
+ env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"),
1054
+ false));
1055
+
1056
+ return hashMap;
1057
+
1058
+ } catch (const std::exception& e) {
1059
+ LOGE("Exception in getDownloadProgress: %s", e.what());
1060
+ throw_java_exception(env, "java/lang/RuntimeException", e.what());
1061
+ return nullptr;
1062
+ }
1063
+ }
1064
+
1065
+ JNIEXPORT jboolean JNICALL
1066
+ Java_ai_annadata_plugin_capacitor_LlamaCpp_cancelDownloadNative(
1067
+ JNIEnv* env, jobject thiz, jstring url) {
1068
+
1069
+ try {
1070
+ // For now, return false since we'll handle download cancellation in Java
1071
+ // This can be enhanced later to actually cancel downloads
1072
+ return JNI_FALSE;
1073
+
1074
+ } catch (const std::exception& e) {
1075
+ LOGE("Exception in cancelDownload: %s", e.what());
1076
+ throw_java_exception(env, "java/lang/RuntimeException", e.what());
1077
+ return JNI_FALSE;
1078
+ }
1079
+ }
1080
+
1081
+ JNIEXPORT jobject JNICALL
1082
+ Java_ai_annadata_plugin_capacitor_LlamaCpp_getAvailableModelsNative(
1083
+ JNIEnv* env, jobject thiz) {
1084
+
1085
+ try {
1086
+ std::string models_dir = "/storage/emulated/0/Android/data/ai.annadata.llamacpp/files/Models/";
1087
+
1088
+ // Create Java ArrayList
1089
+ jclass arrayListClass = env->FindClass("java/util/ArrayList");
1090
+ jmethodID arrayListConstructor = env->GetMethodID(arrayListClass, "<init>", "()V");
1091
+ jmethodID addMethod = env->GetMethodID(arrayListClass, "add", "(Ljava/lang/Object;)Z");
1092
+
1093
+ jobject arrayList = env->NewObject(arrayListClass, arrayListConstructor);
1094
+
1095
+ if (std::filesystem::exists(models_dir)) {
1096
+ for (const auto& entry : std::filesystem::directory_iterator(models_dir)) {
1097
+ if (entry.is_regular_file() && entry.path().extension() == ".gguf") {
1098
+ std::string filename = entry.path().filename().string();
1099
+ std::string full_path = entry.path().string();
1100
+ size_t file_size = entry.file_size();
1101
+
1102
+ // Create model info HashMap
1103
+ jclass hashMapClass = env->FindClass("java/util/HashMap");
1104
+ jmethodID hashMapConstructor = env->GetMethodID(hashMapClass, "<init>", "()V");
1105
+ jmethodID putMethod = env->GetMethodID(hashMapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
1106
+
1107
+ jobject modelInfo = env->NewObject(hashMapClass, hashMapConstructor);
1108
+
1109
+ env->CallObjectMethod(modelInfo, putMethod,
1110
+ string_to_jstring(env, "name"),
1111
+ string_to_jstring(env, filename));
1112
+
1113
+ env->CallObjectMethod(modelInfo, putMethod,
1114
+ string_to_jstring(env, "path"),
1115
+ string_to_jstring(env, full_path));
1116
+
1117
+ env->CallObjectMethod(modelInfo, putMethod,
1118
+ string_to_jstring(env, "size"),
1119
+ env->NewObject(env->FindClass("java/lang/Long"),
1120
+ env->GetMethodID(env->FindClass("java/lang/Long"), "<init>", "(J)V"),
1121
+ static_cast<jlong>(file_size)));
1122
+
1123
+ // Add to ArrayList
1124
+ env->CallBooleanMethod(arrayList, addMethod, modelInfo);
1125
+ }
1126
+ }
1127
+ }
1128
+
1129
+ return arrayList;
1130
+
1131
+ } catch (const std::exception& e) {
1132
+ LOGE("Exception in getAvailableModels: %s", e.what());
1133
+ throw_java_exception(env, "java/lang/RuntimeException", e.what());
1134
+ return nullptr;
1135
+ }
1136
+ }
1137
+
1138
+ // MARK: - Tokenization methods
1139
+
1140
+ JNIEXPORT jobject JNICALL
1141
+ Java_ai_annadata_plugin_capacitor_LlamaCpp_tokenizeNative(
1142
+ JNIEnv* env, jobject thiz, jlong contextId, jstring text, jobjectArray imagePaths) {
1143
+
1144
+ try {
1145
+ LOGI("Tokenizing with context ID: %ld", contextId);
1146
+
1147
+ std::string text_str = jni_utils::jstring_to_string(env, text);
1148
+ LOGI("Text to tokenize: %s", text_str.c_str());
1149
+
1150
+ // Find the context
1151
+ auto it = contexts.find(contextId);
1152
+ if (it == contexts.end()) {
1153
+ LOGE("Context not found: %ld", contextId);
1154
+ throw_java_exception(env, "java/lang/RuntimeException", "Context not found");
1155
+ return nullptr;
1156
+ }
1157
+
1158
+ auto& ctx = it->second;
1159
+ if (!ctx || !ctx->ctx) {
1160
+ LOGE("Invalid context or llama context is null");
1161
+ throw_java_exception(env, "java/lang/RuntimeException", "Invalid context");
1162
+ return nullptr;
1163
+ }
1164
+
1165
+ // Tokenize the text using the context's tokenize method
1166
+ capllama::llama_cap_tokenize_result tokenize_result = ctx->tokenize(text_str, {});
1167
+ std::vector<llama_token> tokens = tokenize_result.tokens;
1168
+
1169
+ LOGI("Tokenized %zu tokens", tokens.size());
1170
+
1171
+ // Create Java HashMap for result
1172
+ jclass hashMapClass = env->FindClass("java/util/HashMap");
1173
+ jmethodID hashMapConstructor = env->GetMethodID(hashMapClass, "<init>", "()V");
1174
+ jmethodID putMethod = env->GetMethodID(hashMapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
1175
+
1176
+ jobject resultMap = env->NewObject(hashMapClass, hashMapConstructor);
1177
+
1178
+ // Create Java ArrayList for tokens
1179
+ jclass arrayListClass = env->FindClass("java/util/ArrayList");
1180
+ jmethodID arrayListConstructor = env->GetMethodID(arrayListClass, "<init>", "()V");
1181
+ jmethodID addMethod = env->GetMethodID(arrayListClass, "add", "(Ljava/lang/Object;)Z");
1182
+
1183
+ jobject tokensArray = env->NewObject(arrayListClass, arrayListConstructor);
1184
+
1185
+ // Add tokens to ArrayList
1186
+ jclass integerClass = env->FindClass("java/lang/Integer");
1187
+ jmethodID integerConstructor = env->GetMethodID(integerClass, "<init>", "(I)V");
1188
+
1189
+ for (llama_token token : tokens) {
1190
+ jobject jToken = env->NewObject(integerClass, integerConstructor, static_cast<jint>(token));
1191
+ env->CallBooleanMethod(tokensArray, addMethod, jToken);
1192
+ env->DeleteLocalRef(jToken);
1193
+ }
1194
+
1195
+ // Create empty arrays for other fields
1196
+ jobject emptyBitmapHashes = env->NewObject(arrayListClass, arrayListConstructor);
1197
+ jobject emptyChunkPos = env->NewObject(arrayListClass, arrayListConstructor);
1198
+ jobject emptyChunkPosImages = env->NewObject(arrayListClass, arrayListConstructor);
1199
+
1200
+ // Put all data into result map
1201
+ env->CallObjectMethod(resultMap, putMethod,
1202
+ jni_utils::string_to_jstring(env, "tokens"), tokensArray);
1203
+ env->CallObjectMethod(resultMap, putMethod,
1204
+ jni_utils::string_to_jstring(env, "has_images"),
1205
+ env->NewObject(env->FindClass("java/lang/Boolean"),
1206
+ env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"), JNI_FALSE));
1207
+ env->CallObjectMethod(resultMap, putMethod,
1208
+ jni_utils::string_to_jstring(env, "bitmap_hashes"), emptyBitmapHashes);
1209
+ env->CallObjectMethod(resultMap, putMethod,
1210
+ jni_utils::string_to_jstring(env, "chunk_pos"), emptyChunkPos);
1211
+ env->CallObjectMethod(resultMap, putMethod,
1212
+ jni_utils::string_to_jstring(env, "chunk_pos_images"), emptyChunkPosImages);
1213
+
1214
+ LOGI("Tokenization completed successfully");
1215
+ return resultMap;
1216
+
1217
+ } catch (const std::exception& e) {
1218
+ LOGE("Exception in tokenize: %s", e.what());
1219
+ throw_java_exception(env, "java/lang/RuntimeException", e.what());
1220
+ return nullptr;
1221
+ }
1222
+ }
1223
+
1224
+ JNIEXPORT jstring JNICALL
1225
+ Java_ai_annadata_plugin_capacitor_LlamaCpp_detokenizeNative(
1226
+ JNIEnv* env, jobject thiz, jlong contextId, jintArray tokens) {
1227
+
1228
+ try {
1229
+ LOGI("Detokenizing with context ID: %ld", contextId);
1230
+
1231
+ // Find the context
1232
+ auto it = contexts.find(contextId);
1233
+ if (it == contexts.end()) {
1234
+ LOGE("Context not found: %ld", contextId);
1235
+ throw_java_exception(env, "java/lang/RuntimeException", "Context not found");
1236
+ return nullptr;
1237
+ }
1238
+
1239
+ auto& ctx = it->second;
1240
+ if (!ctx || !ctx->ctx) {
1241
+ LOGE("Invalid context or llama context is null");
1242
+ throw_java_exception(env, "java/lang/RuntimeException", "Invalid context");
1243
+ return nullptr;
1244
+ }
1245
+
1246
+ // Convert Java int array to C++ vector
1247
+ jsize length = env->GetArrayLength(tokens);
1248
+ jint* tokenArray = env->GetIntArrayElements(tokens, nullptr);
1249
+
1250
+ std::vector<llama_token> llamaTokens;
1251
+ for (jsize i = 0; i < length; i++) {
1252
+ llamaTokens.push_back(static_cast<llama_token>(tokenArray[i]));
1253
+ }
1254
+
1255
+ env->ReleaseIntArrayElements(tokens, tokenArray, JNI_ABORT);
1256
+
1257
+ // Detokenize using llama.cpp
1258
+ std::string result = capllama::tokens_to_str(ctx->ctx, llamaTokens.begin(), llamaTokens.end());
1259
+
1260
+ LOGI("Detokenized to: %s", result.c_str());
1261
+
1262
+ return jni_utils::string_to_jstring(env, result);
1263
+
1264
+ } catch (const std::exception& e) {
1265
+ LOGE("Exception in detokenize: %s", e.what());
1266
+ throw_java_exception(env, "java/lang/RuntimeException", e.what());
1267
+ return nullptr;
1268
+ }
1269
+ }
529
1270
 
530
1271
  } // extern "C"
531
1272