llama-cpp-capacitor 0.0.12 → 0.0.21
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LlamaCpp.podspec +17 -17
- package/Package.swift +27 -27
- package/README.md +717 -574
- package/android/build.gradle +88 -69
- package/android/src/main/AndroidManifest.xml +2 -2
- package/android/src/main/CMakeLists-arm64.txt +131 -0
- package/android/src/main/CMakeLists-x86_64.txt +135 -0
- package/android/src/main/CMakeLists.txt +35 -52
- package/android/src/main/java/ai/annadata/plugin/capacitor/LlamaCpp.java +956 -717
- package/android/src/main/java/ai/annadata/plugin/capacitor/LlamaCppPlugin.java +710 -590
- package/android/src/main/jni-utils.h +7 -7
- package/android/src/main/jni.cpp +952 -159
- package/cpp/{rn-completion.cpp → cap-completion.cpp} +202 -24
- package/cpp/{rn-completion.h → cap-completion.h} +22 -11
- package/cpp/{rn-llama.cpp → cap-llama.cpp} +81 -27
- package/cpp/{rn-llama.h → cap-llama.h} +32 -20
- package/cpp/{rn-mtmd.hpp → cap-mtmd.hpp} +15 -15
- package/cpp/{rn-tts.cpp → cap-tts.cpp} +12 -12
- package/cpp/{rn-tts.h → cap-tts.h} +14 -14
- package/cpp/ggml-cpu/ggml-cpu-impl.h +30 -0
- package/dist/docs.json +100 -3
- package/dist/esm/definitions.d.ts +45 -2
- package/dist/esm/definitions.js.map +1 -1
- package/dist/esm/index.d.ts +22 -0
- package/dist/esm/index.js +66 -3
- package/dist/esm/index.js.map +1 -1
- package/dist/plugin.cjs.js +71 -3
- package/dist/plugin.cjs.js.map +1 -1
- package/dist/plugin.js +71 -3
- package/dist/plugin.js.map +1 -1
- package/ios/Sources/LlamaCppPlugin/LlamaCpp.swift +596 -596
- package/ios/Sources/LlamaCppPlugin/LlamaCppPlugin.swift +591 -514
- package/ios/Tests/LlamaCppPluginTests/LlamaCppPluginTests.swift +15 -15
- package/package.json +111 -110
package/android/src/main/jni.cpp
CHANGED
|
@@ -1,14 +1,21 @@
|
|
|
1
1
|
#include "jni-utils.h"
|
|
2
|
-
#include "
|
|
2
|
+
#include "cap-llama.h"
|
|
3
|
+
#include "cap-completion.h"
|
|
3
4
|
#include <android/log.h>
|
|
4
5
|
#include <cstring>
|
|
5
6
|
#include <memory>
|
|
6
7
|
#include <fstream> // Added for file existence and size checks
|
|
8
|
+
#include <signal.h> // Added for signal handling
|
|
9
|
+
#include <sys/signal.h> // Added for sigaction
|
|
10
|
+
#include <thread> // For background downloads
|
|
11
|
+
#include <atomic> // For thread-safe progress tracking
|
|
12
|
+
#include <filesystem> // For file operations
|
|
13
|
+
#include <mutex> // For thread synchronization
|
|
7
14
|
|
|
8
15
|
// Add missing symbol
|
|
9
|
-
namespace rnllama {
|
|
10
|
-
|
|
11
|
-
}
|
|
16
|
+
// namespace rnllama {
|
|
17
|
+
// bool rnllama_verbose = false;
|
|
18
|
+
// }
|
|
12
19
|
|
|
13
20
|
#define LOG_TAG "LlamaCpp"
|
|
14
21
|
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
|
|
@@ -128,81 +135,92 @@ jclass find_class(JNIEnv* env, const char* name) {
|
|
|
128
135
|
return clazz;
|
|
129
136
|
}
|
|
130
137
|
|
|
131
|
-
//
|
|
132
|
-
|
|
138
|
+
// Convert llama_cap_context to jobject
|
|
139
|
+
jobject llama_context_to_jobject(JNIEnv* env, const capllama::llama_cap_context* context);
|
|
140
|
+
|
|
141
|
+
// Convert jobject to llama_cap_context
|
|
142
|
+
capllama::llama_cap_context* jobject_to_llama_context(JNIEnv* env, jobject obj);
|
|
143
|
+
|
|
144
|
+
// Convert completion result to jobject
|
|
145
|
+
jobject completion_result_to_jobject(JNIEnv* env, const capllama::completion_token_output& result);
|
|
146
|
+
|
|
147
|
+
// Convert tokenize result to jobject
|
|
148
|
+
jobject tokenize_result_to_jobject(JNIEnv* env, const capllama::llama_cap_tokenize_result& result);
|
|
149
|
+
|
|
150
|
+
// Global context storage - fix namespace
|
|
151
|
+
static std::map<jlong, std::unique_ptr<capllama::llama_cap_context>> contexts;
|
|
133
152
|
static jlong next_context_id = 1;
|
|
134
153
|
|
|
154
|
+
// Download progress tracking (simplified for now)
|
|
155
|
+
// This can be enhanced later to track actual download progress
|
|
156
|
+
|
|
135
157
|
extern "C" {
|
|
136
158
|
|
|
137
159
|
JNIEXPORT jlong JNICALL
|
|
138
160
|
Java_ai_annadata_plugin_capacitor_LlamaCpp_initContextNative(
|
|
139
|
-
JNIEnv*
|
|
161
|
+
JNIEnv *env, jobject thiz, jstring modelPath, jobjectArray searchPaths, jobject params) {
|
|
140
162
|
|
|
141
163
|
try {
|
|
142
|
-
std::string model_path_str = jstring_to_string(env,
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
std::vector<std::string> paths_to_check
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
164
|
+
std::string model_path_str = jstring_to_string(env, modelPath);
|
|
165
|
+
|
|
166
|
+
// Get search paths from Java
|
|
167
|
+
jsize pathCount = env->GetArrayLength(searchPaths);
|
|
168
|
+
std::vector<std::string> paths_to_check;
|
|
169
|
+
|
|
170
|
+
// Add the original path first
|
|
171
|
+
paths_to_check.push_back(model_path_str);
|
|
172
|
+
|
|
173
|
+
// Add all search paths from Java
|
|
174
|
+
for (jsize i = 0; i < pathCount; i++) {
|
|
175
|
+
jstring pathJString = (jstring)env->GetObjectArrayElement(searchPaths, i);
|
|
176
|
+
std::string path = jstring_to_string(env, pathJString);
|
|
177
|
+
paths_to_check.push_back(path);
|
|
178
|
+
env->DeleteLocalRef(pathJString);
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
// Rest of the existing logic remains the same...
|
|
156
182
|
std::string full_model_path;
|
|
157
183
|
bool file_found = false;
|
|
158
184
|
|
|
159
185
|
for (const auto& path : paths_to_check) {
|
|
160
186
|
LOGI("Checking path: %s", path.c_str());
|
|
161
|
-
std::
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
LOGI("Found file at: %s, size: %ld bytes", path.c_str(), file_size);
|
|
167
|
-
|
|
168
|
-
// Validate file size
|
|
169
|
-
if (file_size < 1024 * 1024) { // Less than 1MB
|
|
170
|
-
LOGE("Model file is too small, likely corrupted: %s", path.c_str());
|
|
171
|
-
continue; // Try next path
|
|
172
|
-
}
|
|
173
|
-
|
|
174
|
-
// Check if it's a valid GGUF file by reading the magic number
|
|
175
|
-
std::ifstream magic_file(path, std::ios::binary);
|
|
176
|
-
if (magic_file.good()) {
|
|
177
|
-
char magic[4];
|
|
178
|
-
if (magic_file.read(magic, 4)) {
|
|
179
|
-
if (magic[0] == 'G' && magic[1] == 'G' && magic[2] == 'U' && magic[3] == 'F') {
|
|
180
|
-
LOGI("Valid GGUF file detected at: %s", path.c_str());
|
|
181
|
-
full_model_path = path;
|
|
182
|
-
file_found = true;
|
|
183
|
-
break;
|
|
184
|
-
} else {
|
|
185
|
-
LOGI("File does not appear to be a GGUF file (magic: %c%c%c%c) at: %s",
|
|
186
|
-
magic[0], magic[1], magic[2], magic[3], path.c_str());
|
|
187
|
-
}
|
|
188
|
-
}
|
|
189
|
-
magic_file.close();
|
|
190
|
-
}
|
|
187
|
+
if (std::filesystem::exists(path)) {
|
|
188
|
+
full_model_path = path;
|
|
189
|
+
file_found = true;
|
|
190
|
+
LOGI("Found model file at: %s", path.c_str());
|
|
191
|
+
break;
|
|
191
192
|
} else {
|
|
192
|
-
|
|
193
|
+
LOGE("Path not found: %s", path.c_str());
|
|
193
194
|
}
|
|
194
|
-
file_check.close();
|
|
195
195
|
}
|
|
196
|
-
|
|
196
|
+
|
|
197
197
|
if (!file_found) {
|
|
198
|
-
LOGE("Model file not found in any of the
|
|
199
|
-
throw_java_exception(env, "java/lang/RuntimeException", "Model file not found in any expected location");
|
|
198
|
+
LOGE("Model file not found in any of the search paths");
|
|
200
199
|
return -1;
|
|
201
200
|
}
|
|
201
|
+
|
|
202
|
+
// Additional model validation
|
|
203
|
+
LOGI("Performing additional model validation...");
|
|
204
|
+
std::ifstream validation_file(full_model_path, std::ios::binary);
|
|
205
|
+
if (validation_file.good()) {
|
|
206
|
+
// Read first 8 bytes to check GGUF version
|
|
207
|
+
char header[8];
|
|
208
|
+
if (validation_file.read(header, 8)) {
|
|
209
|
+
uint32_t version = *reinterpret_cast<uint32_t*>(header + 4);
|
|
210
|
+
LOGI("GGUF version: %u", version);
|
|
211
|
+
|
|
212
|
+
// Check if version is reasonable (should be > 0 and < 1000)
|
|
213
|
+
if (version == 0 || version > 1000) {
|
|
214
|
+
LOGE("Suspicious GGUF version: %u", version);
|
|
215
|
+
LOGI("This might indicate a corrupted or incompatible model file");
|
|
216
|
+
}
|
|
217
|
+
}
|
|
218
|
+
validation_file.close();
|
|
219
|
+
}
|
|
202
220
|
|
|
203
|
-
// Create new context
|
|
204
|
-
auto context = std::make_unique<
|
|
205
|
-
LOGI("Created
|
|
221
|
+
// Create new context - fix namespace
|
|
222
|
+
auto context = std::make_unique<capllama::llama_cap_context>();
|
|
223
|
+
LOGI("Created llama_cap_context");
|
|
206
224
|
|
|
207
225
|
// Initialize common parameters
|
|
208
226
|
common_params cparams;
|
|
@@ -219,54 +237,46 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_initContextNative(
|
|
|
219
237
|
cparams.chat_template = "";
|
|
220
238
|
cparams.embedding = false;
|
|
221
239
|
cparams.cont_batching = false;
|
|
222
|
-
cparams.
|
|
223
|
-
cparams.grammar = "";
|
|
224
|
-
cparams.grammar_penalty.clear();
|
|
240
|
+
cparams.n_parallel = 1;
|
|
225
241
|
cparams.antiprompt.clear();
|
|
226
|
-
cparams.lora_adapter.clear();
|
|
227
|
-
cparams.lora_base = "";
|
|
228
|
-
cparams.mul_mat_q = true;
|
|
229
|
-
cparams.f16_kv = true;
|
|
230
|
-
cparams.logits_all = false;
|
|
231
242
|
cparams.vocab_only = false;
|
|
232
243
|
cparams.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
|
|
233
|
-
cparams.rope_scaling_factor = 0.0f;
|
|
234
|
-
cparams.rope_scaling_orig_ctx_len = 0;
|
|
235
244
|
cparams.yarn_ext_factor = -1.0f;
|
|
236
245
|
cparams.yarn_attn_factor = 1.0f;
|
|
237
246
|
cparams.yarn_beta_fast = 32.0f;
|
|
238
247
|
cparams.yarn_beta_slow = 1.0f;
|
|
239
248
|
cparams.yarn_orig_ctx = 0;
|
|
240
|
-
cparams.offload_kqv = true;
|
|
241
249
|
cparams.flash_attn = false;
|
|
242
|
-
cparams.flash_attn_kernel = false;
|
|
243
|
-
cparams.flash_attn_causal = true;
|
|
244
|
-
cparams.mmproj = "";
|
|
245
|
-
cparams.image = "";
|
|
246
|
-
cparams.export = "";
|
|
247
|
-
cparams.export_path = "";
|
|
248
|
-
cparams.seed = -1;
|
|
249
250
|
cparams.n_keep = 0;
|
|
250
|
-
cparams.n_discard = -1;
|
|
251
|
-
cparams.n_draft = 0;
|
|
252
251
|
cparams.n_chunks = -1;
|
|
253
|
-
cparams.n_parallel = 1;
|
|
254
252
|
cparams.n_sequences = 1;
|
|
255
|
-
cparams.p_accept = 0.5f;
|
|
256
|
-
cparams.p_split = 0.1f;
|
|
257
|
-
cparams.n_gqa = 8;
|
|
258
|
-
cparams.rms_norm_eps = 5e-6f;
|
|
259
253
|
cparams.model_alias = "unknown";
|
|
260
|
-
|
|
261
|
-
cparams.ubatch_seq_len_max = 1;
|
|
262
|
-
|
|
254
|
+
|
|
263
255
|
LOGI("Initialized common parameters, attempting to load model from: %s", full_model_path.c_str());
|
|
264
256
|
LOGI("Model parameters: n_ctx=%d, n_batch=%d, n_gpu_layers=%d",
|
|
265
257
|
cparams.n_ctx, cparams.n_batch, cparams.n_gpu_layers);
|
|
266
258
|
|
|
267
|
-
// Try to load the model with error handling
|
|
259
|
+
// Try to load the model with error handling and signal protection
|
|
268
260
|
bool load_success = false;
|
|
261
|
+
|
|
262
|
+
// Set up signal handler to catch segmentation faults
|
|
263
|
+
struct sigaction old_action;
|
|
264
|
+
struct sigaction new_action;
|
|
265
|
+
new_action.sa_handler = [](int sig) {
|
|
266
|
+
LOGE("Segmentation fault caught during model loading");
|
|
267
|
+
// Restore default handler and re-raise signal
|
|
268
|
+
signal(sig, SIG_DFL);
|
|
269
|
+
raise(sig);
|
|
270
|
+
};
|
|
271
|
+
new_action.sa_flags = SA_RESETHAND;
|
|
272
|
+
sigemptyset(&new_action.sa_mask);
|
|
273
|
+
|
|
274
|
+
if (sigaction(SIGSEGV, &new_action, &old_action) == 0) {
|
|
275
|
+
LOGI("Signal handler installed for segmentation fault protection");
|
|
276
|
+
}
|
|
277
|
+
|
|
269
278
|
try {
|
|
279
|
+
LOGI("Attempting to load model with standard parameters...");
|
|
270
280
|
load_success = context->loadModel(cparams);
|
|
271
281
|
} catch (const std::exception& e) {
|
|
272
282
|
LOGE("Exception during model loading: %s", e.what());
|
|
@@ -276,77 +286,64 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_initContextNative(
|
|
|
276
286
|
load_success = false;
|
|
277
287
|
}
|
|
278
288
|
|
|
289
|
+
// Restore original signal handler
|
|
290
|
+
sigaction(SIGSEGV, &old_action, nullptr);
|
|
291
|
+
|
|
279
292
|
if (!load_success) {
|
|
280
293
|
LOGE("context->loadModel() returned false - model loading failed");
|
|
281
294
|
|
|
282
|
-
// Try with minimal parameters as fallback
|
|
283
|
-
LOGI("Trying with minimal parameters...");
|
|
284
|
-
common_params
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
minimal_params.flash_attn = false;
|
|
316
|
-
minimal_params.flash_attn_kernel = false;
|
|
317
|
-
minimal_params.flash_attn_causal = true;
|
|
318
|
-
minimal_params.mmproj = "";
|
|
319
|
-
minimal_params.image = "";
|
|
320
|
-
minimal_params.export = "";
|
|
321
|
-
minimal_params.export_path = "";
|
|
322
|
-
minimal_params.seed = -1;
|
|
323
|
-
minimal_params.n_keep = 0;
|
|
324
|
-
minimal_params.n_discard = -1;
|
|
325
|
-
minimal_params.n_draft = 0;
|
|
326
|
-
minimal_params.n_chunks = -1;
|
|
327
|
-
minimal_params.n_parallel = 1;
|
|
328
|
-
minimal_params.n_sequences = 1;
|
|
329
|
-
minimal_params.p_accept = 0.5f;
|
|
330
|
-
minimal_params.p_split = 0.1f;
|
|
331
|
-
minimal_params.n_gqa = 8;
|
|
332
|
-
minimal_params.rms_norm_eps = 5e-6f;
|
|
333
|
-
minimal_params.model_alias = "unknown";
|
|
334
|
-
minimal_params.ubatch_size = 256;
|
|
335
|
-
minimal_params.ubatch_seq_len_max = 1;
|
|
295
|
+
// Try with ultra-minimal parameters as fallback
|
|
296
|
+
LOGI("Trying with ultra-minimal parameters...");
|
|
297
|
+
common_params ultra_minimal_params;
|
|
298
|
+
ultra_minimal_params.model.path = full_model_path;
|
|
299
|
+
ultra_minimal_params.n_ctx = 256; // Very small context
|
|
300
|
+
ultra_minimal_params.n_batch = 128; // Very small batch
|
|
301
|
+
ultra_minimal_params.n_gpu_layers = 0;
|
|
302
|
+
ultra_minimal_params.use_mmap = false; // Disable mmap to avoid memory issues
|
|
303
|
+
ultra_minimal_params.use_mlock = false;
|
|
304
|
+
ultra_minimal_params.numa = LM_GGML_NUMA_STRATEGY_DISABLED;
|
|
305
|
+
ultra_minimal_params.ctx_shift = false;
|
|
306
|
+
ultra_minimal_params.chat_template = "";
|
|
307
|
+
ultra_minimal_params.embedding = false;
|
|
308
|
+
ultra_minimal_params.cont_batching = false;
|
|
309
|
+
ultra_minimal_params.n_parallel = 1;
|
|
310
|
+
ultra_minimal_params.antiprompt.clear();
|
|
311
|
+
ultra_minimal_params.vocab_only = false;
|
|
312
|
+
ultra_minimal_params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
|
|
313
|
+
ultra_minimal_params.yarn_ext_factor = -1.0f;
|
|
314
|
+
ultra_minimal_params.yarn_attn_factor = 1.0f;
|
|
315
|
+
ultra_minimal_params.yarn_beta_fast = 32.0f;
|
|
316
|
+
ultra_minimal_params.yarn_beta_slow = 1.0f;
|
|
317
|
+
ultra_minimal_params.yarn_orig_ctx = 0;
|
|
318
|
+
ultra_minimal_params.flash_attn = false;
|
|
319
|
+
ultra_minimal_params.n_keep = 0;
|
|
320
|
+
ultra_minimal_params.n_chunks = -1;
|
|
321
|
+
ultra_minimal_params.n_sequences = 1;
|
|
322
|
+
ultra_minimal_params.model_alias = "unknown";
|
|
323
|
+
|
|
324
|
+
// Set up signal handler again for ultra-minimal attempt
|
|
325
|
+
if (sigaction(SIGSEGV, &new_action, &old_action) == 0) {
|
|
326
|
+
LOGI("Signal handler reinstalled for ultra-minimal attempt");
|
|
327
|
+
}
|
|
336
328
|
|
|
337
329
|
try {
|
|
338
|
-
load_success = context->loadModel(
|
|
330
|
+
load_success = context->loadModel(ultra_minimal_params);
|
|
339
331
|
} catch (const std::exception& e) {
|
|
340
|
-
LOGE("Exception during minimal model loading: %s", e.what());
|
|
332
|
+
LOGE("Exception during ultra-minimal model loading: %s", e.what());
|
|
341
333
|
load_success = false;
|
|
342
334
|
} catch (...) {
|
|
343
|
-
LOGE("Unknown exception during minimal model loading");
|
|
335
|
+
LOGE("Unknown exception during ultra-minimal model loading");
|
|
344
336
|
load_success = false;
|
|
345
337
|
}
|
|
346
338
|
|
|
339
|
+
// Restore original signal handler
|
|
340
|
+
sigaction(SIGSEGV, &old_action, nullptr);
|
|
341
|
+
|
|
347
342
|
if (!load_success) {
|
|
348
|
-
LOGE("Model loading failed even with minimal parameters");
|
|
349
|
-
throw_java_exception(env, "java/lang/RuntimeException",
|
|
343
|
+
LOGE("Model loading failed even with ultra-minimal parameters");
|
|
344
|
+
throw_java_exception(env, "java/lang/RuntimeException",
|
|
345
|
+
"Failed to load model - model appears to be corrupted or incompatible with this llama.cpp version. "
|
|
346
|
+
"Try downloading a fresh copy of the model file.");
|
|
350
347
|
return -1;
|
|
351
348
|
}
|
|
352
349
|
}
|
|
@@ -383,28 +380,400 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_releaseContextNative(
|
|
|
383
380
|
}
|
|
384
381
|
}
|
|
385
382
|
|
|
386
|
-
JNIEXPORT
|
|
383
|
+
JNIEXPORT jobject JNICALL
|
|
387
384
|
Java_ai_annadata_plugin_capacitor_LlamaCpp_completionNative(
|
|
388
|
-
JNIEnv* env, jobject thiz, jlong context_id,
|
|
385
|
+
JNIEnv* env, jobject thiz, jlong context_id, jobject params) {
|
|
389
386
|
|
|
390
387
|
try {
|
|
388
|
+
LOGI("Starting completion for context: %ld", context_id);
|
|
389
|
+
|
|
391
390
|
auto it = contexts.find(context_id);
|
|
392
391
|
if (it == contexts.end()) {
|
|
392
|
+
LOGE("Context not found: %ld", context_id);
|
|
393
393
|
throw_java_exception(env, "java/lang/IllegalArgumentException", "Invalid context ID");
|
|
394
394
|
return nullptr;
|
|
395
395
|
}
|
|
396
396
|
|
|
397
|
-
|
|
397
|
+
auto& ctx = it->second;
|
|
398
|
+
if (!ctx || !ctx->ctx) {
|
|
399
|
+
LOGE("Invalid context or llama context is null");
|
|
400
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Invalid context");
|
|
401
|
+
return nullptr;
|
|
402
|
+
}
|
|
398
403
|
|
|
399
|
-
//
|
|
400
|
-
|
|
404
|
+
// Extract parameters from JSObject using compatible API
|
|
405
|
+
jclass jsObjectClass = env->GetObjectClass(params);
|
|
401
406
|
|
|
402
|
-
//
|
|
403
|
-
|
|
404
|
-
|
|
407
|
+
// Try to get method IDs and handle exceptions
|
|
408
|
+
jmethodID getStringMethod = nullptr;
|
|
409
|
+
jmethodID getIntegerMethod = nullptr;
|
|
410
|
+
jmethodID getDoubleMethod = nullptr;
|
|
405
411
|
|
|
406
|
-
|
|
407
|
-
|
|
412
|
+
// Clear any pending exceptions first
|
|
413
|
+
if (env->ExceptionCheck()) {
|
|
414
|
+
env->ExceptionClear();
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
try {
|
|
418
|
+
getStringMethod = env->GetMethodID(jsObjectClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;");
|
|
419
|
+
if (env->ExceptionCheck()) {
|
|
420
|
+
env->ExceptionClear();
|
|
421
|
+
getStringMethod = nullptr;
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
getIntegerMethod = env->GetMethodID(jsObjectClass, "getInteger", "(Ljava/lang/String;)Ljava/lang/Integer;");
|
|
425
|
+
if (env->ExceptionCheck()) {
|
|
426
|
+
env->ExceptionClear();
|
|
427
|
+
getIntegerMethod = nullptr;
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
getDoubleMethod = env->GetMethodID(jsObjectClass, "getDouble", "(Ljava/lang/String;)Ljava/lang/Double;");
|
|
431
|
+
if (env->ExceptionCheck()) {
|
|
432
|
+
env->ExceptionClear();
|
|
433
|
+
getDoubleMethod = nullptr;
|
|
434
|
+
}
|
|
435
|
+
} catch (...) {
|
|
436
|
+
LOGE("Exception getting JSObject method IDs");
|
|
437
|
+
if (env->ExceptionCheck()) {
|
|
438
|
+
env->ExceptionClear();
|
|
439
|
+
}
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
// Get prompt with safe method calls
|
|
443
|
+
std::string prompt_str = "Once upon a time";
|
|
444
|
+
jint n_predict = 50;
|
|
445
|
+
jdouble temperature = 0.7;
|
|
446
|
+
|
|
447
|
+
if (getStringMethod) {
|
|
448
|
+
jstring promptKey = jni_utils::string_to_jstring(env, "prompt");
|
|
449
|
+
jstring promptObj = (jstring)env->CallObjectMethod(params, getStringMethod, promptKey);
|
|
450
|
+
if (promptObj && !env->ExceptionCheck()) {
|
|
451
|
+
prompt_str = jni_utils::jstring_to_string(env, promptObj);
|
|
452
|
+
} else if (env->ExceptionCheck()) {
|
|
453
|
+
env->ExceptionClear();
|
|
454
|
+
}
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
// Get n_predict with safe method calls
|
|
458
|
+
if (getIntegerMethod) {
|
|
459
|
+
jstring nPredictKey = jni_utils::string_to_jstring(env, "n_predict");
|
|
460
|
+
jobject nPredictObj = env->CallObjectMethod(params, getIntegerMethod, nPredictKey);
|
|
461
|
+
if (nPredictObj && !env->ExceptionCheck()) {
|
|
462
|
+
n_predict = env->CallIntMethod(nPredictObj, env->GetMethodID(env->FindClass("java/lang/Integer"), "intValue", "()I"));
|
|
463
|
+
if (env->ExceptionCheck()) {
|
|
464
|
+
env->ExceptionClear();
|
|
465
|
+
n_predict = 50; // fallback
|
|
466
|
+
}
|
|
467
|
+
} else if (env->ExceptionCheck()) {
|
|
468
|
+
env->ExceptionClear();
|
|
469
|
+
}
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
// Get temperature with safe method calls
|
|
473
|
+
if (getDoubleMethod) {
|
|
474
|
+
jstring temperatureKey = jni_utils::string_to_jstring(env, "temperature");
|
|
475
|
+
jobject tempObj = env->CallObjectMethod(params, getDoubleMethod, temperatureKey);
|
|
476
|
+
if (tempObj && !env->ExceptionCheck()) {
|
|
477
|
+
temperature = env->CallDoubleMethod(tempObj, env->GetMethodID(env->FindClass("java/lang/Double"), "doubleValue", "()D"));
|
|
478
|
+
if (env->ExceptionCheck()) {
|
|
479
|
+
env->ExceptionClear();
|
|
480
|
+
temperature = 0.7; // fallback
|
|
481
|
+
}
|
|
482
|
+
} else if (env->ExceptionCheck()) {
|
|
483
|
+
env->ExceptionClear();
|
|
484
|
+
}
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
LOGI("Completion params - prompt: %s, n_predict: %d, temperature: %.2f",
|
|
488
|
+
prompt_str.c_str(), n_predict, temperature);
|
|
489
|
+
|
|
490
|
+
// Set sampling parameters based on extracted values
|
|
491
|
+
ctx->params.sampling.temp = temperature;
|
|
492
|
+
ctx->params.sampling.top_k = 40; // Default value
|
|
493
|
+
ctx->params.sampling.top_p = 0.95f; // Default value
|
|
494
|
+
ctx->params.sampling.penalty_repeat = 1.1f; // Default value (correct field name)
|
|
495
|
+
ctx->params.n_predict = n_predict;
|
|
496
|
+
ctx->params.prompt = prompt_str;
|
|
497
|
+
|
|
498
|
+
LOGI("Updated context sampling params - temp: %.2f, top_k: %d, top_p: %.2f",
|
|
499
|
+
ctx->params.sampling.temp, ctx->params.sampling.top_k, ctx->params.sampling.top_p);
|
|
500
|
+
|
|
501
|
+
// Tokenize the prompt
|
|
502
|
+
capllama::llama_cap_tokenize_result tokenize_result = ctx->tokenize(prompt_str, {});
|
|
503
|
+
std::vector<llama_token> prompt_tokens = tokenize_result.tokens;
|
|
504
|
+
|
|
505
|
+
LOGI("Tokenized prompt into %zu tokens", prompt_tokens.size());
|
|
506
|
+
|
|
507
|
+
// Initialize completion context if not already done
|
|
508
|
+
if (!ctx->completion) {
|
|
509
|
+
LOGI("Initializing completion context for the first time");
|
|
510
|
+
|
|
511
|
+
// Validate parent context before creating completion
|
|
512
|
+
if (!ctx->ctx || !ctx->model) {
|
|
513
|
+
LOGE("Parent context is invalid - missing llama context or model");
|
|
514
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Parent context is not properly initialized");
|
|
515
|
+
return nullptr;
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
try {
|
|
519
|
+
LOGI("Creating llama_cap_context_completion...");
|
|
520
|
+
LOGI("Parent context pointer: %p", ctx.get());
|
|
521
|
+
LOGI("Parent context->ctx: %p", ctx->ctx);
|
|
522
|
+
LOGI("Parent context->model: %p", ctx->model);
|
|
523
|
+
|
|
524
|
+
// Additional safety checks before constructor
|
|
525
|
+
if (!ctx.get()) {
|
|
526
|
+
LOGE("Parent context pointer is null");
|
|
527
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Parent context pointer is null");
|
|
528
|
+
return nullptr;
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
ctx->completion = new capllama::llama_cap_context_completion(ctx.get());
|
|
532
|
+
|
|
533
|
+
if (!ctx->completion) {
|
|
534
|
+
LOGE("Failed to create completion context - constructor returned null");
|
|
535
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Failed to create completion context");
|
|
536
|
+
return nullptr;
|
|
537
|
+
}
|
|
538
|
+
|
|
539
|
+
LOGI("Completion context created successfully at: %p", ctx->completion);
|
|
540
|
+
|
|
541
|
+
LOGI("Initializing sampling for completion context...");
|
|
542
|
+
LOGI("Parent context params before initSampling - model: %p, params: %p", ctx->model, &(ctx->params));
|
|
543
|
+
LOGI("Parent context sampling params - temperature: %.2f, top_k: %d, top_p: %.2f",
|
|
544
|
+
ctx->params.sampling.temp, ctx->params.sampling.top_k, ctx->params.sampling.top_p);
|
|
545
|
+
|
|
546
|
+
bool sampling_result = false;
|
|
547
|
+
try {
|
|
548
|
+
sampling_result = ctx->completion->initSampling();
|
|
549
|
+
LOGI("initSampling completed, result: %s", sampling_result ? "true" : "false");
|
|
550
|
+
LOGI("Sampler pointer after init: %p", ctx->completion->ctx_sampling);
|
|
551
|
+
} catch (const std::exception& e) {
|
|
552
|
+
LOGE("Exception in initSampling: %s", e.what());
|
|
553
|
+
delete ctx->completion;
|
|
554
|
+
ctx->completion = nullptr;
|
|
555
|
+
throw_java_exception(env, "java/lang/RuntimeException",
|
|
556
|
+
("Failed to initialize sampling: " + std::string(e.what())).c_str());
|
|
557
|
+
return nullptr;
|
|
558
|
+
} catch (...) {
|
|
559
|
+
LOGE("Unknown exception in initSampling");
|
|
560
|
+
delete ctx->completion;
|
|
561
|
+
ctx->completion = nullptr;
|
|
562
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Unknown error in sampling initialization");
|
|
563
|
+
return nullptr;
|
|
564
|
+
}
|
|
565
|
+
|
|
566
|
+
if (!sampling_result || !ctx->completion->ctx_sampling) {
|
|
567
|
+
LOGE("Failed to initialize sampling - result: %s, sampler: %p",
|
|
568
|
+
sampling_result ? "true" : "false", ctx->completion->ctx_sampling);
|
|
569
|
+
delete ctx->completion;
|
|
570
|
+
ctx->completion = nullptr;
|
|
571
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Failed to initialize sampling context");
|
|
572
|
+
return nullptr;
|
|
573
|
+
}
|
|
574
|
+
|
|
575
|
+
LOGI("Completion context initialized successfully");
|
|
576
|
+
} catch (const std::exception& e) {
|
|
577
|
+
LOGE("Exception during completion context creation: %s", e.what());
|
|
578
|
+
if (ctx->completion) {
|
|
579
|
+
delete ctx->completion;
|
|
580
|
+
ctx->completion = nullptr;
|
|
581
|
+
}
|
|
582
|
+
throw_java_exception(env, "java/lang/RuntimeException",
|
|
583
|
+
("Failed to create completion context: " + std::string(e.what())).c_str());
|
|
584
|
+
return nullptr;
|
|
585
|
+
} catch (...) {
|
|
586
|
+
LOGE("Unknown exception during completion context creation");
|
|
587
|
+
if (ctx->completion) {
|
|
588
|
+
delete ctx->completion;
|
|
589
|
+
ctx->completion = nullptr;
|
|
590
|
+
}
|
|
591
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Unknown error during completion context creation");
|
|
592
|
+
return nullptr;
|
|
593
|
+
}
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
// Set up sampling parameters
|
|
597
|
+
// Note: For now, we'll use the completion context's default parameters
|
|
598
|
+
// TODO: Update sampling parameters with user values
|
|
599
|
+
//
|
|
600
|
+
// Declare variables outside try block so they're accessible later
|
|
601
|
+
std::string generated_text;
|
|
602
|
+
int tokens_generated = 0;
|
|
603
|
+
|
|
604
|
+
try {
|
|
605
|
+
LOGI("Rewinding completion context...");
|
|
606
|
+
try {
|
|
607
|
+
ctx->completion->rewind();
|
|
608
|
+
LOGI("Rewind completed successfully");
|
|
609
|
+
} catch (const std::exception& e) {
|
|
610
|
+
LOGE("Exception in rewind: %s", e.what());
|
|
611
|
+
throw;
|
|
612
|
+
}
|
|
613
|
+
|
|
614
|
+
LOGI("Loading prompt into completion context...");
|
|
615
|
+
try {
|
|
616
|
+
// Validate sampler is properly initialized before loadPrompt
|
|
617
|
+
if (!ctx->completion->ctx_sampling) {
|
|
618
|
+
LOGE("Sampler context is null - reinitializing");
|
|
619
|
+
if (!ctx->completion->initSampling()) {
|
|
620
|
+
LOGE("Failed to reinitialize sampling");
|
|
621
|
+
throw std::runtime_error("Sampler initialization failed");
|
|
622
|
+
}
|
|
623
|
+
LOGI("Sampler reinitialized successfully");
|
|
624
|
+
}
|
|
625
|
+
|
|
626
|
+
ctx->completion->loadPrompt({});
|
|
627
|
+
LOGI("loadPrompt completed successfully");
|
|
628
|
+
} catch (const std::exception& e) {
|
|
629
|
+
LOGE("Exception in loadPrompt: %s", e.what());
|
|
630
|
+
throw;
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
LOGI("Beginning completion generation...");
|
|
634
|
+
try {
|
|
635
|
+
ctx->completion->beginCompletion();
|
|
636
|
+
LOGI("beginCompletion completed successfully");
|
|
637
|
+
} catch (const std::exception& e) {
|
|
638
|
+
LOGE("Exception in beginCompletion: %s", e.what());
|
|
639
|
+
throw;
|
|
640
|
+
}
|
|
641
|
+
|
|
642
|
+
LOGI("Starting token generation loop (max tokens: %d)...", n_predict);
|
|
643
|
+
|
|
644
|
+
while (tokens_generated < n_predict && !ctx->completion->is_interrupted) {
|
|
645
|
+
try {
|
|
646
|
+
LOGI("Generating token %d...", tokens_generated + 1);
|
|
647
|
+
auto token_output = ctx->completion->nextToken();
|
|
648
|
+
|
|
649
|
+
// Check for end-of-sequence (simplified check)
|
|
650
|
+
if (token_output.tok == 2) { // Most models use 2 as EOS token
|
|
651
|
+
LOGI("Reached EOS token, stopping generation");
|
|
652
|
+
break;
|
|
653
|
+
}
|
|
654
|
+
|
|
655
|
+
// Convert token to text
|
|
656
|
+
std::string token_text = capllama::tokens_to_output_formatted_string(ctx->ctx, token_output.tok);
|
|
657
|
+
generated_text += token_text;
|
|
658
|
+
tokens_generated++;
|
|
659
|
+
|
|
660
|
+
LOGI("Generated token %d (ID: %d): %s", tokens_generated, token_output.tok, token_text.c_str());
|
|
661
|
+
|
|
662
|
+
} catch (const std::exception& e) {
|
|
663
|
+
LOGE("Exception during token generation %d: %s", tokens_generated + 1, e.what());
|
|
664
|
+
break;
|
|
665
|
+
} catch (...) {
|
|
666
|
+
LOGE("Unknown exception during token generation %d", tokens_generated + 1);
|
|
667
|
+
break;
|
|
668
|
+
}
|
|
669
|
+
}
|
|
670
|
+
|
|
671
|
+
LOGI("Token generation completed. Generated %d tokens.", tokens_generated);
|
|
672
|
+
|
|
673
|
+
// End completion
|
|
674
|
+
LOGI("Ending completion...");
|
|
675
|
+
ctx->completion->endCompletion();
|
|
676
|
+
|
|
677
|
+
} catch (const std::exception& e) {
|
|
678
|
+
LOGE("Exception during completion process: %s", e.what());
|
|
679
|
+
try {
|
|
680
|
+
ctx->completion->endCompletion();
|
|
681
|
+
} catch (...) {
|
|
682
|
+
LOGE("Failed to properly end completion after exception");
|
|
683
|
+
}
|
|
684
|
+
throw_java_exception(env, "java/lang/RuntimeException",
|
|
685
|
+
("Completion process failed: " + std::string(e.what())).c_str());
|
|
686
|
+
return nullptr;
|
|
687
|
+
} catch (...) {
|
|
688
|
+
LOGE("Unknown exception during completion process");
|
|
689
|
+
try {
|
|
690
|
+
ctx->completion->endCompletion();
|
|
691
|
+
} catch (...) {
|
|
692
|
+
LOGE("Failed to properly end completion after unknown exception");
|
|
693
|
+
}
|
|
694
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Unknown error during completion process");
|
|
695
|
+
return nullptr;
|
|
696
|
+
}
|
|
697
|
+
|
|
698
|
+
LOGI("Completion finished. Generated %d tokens: %s", tokens_generated, generated_text.c_str());
|
|
699
|
+
|
|
700
|
+
// Create result HashMap
|
|
701
|
+
jclass hashMapClass = env->FindClass("java/util/HashMap");
|
|
702
|
+
jmethodID hashMapConstructor = env->GetMethodID(hashMapClass, "<init>", "()V");
|
|
703
|
+
jmethodID putMethod = env->GetMethodID(hashMapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
|
|
704
|
+
|
|
705
|
+
jobject resultMap = env->NewObject(hashMapClass, hashMapConstructor);
|
|
706
|
+
|
|
707
|
+
// Add completion results
|
|
708
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
709
|
+
jni_utils::string_to_jstring(env, "text"), jni_utils::string_to_jstring(env, generated_text));
|
|
710
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
711
|
+
jni_utils::string_to_jstring(env, "content"), jni_utils::string_to_jstring(env, generated_text));
|
|
712
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
713
|
+
jni_utils::string_to_jstring(env, "reasoning_content"), jni_utils::string_to_jstring(env, ""));
|
|
714
|
+
|
|
715
|
+
// Create empty tool_calls array
|
|
716
|
+
jclass arrayListClass = env->FindClass("java/util/ArrayList");
|
|
717
|
+
jmethodID arrayListConstructor = env->GetMethodID(arrayListClass, "<init>", "()V");
|
|
718
|
+
jobject emptyToolCalls = env->NewObject(arrayListClass, arrayListConstructor);
|
|
719
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
720
|
+
jni_utils::string_to_jstring(env, "tool_calls"), emptyToolCalls);
|
|
721
|
+
|
|
722
|
+
// Add token counts and status
|
|
723
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
724
|
+
jni_utils::string_to_jstring(env, "tokens_predicted"),
|
|
725
|
+
env->NewObject(env->FindClass("java/lang/Integer"),
|
|
726
|
+
env->GetMethodID(env->FindClass("java/lang/Integer"), "<init>", "(I)V"), tokens_generated));
|
|
727
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
728
|
+
jni_utils::string_to_jstring(env, "tokens_evaluated"),
|
|
729
|
+
env->NewObject(env->FindClass("java/lang/Integer"),
|
|
730
|
+
env->GetMethodID(env->FindClass("java/lang/Integer"), "<init>", "(I)V"), (jint)prompt_tokens.size()));
|
|
731
|
+
|
|
732
|
+
// Add completion status flags
|
|
733
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
734
|
+
jni_utils::string_to_jstring(env, "truncated"),
|
|
735
|
+
env->NewObject(env->FindClass("java/lang/Boolean"),
|
|
736
|
+
env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"), JNI_FALSE));
|
|
737
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
738
|
+
jni_utils::string_to_jstring(env, "stopped_eos"),
|
|
739
|
+
env->NewObject(env->FindClass("java/lang/Boolean"),
|
|
740
|
+
env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"),
|
|
741
|
+
tokens_generated < n_predict ? JNI_TRUE : JNI_FALSE));
|
|
742
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
743
|
+
jni_utils::string_to_jstring(env, "stopped_limit"),
|
|
744
|
+
env->NewObject(env->FindClass("java/lang/Boolean"),
|
|
745
|
+
env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"),
|
|
746
|
+
tokens_generated >= n_predict ? JNI_TRUE : JNI_FALSE));
|
|
747
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
748
|
+
jni_utils::string_to_jstring(env, "context_full"),
|
|
749
|
+
env->NewObject(env->FindClass("java/lang/Boolean"),
|
|
750
|
+
env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"), JNI_FALSE));
|
|
751
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
752
|
+
jni_utils::string_to_jstring(env, "interrupted"),
|
|
753
|
+
env->NewObject(env->FindClass("java/lang/Boolean"),
|
|
754
|
+
env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"), JNI_FALSE));
|
|
755
|
+
|
|
756
|
+
// Add empty strings for stop reasons
|
|
757
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
758
|
+
jni_utils::string_to_jstring(env, "stopped_word"), jni_utils::string_to_jstring(env, ""));
|
|
759
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
760
|
+
jni_utils::string_to_jstring(env, "stopping_word"), jni_utils::string_to_jstring(env, ""));
|
|
761
|
+
|
|
762
|
+
// Add timing information (basic)
|
|
763
|
+
jobject timingsMap = env->NewObject(hashMapClass, hashMapConstructor);
|
|
764
|
+
env->CallObjectMethod(timingsMap, putMethod,
|
|
765
|
+
jni_utils::string_to_jstring(env, "prompt_n"),
|
|
766
|
+
env->NewObject(env->FindClass("java/lang/Integer"),
|
|
767
|
+
env->GetMethodID(env->FindClass("java/lang/Integer"), "<init>", "(I)V"), (jint)prompt_tokens.size()));
|
|
768
|
+
env->CallObjectMethod(timingsMap, putMethod,
|
|
769
|
+
jni_utils::string_to_jstring(env, "predicted_n"),
|
|
770
|
+
env->NewObject(env->FindClass("java/lang/Integer"),
|
|
771
|
+
env->GetMethodID(env->FindClass("java/lang/Integer"), "<init>", "(I)V"), tokens_generated));
|
|
772
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
773
|
+
jni_utils::string_to_jstring(env, "timings"), timingsMap);
|
|
774
|
+
|
|
775
|
+
LOGI("Completion result created successfully");
|
|
776
|
+
return resultMap;
|
|
408
777
|
|
|
409
778
|
} catch (const std::exception& e) {
|
|
410
779
|
LOGE("Exception in completion: %s", e.what());
|
|
@@ -443,7 +812,7 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_getFormattedChatNative(
|
|
|
443
812
|
std::string messages_str = jstring_to_string(env, messages);
|
|
444
813
|
std::string template_str = jstring_to_string(env, chat_template);
|
|
445
814
|
|
|
446
|
-
|
|
815
|
+
capllama::llama_cap_context* context = it->second.get();
|
|
447
816
|
|
|
448
817
|
// Format chat using the context's method
|
|
449
818
|
std::string result = context->getFormattedChat(messages_str, template_str);
|
|
@@ -463,7 +832,7 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_toggleNativeLogNative(
|
|
|
463
832
|
JNIEnv* env, jobject thiz, jboolean enabled) {
|
|
464
833
|
|
|
465
834
|
try {
|
|
466
|
-
rnllama::rnllama_verbose = jboolean_to_bool(enabled);
|
|
835
|
+
// rnllama::rnllama_verbose = jboolean_to_bool(enabled); // This line is removed as per the edit hint
|
|
467
836
|
LOGI("Native logging %s", enabled ? "enabled" : "disabled");
|
|
468
837
|
return bool_to_jboolean(true);
|
|
469
838
|
} catch (const std::exception& e) {
|
|
@@ -473,7 +842,431 @@ Java_ai_annadata_plugin_capacitor_LlamaCpp_toggleNativeLogNative(
|
|
|
473
842
|
}
|
|
474
843
|
}
|
|
475
844
|
|
|
845
|
+
JNIEXPORT jobject JNICALL
|
|
846
|
+
Java_ai_annadata_plugin_capacitor_LlamaCpp_modelInfoNative(
|
|
847
|
+
JNIEnv* env, jobject thiz, jstring model_path) {
|
|
848
|
+
|
|
849
|
+
try {
|
|
850
|
+
std::string model_path_str = jstring_to_string(env, model_path);
|
|
851
|
+
LOGI("Getting model info for: %s", model_path_str.c_str());
|
|
852
|
+
|
|
853
|
+
// Extract filename from path
|
|
854
|
+
std::string filename = model_path_str;
|
|
855
|
+
size_t last_slash = model_path_str.find_last_of('/');
|
|
856
|
+
if (last_slash != std::string::npos) {
|
|
857
|
+
filename = model_path_str.substr(last_slash + 1);
|
|
858
|
+
}
|
|
859
|
+
LOGI("Extracted filename for model info: %s", filename.c_str());
|
|
860
|
+
|
|
861
|
+
// List all possible paths we should check (same as initContextNative)
|
|
862
|
+
std::vector<std::string> paths_to_check = {
|
|
863
|
+
model_path_str, // Try the original path first
|
|
864
|
+
"/data/data/ai.annadata.llamacpp/files/" + filename,
|
|
865
|
+
"/data/data/ai.annadata.llamacpp/files/Documents/" + filename,
|
|
866
|
+
"/storage/emulated/0/Android/data/ai.annadata.llamacpp/files/" + filename,
|
|
867
|
+
"/storage/emulated/0/Android/data/ai.annadata.llamacpp/files/Documents/" + filename,
|
|
868
|
+
"/storage/emulated/0/Documents/" + filename,
|
|
869
|
+
"/storage/emulated/0/Download/" + filename
|
|
870
|
+
};
|
|
871
|
+
|
|
872
|
+
// Check each path and find the actual file
|
|
873
|
+
std::string full_model_path;
|
|
874
|
+
bool file_found = false;
|
|
875
|
+
|
|
876
|
+
for (const auto& path : paths_to_check) {
|
|
877
|
+
LOGI("Checking path for model info: %s", path.c_str());
|
|
878
|
+
std::ifstream file_check(path, std::ios::binary);
|
|
879
|
+
if (file_check.good()) {
|
|
880
|
+
file_check.seekg(0, std::ios::end);
|
|
881
|
+
std::streamsize file_size = file_check.tellg();
|
|
882
|
+
file_check.seekg(0, std::ios::beg);
|
|
883
|
+
|
|
884
|
+
// Validate file size
|
|
885
|
+
if (file_size < 1024 * 1024) { // Less than 1MB
|
|
886
|
+
LOGE("Model file is too small, likely corrupted: %s", path.c_str());
|
|
887
|
+
file_check.close();
|
|
888
|
+
continue; // Try next path
|
|
889
|
+
}
|
|
890
|
+
|
|
891
|
+
// Check if it's a valid GGUF file by reading the magic number
|
|
892
|
+
char magic[4];
|
|
893
|
+
if (file_check.read(magic, 4)) {
|
|
894
|
+
if (magic[0] == 'G' && magic[1] == 'G' && magic[2] == 'U' && magic[3] == 'F') {
|
|
895
|
+
LOGI("Valid GGUF file detected for model info at: %s", path.c_str());
|
|
896
|
+
full_model_path = path;
|
|
897
|
+
file_found = true;
|
|
898
|
+
file_check.close();
|
|
899
|
+
break;
|
|
900
|
+
} else {
|
|
901
|
+
LOGI("File does not appear to be a GGUF file (magic: %c%c%c%c) at: %s",
|
|
902
|
+
magic[0], magic[1], magic[2], magic[3], path.c_str());
|
|
903
|
+
}
|
|
904
|
+
}
|
|
905
|
+
file_check.close();
|
|
906
|
+
} else {
|
|
907
|
+
LOGI("File not found at: %s", path.c_str());
|
|
908
|
+
}
|
|
909
|
+
}
|
|
910
|
+
|
|
911
|
+
if (!file_found) {
|
|
912
|
+
LOGE("Model file not found in any of the checked paths");
|
|
913
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Model file not found");
|
|
914
|
+
return nullptr;
|
|
915
|
+
}
|
|
916
|
+
|
|
917
|
+
// Now use the found path for getting model info
|
|
918
|
+
std::ifstream file_check(full_model_path, std::ios::binary);
|
|
919
|
+
|
|
920
|
+
// Get file size
|
|
921
|
+
file_check.seekg(0, std::ios::end);
|
|
922
|
+
std::streamsize file_size = file_check.tellg();
|
|
923
|
+
file_check.seekg(0, std::ios::beg);
|
|
924
|
+
|
|
925
|
+
// Check GGUF magic number
|
|
926
|
+
char magic[4];
|
|
927
|
+
if (!file_check.read(magic, 4)) {
|
|
928
|
+
LOGE("Failed to read magic number from: %s", full_model_path.c_str());
|
|
929
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Failed to read model file header");
|
|
930
|
+
return nullptr;
|
|
931
|
+
}
|
|
932
|
+
|
|
933
|
+
if (magic[0] != 'G' || magic[1] != 'G' || magic[2] != 'U' || magic[3] != 'F') {
|
|
934
|
+
LOGE("Invalid GGUF file (magic: %c%c%c%c): %s", magic[0], magic[1], magic[2], magic[3], full_model_path.c_str());
|
|
935
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Invalid GGUF file format");
|
|
936
|
+
return nullptr;
|
|
937
|
+
}
|
|
938
|
+
|
|
939
|
+
// Read GGUF version
|
|
940
|
+
uint32_t version;
|
|
941
|
+
if (!file_check.read(reinterpret_cast<char*>(&version), sizeof(version))) {
|
|
942
|
+
LOGE("Failed to read GGUF version from: %s", full_model_path.c_str());
|
|
943
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Failed to read GGUF version");
|
|
944
|
+
return nullptr;
|
|
945
|
+
}
|
|
946
|
+
|
|
947
|
+
file_check.close();
|
|
948
|
+
|
|
949
|
+
// Create Java HashMap
|
|
950
|
+
jclass hashMapClass = env->FindClass("java/util/HashMap");
|
|
951
|
+
jmethodID hashMapConstructor = env->GetMethodID(hashMapClass, "<init>", "()V");
|
|
952
|
+
jmethodID putMethod = env->GetMethodID(hashMapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
|
|
953
|
+
|
|
954
|
+
jobject hashMap = env->NewObject(hashMapClass, hashMapConstructor);
|
|
955
|
+
|
|
956
|
+
// Add model info to HashMap
|
|
957
|
+
env->CallObjectMethod(hashMap, putMethod,
|
|
958
|
+
string_to_jstring(env, "path"),
|
|
959
|
+
string_to_jstring(env, full_model_path));
|
|
960
|
+
|
|
961
|
+
env->CallObjectMethod(hashMap, putMethod,
|
|
962
|
+
string_to_jstring(env, "size"),
|
|
963
|
+
env->NewObject(env->FindClass("java/lang/Long"),
|
|
964
|
+
env->GetMethodID(env->FindClass("java/lang/Long"), "<init>", "(J)V"),
|
|
965
|
+
static_cast<jlong>(file_size)));
|
|
966
|
+
|
|
967
|
+
env->CallObjectMethod(hashMap, putMethod,
|
|
968
|
+
string_to_jstring(env, "desc"),
|
|
969
|
+
string_to_jstring(env, "GGUF Model (v" + std::to_string(version) + ")"));
|
|
970
|
+
|
|
971
|
+
env->CallObjectMethod(hashMap, putMethod,
|
|
972
|
+
string_to_jstring(env, "nEmbd"),
|
|
973
|
+
env->NewObject(env->FindClass("java/lang/Integer"),
|
|
974
|
+
env->GetMethodID(env->FindClass("java/lang/Integer"), "<init>", "(I)V"),
|
|
975
|
+
0)); // Will be filled by actual model loading
|
|
976
|
+
|
|
977
|
+
env->CallObjectMethod(hashMap, putMethod,
|
|
978
|
+
string_to_jstring(env, "nParams"),
|
|
979
|
+
env->NewObject(env->FindClass("java/lang/Integer"),
|
|
980
|
+
env->GetMethodID(env->FindClass("java/lang/Integer"), "<init>", "(I)V"),
|
|
981
|
+
0)); // Will be filled by actual model loading
|
|
982
|
+
|
|
983
|
+
LOGI("Model info retrieved successfully from %s: size=%ld, version=%u", full_model_path.c_str(), file_size, version);
|
|
984
|
+
return hashMap;
|
|
985
|
+
|
|
986
|
+
} catch (const std::exception& e) {
|
|
987
|
+
LOGE("Exception in modelInfo: %s", e.what());
|
|
988
|
+
throw_java_exception(env, "java/lang/RuntimeException", e.what());
|
|
989
|
+
return nullptr;
|
|
990
|
+
}
|
|
991
|
+
}
|
|
992
|
+
|
|
993
|
+
|
|
476
994
|
|
|
995
|
+
JNIEXPORT jstring JNICALL
|
|
996
|
+
Java_ai_annadata_plugin_capacitor_LlamaCpp_downloadModelNative(
|
|
997
|
+
JNIEnv* env, jobject thiz, jstring url, jstring filename) {
|
|
998
|
+
|
|
999
|
+
try {
|
|
1000
|
+
std::string url_str = jstring_to_string(env, url);
|
|
1001
|
+
std::string filename_str = jstring_to_string(env, filename);
|
|
1002
|
+
|
|
1003
|
+
LOGI("Preparing download path for model: %s", filename_str.c_str());
|
|
1004
|
+
|
|
1005
|
+
// Determine local storage path (use external storage for large files)
|
|
1006
|
+
std::string local_path = "/storage/emulated/0/Android/data/ai.annadata.llamacpp/files/Models/" + filename_str;
|
|
1007
|
+
|
|
1008
|
+
// Create directory if it doesn't exist
|
|
1009
|
+
std::string dir_path = "/storage/emulated/0/Android/data/ai.annadata.llamacpp/files/Models/";
|
|
1010
|
+
std::filesystem::create_directories(dir_path);
|
|
1011
|
+
|
|
1012
|
+
LOGI("Download path prepared: %s", local_path.c_str());
|
|
1013
|
+
|
|
1014
|
+
return string_to_jstring(env, local_path);
|
|
1015
|
+
|
|
1016
|
+
} catch (const std::exception& e) {
|
|
1017
|
+
LOGE("Exception in downloadModel: %s", e.what());
|
|
1018
|
+
throw_java_exception(env, "java/lang/RuntimeException", e.what());
|
|
1019
|
+
return nullptr;
|
|
1020
|
+
}
|
|
1021
|
+
}
|
|
1022
|
+
|
|
1023
|
+
JNIEXPORT jobject JNICALL
|
|
1024
|
+
Java_ai_annadata_plugin_capacitor_LlamaCpp_getDownloadProgressNative(
|
|
1025
|
+
JNIEnv* env, jobject thiz, jstring url) {
|
|
1026
|
+
|
|
1027
|
+
try {
|
|
1028
|
+
// For now, return a placeholder since we'll handle download in Java
|
|
1029
|
+
// This can be enhanced later to track actual download progress
|
|
1030
|
+
|
|
1031
|
+
jclass hashMapClass = env->FindClass("java/util/HashMap");
|
|
1032
|
+
jmethodID hashMapConstructor = env->GetMethodID(hashMapClass, "<init>", "()V");
|
|
1033
|
+
jmethodID putMethod = env->GetMethodID(hashMapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
|
|
1034
|
+
|
|
1035
|
+
jobject hashMap = env->NewObject(hashMapClass, hashMapConstructor);
|
|
1036
|
+
|
|
1037
|
+
// Return placeholder progress info
|
|
1038
|
+
env->CallObjectMethod(hashMap, putMethod,
|
|
1039
|
+
string_to_jstring(env, "progress"),
|
|
1040
|
+
env->NewObject(env->FindClass("java/lang/Double"),
|
|
1041
|
+
env->GetMethodID(env->FindClass("java/lang/Double"), "<init>", "(D)V"),
|
|
1042
|
+
0.0));
|
|
1043
|
+
|
|
1044
|
+
env->CallObjectMethod(hashMap, putMethod,
|
|
1045
|
+
string_to_jstring(env, "completed"),
|
|
1046
|
+
env->NewObject(env->FindClass("java/lang/Boolean"),
|
|
1047
|
+
env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"),
|
|
1048
|
+
false));
|
|
1049
|
+
|
|
1050
|
+
env->CallObjectMethod(hashMap, putMethod,
|
|
1051
|
+
string_to_jstring(env, "failed"),
|
|
1052
|
+
env->NewObject(env->FindClass("java/lang/Boolean"),
|
|
1053
|
+
env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"),
|
|
1054
|
+
false));
|
|
1055
|
+
|
|
1056
|
+
return hashMap;
|
|
1057
|
+
|
|
1058
|
+
} catch (const std::exception& e) {
|
|
1059
|
+
LOGE("Exception in getDownloadProgress: %s", e.what());
|
|
1060
|
+
throw_java_exception(env, "java/lang/RuntimeException", e.what());
|
|
1061
|
+
return nullptr;
|
|
1062
|
+
}
|
|
1063
|
+
}
|
|
1064
|
+
|
|
1065
|
+
JNIEXPORT jboolean JNICALL
|
|
1066
|
+
Java_ai_annadata_plugin_capacitor_LlamaCpp_cancelDownloadNative(
|
|
1067
|
+
JNIEnv* env, jobject thiz, jstring url) {
|
|
1068
|
+
|
|
1069
|
+
try {
|
|
1070
|
+
// For now, return false since we'll handle download cancellation in Java
|
|
1071
|
+
// This can be enhanced later to actually cancel downloads
|
|
1072
|
+
return JNI_FALSE;
|
|
1073
|
+
|
|
1074
|
+
} catch (const std::exception& e) {
|
|
1075
|
+
LOGE("Exception in cancelDownload: %s", e.what());
|
|
1076
|
+
throw_java_exception(env, "java/lang/RuntimeException", e.what());
|
|
1077
|
+
return JNI_FALSE;
|
|
1078
|
+
}
|
|
1079
|
+
}
|
|
1080
|
+
|
|
1081
|
+
JNIEXPORT jobject JNICALL
|
|
1082
|
+
Java_ai_annadata_plugin_capacitor_LlamaCpp_getAvailableModelsNative(
|
|
1083
|
+
JNIEnv* env, jobject thiz) {
|
|
1084
|
+
|
|
1085
|
+
try {
|
|
1086
|
+
std::string models_dir = "/storage/emulated/0/Android/data/ai.annadata.llamacpp/files/Models/";
|
|
1087
|
+
|
|
1088
|
+
// Create Java ArrayList
|
|
1089
|
+
jclass arrayListClass = env->FindClass("java/util/ArrayList");
|
|
1090
|
+
jmethodID arrayListConstructor = env->GetMethodID(arrayListClass, "<init>", "()V");
|
|
1091
|
+
jmethodID addMethod = env->GetMethodID(arrayListClass, "add", "(Ljava/lang/Object;)Z");
|
|
1092
|
+
|
|
1093
|
+
jobject arrayList = env->NewObject(arrayListClass, arrayListConstructor);
|
|
1094
|
+
|
|
1095
|
+
if (std::filesystem::exists(models_dir)) {
|
|
1096
|
+
for (const auto& entry : std::filesystem::directory_iterator(models_dir)) {
|
|
1097
|
+
if (entry.is_regular_file() && entry.path().extension() == ".gguf") {
|
|
1098
|
+
std::string filename = entry.path().filename().string();
|
|
1099
|
+
std::string full_path = entry.path().string();
|
|
1100
|
+
size_t file_size = entry.file_size();
|
|
1101
|
+
|
|
1102
|
+
// Create model info HashMap
|
|
1103
|
+
jclass hashMapClass = env->FindClass("java/util/HashMap");
|
|
1104
|
+
jmethodID hashMapConstructor = env->GetMethodID(hashMapClass, "<init>", "()V");
|
|
1105
|
+
jmethodID putMethod = env->GetMethodID(hashMapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
|
|
1106
|
+
|
|
1107
|
+
jobject modelInfo = env->NewObject(hashMapClass, hashMapConstructor);
|
|
1108
|
+
|
|
1109
|
+
env->CallObjectMethod(modelInfo, putMethod,
|
|
1110
|
+
string_to_jstring(env, "name"),
|
|
1111
|
+
string_to_jstring(env, filename));
|
|
1112
|
+
|
|
1113
|
+
env->CallObjectMethod(modelInfo, putMethod,
|
|
1114
|
+
string_to_jstring(env, "path"),
|
|
1115
|
+
string_to_jstring(env, full_path));
|
|
1116
|
+
|
|
1117
|
+
env->CallObjectMethod(modelInfo, putMethod,
|
|
1118
|
+
string_to_jstring(env, "size"),
|
|
1119
|
+
env->NewObject(env->FindClass("java/lang/Long"),
|
|
1120
|
+
env->GetMethodID(env->FindClass("java/lang/Long"), "<init>", "(J)V"),
|
|
1121
|
+
static_cast<jlong>(file_size)));
|
|
1122
|
+
|
|
1123
|
+
// Add to ArrayList
|
|
1124
|
+
env->CallBooleanMethod(arrayList, addMethod, modelInfo);
|
|
1125
|
+
}
|
|
1126
|
+
}
|
|
1127
|
+
}
|
|
1128
|
+
|
|
1129
|
+
return arrayList;
|
|
1130
|
+
|
|
1131
|
+
} catch (const std::exception& e) {
|
|
1132
|
+
LOGE("Exception in getAvailableModels: %s", e.what());
|
|
1133
|
+
throw_java_exception(env, "java/lang/RuntimeException", e.what());
|
|
1134
|
+
return nullptr;
|
|
1135
|
+
}
|
|
1136
|
+
}
|
|
1137
|
+
|
|
1138
|
+
// MARK: - Tokenization methods
|
|
1139
|
+
|
|
1140
|
+
JNIEXPORT jobject JNICALL
|
|
1141
|
+
Java_ai_annadata_plugin_capacitor_LlamaCpp_tokenizeNative(
|
|
1142
|
+
JNIEnv* env, jobject thiz, jlong contextId, jstring text, jobjectArray imagePaths) {
|
|
1143
|
+
|
|
1144
|
+
try {
|
|
1145
|
+
LOGI("Tokenizing with context ID: %ld", contextId);
|
|
1146
|
+
|
|
1147
|
+
std::string text_str = jni_utils::jstring_to_string(env, text);
|
|
1148
|
+
LOGI("Text to tokenize: %s", text_str.c_str());
|
|
1149
|
+
|
|
1150
|
+
// Find the context
|
|
1151
|
+
auto it = contexts.find(contextId);
|
|
1152
|
+
if (it == contexts.end()) {
|
|
1153
|
+
LOGE("Context not found: %ld", contextId);
|
|
1154
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Context not found");
|
|
1155
|
+
return nullptr;
|
|
1156
|
+
}
|
|
1157
|
+
|
|
1158
|
+
auto& ctx = it->second;
|
|
1159
|
+
if (!ctx || !ctx->ctx) {
|
|
1160
|
+
LOGE("Invalid context or llama context is null");
|
|
1161
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Invalid context");
|
|
1162
|
+
return nullptr;
|
|
1163
|
+
}
|
|
1164
|
+
|
|
1165
|
+
// Tokenize the text using the context's tokenize method
|
|
1166
|
+
capllama::llama_cap_tokenize_result tokenize_result = ctx->tokenize(text_str, {});
|
|
1167
|
+
std::vector<llama_token> tokens = tokenize_result.tokens;
|
|
1168
|
+
|
|
1169
|
+
LOGI("Tokenized %zu tokens", tokens.size());
|
|
1170
|
+
|
|
1171
|
+
// Create Java HashMap for result
|
|
1172
|
+
jclass hashMapClass = env->FindClass("java/util/HashMap");
|
|
1173
|
+
jmethodID hashMapConstructor = env->GetMethodID(hashMapClass, "<init>", "()V");
|
|
1174
|
+
jmethodID putMethod = env->GetMethodID(hashMapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
|
|
1175
|
+
|
|
1176
|
+
jobject resultMap = env->NewObject(hashMapClass, hashMapConstructor);
|
|
1177
|
+
|
|
1178
|
+
// Create Java ArrayList for tokens
|
|
1179
|
+
jclass arrayListClass = env->FindClass("java/util/ArrayList");
|
|
1180
|
+
jmethodID arrayListConstructor = env->GetMethodID(arrayListClass, "<init>", "()V");
|
|
1181
|
+
jmethodID addMethod = env->GetMethodID(arrayListClass, "add", "(Ljava/lang/Object;)Z");
|
|
1182
|
+
|
|
1183
|
+
jobject tokensArray = env->NewObject(arrayListClass, arrayListConstructor);
|
|
1184
|
+
|
|
1185
|
+
// Add tokens to ArrayList
|
|
1186
|
+
jclass integerClass = env->FindClass("java/lang/Integer");
|
|
1187
|
+
jmethodID integerConstructor = env->GetMethodID(integerClass, "<init>", "(I)V");
|
|
1188
|
+
|
|
1189
|
+
for (llama_token token : tokens) {
|
|
1190
|
+
jobject jToken = env->NewObject(integerClass, integerConstructor, static_cast<jint>(token));
|
|
1191
|
+
env->CallBooleanMethod(tokensArray, addMethod, jToken);
|
|
1192
|
+
env->DeleteLocalRef(jToken);
|
|
1193
|
+
}
|
|
1194
|
+
|
|
1195
|
+
// Create empty arrays for other fields
|
|
1196
|
+
jobject emptyBitmapHashes = env->NewObject(arrayListClass, arrayListConstructor);
|
|
1197
|
+
jobject emptyChunkPos = env->NewObject(arrayListClass, arrayListConstructor);
|
|
1198
|
+
jobject emptyChunkPosImages = env->NewObject(arrayListClass, arrayListConstructor);
|
|
1199
|
+
|
|
1200
|
+
// Put all data into result map
|
|
1201
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
1202
|
+
jni_utils::string_to_jstring(env, "tokens"), tokensArray);
|
|
1203
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
1204
|
+
jni_utils::string_to_jstring(env, "has_images"),
|
|
1205
|
+
env->NewObject(env->FindClass("java/lang/Boolean"),
|
|
1206
|
+
env->GetMethodID(env->FindClass("java/lang/Boolean"), "<init>", "(Z)V"), JNI_FALSE));
|
|
1207
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
1208
|
+
jni_utils::string_to_jstring(env, "bitmap_hashes"), emptyBitmapHashes);
|
|
1209
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
1210
|
+
jni_utils::string_to_jstring(env, "chunk_pos"), emptyChunkPos);
|
|
1211
|
+
env->CallObjectMethod(resultMap, putMethod,
|
|
1212
|
+
jni_utils::string_to_jstring(env, "chunk_pos_images"), emptyChunkPosImages);
|
|
1213
|
+
|
|
1214
|
+
LOGI("Tokenization completed successfully");
|
|
1215
|
+
return resultMap;
|
|
1216
|
+
|
|
1217
|
+
} catch (const std::exception& e) {
|
|
1218
|
+
LOGE("Exception in tokenize: %s", e.what());
|
|
1219
|
+
throw_java_exception(env, "java/lang/RuntimeException", e.what());
|
|
1220
|
+
return nullptr;
|
|
1221
|
+
}
|
|
1222
|
+
}
|
|
1223
|
+
|
|
1224
|
+
JNIEXPORT jstring JNICALL
|
|
1225
|
+
Java_ai_annadata_plugin_capacitor_LlamaCpp_detokenizeNative(
|
|
1226
|
+
JNIEnv* env, jobject thiz, jlong contextId, jintArray tokens) {
|
|
1227
|
+
|
|
1228
|
+
try {
|
|
1229
|
+
LOGI("Detokenizing with context ID: %ld", contextId);
|
|
1230
|
+
|
|
1231
|
+
// Find the context
|
|
1232
|
+
auto it = contexts.find(contextId);
|
|
1233
|
+
if (it == contexts.end()) {
|
|
1234
|
+
LOGE("Context not found: %ld", contextId);
|
|
1235
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Context not found");
|
|
1236
|
+
return nullptr;
|
|
1237
|
+
}
|
|
1238
|
+
|
|
1239
|
+
auto& ctx = it->second;
|
|
1240
|
+
if (!ctx || !ctx->ctx) {
|
|
1241
|
+
LOGE("Invalid context or llama context is null");
|
|
1242
|
+
throw_java_exception(env, "java/lang/RuntimeException", "Invalid context");
|
|
1243
|
+
return nullptr;
|
|
1244
|
+
}
|
|
1245
|
+
|
|
1246
|
+
// Convert Java int array to C++ vector
|
|
1247
|
+
jsize length = env->GetArrayLength(tokens);
|
|
1248
|
+
jint* tokenArray = env->GetIntArrayElements(tokens, nullptr);
|
|
1249
|
+
|
|
1250
|
+
std::vector<llama_token> llamaTokens;
|
|
1251
|
+
for (jsize i = 0; i < length; i++) {
|
|
1252
|
+
llamaTokens.push_back(static_cast<llama_token>(tokenArray[i]));
|
|
1253
|
+
}
|
|
1254
|
+
|
|
1255
|
+
env->ReleaseIntArrayElements(tokens, tokenArray, JNI_ABORT);
|
|
1256
|
+
|
|
1257
|
+
// Detokenize using llama.cpp
|
|
1258
|
+
std::string result = capllama::tokens_to_str(ctx->ctx, llamaTokens.begin(), llamaTokens.end());
|
|
1259
|
+
|
|
1260
|
+
LOGI("Detokenized to: %s", result.c_str());
|
|
1261
|
+
|
|
1262
|
+
return jni_utils::string_to_jstring(env, result);
|
|
1263
|
+
|
|
1264
|
+
} catch (const std::exception& e) {
|
|
1265
|
+
LOGE("Exception in detokenize: %s", e.what());
|
|
1266
|
+
throw_java_exception(env, "java/lang/RuntimeException", e.what());
|
|
1267
|
+
return nullptr;
|
|
1268
|
+
}
|
|
1269
|
+
}
|
|
477
1270
|
|
|
478
1271
|
} // extern "C"
|
|
479
1272
|
|