@agorapete/wllama 3.5.1-q2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.gitmodules +3 -0
- package/.prettierignore +38 -0
- package/AGENTS.md +1 -0
- package/CMakeLists.txt +131 -0
- package/LICENCE +21 -0
- package/README-dev.md +178 -0
- package/README.md +225 -0
- package/README_banner.png +0 -0
- package/assets/screenshot_0.png +0 -0
- package/cpp/generate_glue_prototype.js +115 -0
- package/cpp/glue.hpp +664 -0
- package/cpp/test_glue.cpp +80 -0
- package/cpp/wllama-context.h +1172 -0
- package/cpp/wllama-fs.h +148 -0
- package/cpp/wllama.cpp +187 -0
- package/cpp/wllama.h +6 -0
- package/esm/cache-manager.d.ts +130 -0
- package/esm/debug.d.ts +28 -0
- package/esm/glue/glue.d.ts +22 -0
- package/esm/glue/messages.d.ts +146 -0
- package/esm/huggingface.d.ts +31 -0
- package/esm/index.cjs +3406 -0
- package/esm/index.d.ts +8 -0
- package/esm/index.js +3387 -0
- package/esm/index.min.js +1 -0
- package/esm/index.min.js.map +1 -0
- package/esm/model-manager.d.ts +136 -0
- package/esm/storage/cos.d.ts +36 -0
- package/esm/storage/index.d.ts +33 -0
- package/esm/storage/opfs.d.ts +12 -0
- package/esm/types/oai-compat.d.ts +278 -0
- package/esm/types/types.d.ts +112 -0
- package/esm/utils.d.ts +119 -0
- package/esm/wasm/source-map.d.ts +1 -0
- package/esm/wasm/wllama.wasm +0 -0
- package/esm/wasm-from-cdn.d.ts +8 -0
- package/esm/wllama.d.ts +397 -0
- package/esm/worker.d.ts +92 -0
- package/esm/workers-code/generated.d.ts +4 -0
- package/guides/intro-v2.md +132 -0
- package/guides/intro-v3.1.md +40 -0
- package/guides/intro-v3.md +230 -0
- package/index.ts +1 -0
- package/package.json +71 -0
- package/scripts/bisect_test.sh +33 -0
- package/scripts/build_hf_space.sh +26 -0
- package/scripts/build_source_map.js +269 -0
- package/scripts/build_wasm.sh +19 -0
- package/scripts/build_worker.sh +38 -0
- package/scripts/check_debug_build.js +30 -0
- package/scripts/check_package_size.js +25 -0
- package/scripts/docker-compose.yml +76 -0
- package/scripts/generate_wasm_from_cdn.js +24 -0
- package/scripts/http_server.js +44 -0
- package/scripts/post_build.sh +32 -0
- package/src/cache-manager.ts +358 -0
- package/src/debug.ts +111 -0
- package/src/glue/glue.ts +291 -0
- package/src/glue/messages.ts +773 -0
- package/src/huggingface.ts +151 -0
- package/src/index.ts +8 -0
- package/src/mjs.test.ts +44 -0
- package/src/model-manager.test.ts +200 -0
- package/src/model-manager.ts +359 -0
- package/src/storage/cos.test.ts +83 -0
- package/src/storage/cos.ts +171 -0
- package/src/storage/index.ts +40 -0
- package/src/storage/opfs.ts +119 -0
- package/src/types/oai-compat.ts +342 -0
- package/src/types/types.ts +133 -0
- package/src/utils.test.ts +231 -0
- package/src/utils.ts +403 -0
- package/src/wasm/source-map.ts +7 -0
- package/src/wasm/wllama.js +1 -0
- package/src/wasm/wllama.wasm +0 -0
- package/src/wasm-from-cdn.ts +13 -0
- package/src/wllama.test.ts +392 -0
- package/src/wllama.ts +1138 -0
- package/src/wllama.wgpu.test.ts +62 -0
- package/src/worker.ts +443 -0
- package/src/workers-code/generated.ts +11 -0
- package/src/workers-code/llama-cpp.js +511 -0
- package/src/workers-code/opfs-utils.js +150 -0
- package/tsconfig.build.json +34 -0
- package/tsup.config.ts +23 -0
- package/vitest.config.ts +61 -0
|
@@ -0,0 +1,1172 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <iostream>
|
|
4
|
+
#include <vector>
|
|
5
|
+
#include <string>
|
|
6
|
+
#include <sstream>
|
|
7
|
+
#include <stdio.h>
|
|
8
|
+
#include <cmath>
|
|
9
|
+
#include <fstream>
|
|
10
|
+
|
|
11
|
+
#include "llama.h"
|
|
12
|
+
#include "common.h"
|
|
13
|
+
#include "sampling.h"
|
|
14
|
+
#include "chat.h"
|
|
15
|
+
#include "fit.h"
|
|
16
|
+
#include "log.h"
|
|
17
|
+
#include "download.h"
|
|
18
|
+
#include "wllama.h"
|
|
19
|
+
|
|
20
|
+
#include "server-context.h"
|
|
21
|
+
#include "server-queue.h"
|
|
22
|
+
|
|
23
|
+
#include "ggml-cpu.h"
|
|
24
|
+
#include "ggml-backend.h"
|
|
25
|
+
|
|
26
|
+
#include "glue.hpp"
|
|
27
|
+
|
|
28
|
+
#ifdef WLLAMA_TEST_BACKEND
|
|
29
|
+
int main_test_backend_ops(int argc, char **argv);
|
|
30
|
+
#else
|
|
31
|
+
int main_test_backend_ops(int, char **)
|
|
32
|
+
{
|
|
33
|
+
fprintf(stderr, "@@ERROR@@test-backend-ops is not enabled, please refer to README-dev.md for how to build it\n");
|
|
34
|
+
return -1000;
|
|
35
|
+
}
|
|
36
|
+
#endif
|
|
37
|
+
|
|
38
|
+
#define PARSE_REQ(msg_typename) \
|
|
39
|
+
msg_typename req; \
|
|
40
|
+
glue_inbuf inbuf(req_raw); \
|
|
41
|
+
req.handler.deserialize(inbuf);
|
|
42
|
+
|
|
43
|
+
// for debugging
|
|
44
|
+
enum TEST_STACK_TRACE
|
|
45
|
+
{
|
|
46
|
+
TEST_STACK_TRACE_NONE = 0,
|
|
47
|
+
TEST_STACK_TRACE_ABORT = 1,
|
|
48
|
+
TEST_STACK_TRACE_OOB = 2,
|
|
49
|
+
};
|
|
50
|
+
static TEST_STACK_TRACE test_stack_trace = TEST_STACK_TRACE_NONE;
|
|
51
|
+
extern "C" void __real_abort(void);
|
|
52
|
+
extern "C" void __wrap_abort(void)
|
|
53
|
+
{
|
|
54
|
+
char buf[4096];
|
|
55
|
+
emscripten_get_callstack(EM_LOG_JS_STACK | EM_LOG_NO_PATHS, buf, sizeof(buf));
|
|
56
|
+
for (size_t i = 0; i < sizeof(buf); i++)
|
|
57
|
+
{
|
|
58
|
+
if (buf[i] == '\n')
|
|
59
|
+
buf[i] = '|';
|
|
60
|
+
}
|
|
61
|
+
fprintf(stderr, "@@STACK@@%s\n", buf);
|
|
62
|
+
fflush(stderr);
|
|
63
|
+
__real_abort();
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
static bool force_single_thread = false;
|
|
67
|
+
extern "C" struct ggml_cplan __real_ggml_graph_plan(
|
|
68
|
+
const struct ggml_cgraph *cgraph,
|
|
69
|
+
int n_threads,
|
|
70
|
+
struct ggml_threadpool *threadpool);
|
|
71
|
+
|
|
72
|
+
extern "C" struct ggml_cplan __wrap_ggml_graph_plan(
|
|
73
|
+
const struct ggml_cgraph *cgraph,
|
|
74
|
+
int n_threads,
|
|
75
|
+
struct ggml_threadpool *threadpool)
|
|
76
|
+
{
|
|
77
|
+
return __real_ggml_graph_plan(cgraph, force_single_thread ? 1 : n_threads, threadpool);
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
inline std::vector<char> convert_string_to_buf(std::string &input)
|
|
81
|
+
{
|
|
82
|
+
std::vector<char> output;
|
|
83
|
+
output.reserve(input.size());
|
|
84
|
+
output.insert(output.end(), input.begin(), input.end());
|
|
85
|
+
return output;
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
inline static ggml_type kv_cache_type_from_str(const std::string &s)
|
|
89
|
+
{
|
|
90
|
+
if (s == "f32")
|
|
91
|
+
return GGML_TYPE_F32;
|
|
92
|
+
if (s == "f16")
|
|
93
|
+
return GGML_TYPE_F16;
|
|
94
|
+
if (s == "q8_0")
|
|
95
|
+
return GGML_TYPE_Q8_0;
|
|
96
|
+
if (s == "q4_0")
|
|
97
|
+
return GGML_TYPE_Q4_0;
|
|
98
|
+
if (s == "q4_1")
|
|
99
|
+
return GGML_TYPE_Q4_1;
|
|
100
|
+
if (s == "q5_0")
|
|
101
|
+
return GGML_TYPE_Q5_0;
|
|
102
|
+
if (s == "q5_1")
|
|
103
|
+
return GGML_TYPE_Q5_1;
|
|
104
|
+
throw std::runtime_error("Invalid cache type: " + s);
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
inline static enum llama_pooling_type pooling_type_from_str(const std::string &s)
|
|
108
|
+
{
|
|
109
|
+
// legacy values
|
|
110
|
+
if (s == "LLAMA_POOLING_TYPE_UNSPECIFIED")
|
|
111
|
+
return LLAMA_POOLING_TYPE_UNSPECIFIED;
|
|
112
|
+
if (s == "LLAMA_POOLING_TYPE_NONE")
|
|
113
|
+
return LLAMA_POOLING_TYPE_NONE;
|
|
114
|
+
if (s == "LLAMA_POOLING_TYPE_MEAN")
|
|
115
|
+
return LLAMA_POOLING_TYPE_MEAN;
|
|
116
|
+
if (s == "LLAMA_POOLING_TYPE_CLS")
|
|
117
|
+
return LLAMA_POOLING_TYPE_CLS;
|
|
118
|
+
// new values
|
|
119
|
+
if (s == "unspecified")
|
|
120
|
+
return LLAMA_POOLING_TYPE_UNSPECIFIED;
|
|
121
|
+
if (s == "none")
|
|
122
|
+
return LLAMA_POOLING_TYPE_NONE;
|
|
123
|
+
if (s == "mean")
|
|
124
|
+
return LLAMA_POOLING_TYPE_MEAN;
|
|
125
|
+
if (s == "cls")
|
|
126
|
+
return LLAMA_POOLING_TYPE_CLS;
|
|
127
|
+
if (s == "last")
|
|
128
|
+
return LLAMA_POOLING_TYPE_LAST;
|
|
129
|
+
if (s == "rank")
|
|
130
|
+
return LLAMA_POOLING_TYPE_RANK;
|
|
131
|
+
// for internal wllama testing
|
|
132
|
+
if (s == "test_stack_trace_abort")
|
|
133
|
+
{
|
|
134
|
+
test_stack_trace = TEST_STACK_TRACE_ABORT;
|
|
135
|
+
return LLAMA_POOLING_TYPE_NONE;
|
|
136
|
+
}
|
|
137
|
+
if (s == "test_stack_trace_oob")
|
|
138
|
+
{
|
|
139
|
+
test_stack_trace = TEST_STACK_TRACE_OOB;
|
|
140
|
+
return LLAMA_POOLING_TYPE_NONE;
|
|
141
|
+
}
|
|
142
|
+
throw std::runtime_error("Invalid pooling type: " + s);
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
inline static llama_rope_scaling_type rope_scaling_type_from_str(const std::string &s)
|
|
146
|
+
{
|
|
147
|
+
if (s == "LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED")
|
|
148
|
+
return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
|
|
149
|
+
if (s == "LLAMA_ROPE_SCALING_TYPE_NONE")
|
|
150
|
+
return LLAMA_ROPE_SCALING_TYPE_NONE;
|
|
151
|
+
if (s == "LLAMA_ROPE_SCALING_TYPE_LINEAR")
|
|
152
|
+
return LLAMA_ROPE_SCALING_TYPE_LINEAR;
|
|
153
|
+
if (s == "LLAMA_ROPE_SCALING_TYPE_YARN")
|
|
154
|
+
return LLAMA_ROPE_SCALING_TYPE_YARN;
|
|
155
|
+
throw std::runtime_error("Invalid RoPE scaling type: " + s);
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
inline static common_reasoning_format reasoning_format_from_str(const std::string &s)
|
|
159
|
+
{
|
|
160
|
+
if (s == "none")
|
|
161
|
+
return COMMON_REASONING_FORMAT_NONE;
|
|
162
|
+
if (s == "deepseek-legacy")
|
|
163
|
+
return COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY;
|
|
164
|
+
if (s == "deepseek")
|
|
165
|
+
return COMMON_REASONING_FORMAT_DEEPSEEK;
|
|
166
|
+
throw std::runtime_error("Invalid reasoning format: " + s);
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
inline static llama_model_kv_override parse_kv_override(const std::string &key, const std::string &val_str)
|
|
170
|
+
{
|
|
171
|
+
llama_model_kv_override kvo;
|
|
172
|
+
strncpy(kvo.key, key.c_str(), sizeof(kvo.key) - 1);
|
|
173
|
+
kvo.key[sizeof(kvo.key) - 1] = '\0';
|
|
174
|
+
auto colon = val_str.find(':');
|
|
175
|
+
if (colon == std::string::npos)
|
|
176
|
+
throw std::runtime_error("Invalid kv_override value, expected TYPE:value: " + val_str);
|
|
177
|
+
const std::string type_str = val_str.substr(0, colon);
|
|
178
|
+
const std::string value_str = val_str.substr(colon + 1);
|
|
179
|
+
if (type_str == "int")
|
|
180
|
+
{
|
|
181
|
+
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
|
|
182
|
+
kvo.val_i64 = std::stoll(value_str);
|
|
183
|
+
}
|
|
184
|
+
else if (type_str == "float")
|
|
185
|
+
{
|
|
186
|
+
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
|
|
187
|
+
kvo.val_f64 = std::stod(value_str);
|
|
188
|
+
}
|
|
189
|
+
else if (type_str == "bool")
|
|
190
|
+
{
|
|
191
|
+
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
|
|
192
|
+
if (value_str == "true" || value_str == "1")
|
|
193
|
+
{
|
|
194
|
+
kvo.val_bool = true;
|
|
195
|
+
}
|
|
196
|
+
else if (value_str == "false" || value_str == "0")
|
|
197
|
+
{
|
|
198
|
+
kvo.val_bool = false;
|
|
199
|
+
}
|
|
200
|
+
else
|
|
201
|
+
{
|
|
202
|
+
throw std::runtime_error("Invalid bool value for kv_override: " + value_str);
|
|
203
|
+
}
|
|
204
|
+
}
|
|
205
|
+
else if (type_str == "str")
|
|
206
|
+
{
|
|
207
|
+
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
|
|
208
|
+
strncpy(kvo.val_str, value_str.c_str(), sizeof(kvo.val_str) - 1);
|
|
209
|
+
kvo.val_str[sizeof(kvo.val_str) - 1] = '\0';
|
|
210
|
+
}
|
|
211
|
+
else
|
|
212
|
+
{
|
|
213
|
+
throw std::runtime_error("Invalid kv_override type: " + type_str);
|
|
214
|
+
}
|
|
215
|
+
return kvo;
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
class app_exception : public std::exception
|
|
219
|
+
{
|
|
220
|
+
public:
|
|
221
|
+
app_exception(const std::string &msg) throw() : message(msg) {}
|
|
222
|
+
virtual ~app_exception() throw() {}
|
|
223
|
+
const char *what() const throw() { return message.c_str(); }
|
|
224
|
+
|
|
225
|
+
private:
|
|
226
|
+
std::string message;
|
|
227
|
+
};
|
|
228
|
+
|
|
229
|
+
struct kv_dump
|
|
230
|
+
{
|
|
231
|
+
std::vector<std::string> keys;
|
|
232
|
+
std::vector<std::string> vals;
|
|
233
|
+
};
|
|
234
|
+
|
|
235
|
+
//////////////////////////////////////////
|
|
236
|
+
//////////////////////////////////////////
|
|
237
|
+
//////////////////////////////////////////
|
|
238
|
+
|
|
239
|
+
enum display_type
|
|
240
|
+
{
|
|
241
|
+
DISPLAY_TYPE_RESET = 0,
|
|
242
|
+
DISPLAY_TYPE_INFO,
|
|
243
|
+
DISPLAY_TYPE_PROMPT,
|
|
244
|
+
DISPLAY_TYPE_REASONING,
|
|
245
|
+
DISPLAY_TYPE_USER_INPUT,
|
|
246
|
+
DISPLAY_TYPE_ERROR
|
|
247
|
+
};
|
|
248
|
+
|
|
249
|
+
static bool has_more_tasks = false;
|
|
250
|
+
static ggml_log_level log_level = GGML_LOG_LEVEL_INFO;
|
|
251
|
+
|
|
252
|
+
struct wllama_context
|
|
253
|
+
{
|
|
254
|
+
server_context ctx_server;
|
|
255
|
+
llama_context *ctx = nullptr;
|
|
256
|
+
const llama_model *model = nullptr;
|
|
257
|
+
const llama_vocab *vocab = nullptr;
|
|
258
|
+
common_params params;
|
|
259
|
+
|
|
260
|
+
std::function<bool()> should_stop = []()
|
|
261
|
+
{ return false; };
|
|
262
|
+
std::string last_error;
|
|
263
|
+
// using unique_ptr to allow late initialization
|
|
264
|
+
std::unique_ptr<server_response_reader> rd;
|
|
265
|
+
std::unique_ptr<const server_context_meta> meta;
|
|
266
|
+
|
|
267
|
+
struct console
|
|
268
|
+
{
|
|
269
|
+
struct spinner
|
|
270
|
+
{
|
|
271
|
+
static void start() {}
|
|
272
|
+
static void stop() {}
|
|
273
|
+
} spinner;
|
|
274
|
+
static void set_display(display_type display) {}
|
|
275
|
+
static void flush() {}
|
|
276
|
+
} console;
|
|
277
|
+
|
|
278
|
+
explicit wllama_context() {};
|
|
279
|
+
|
|
280
|
+
void create_completion_task(std::string &req_raw, std::vector<raw_buffer> &files, bool is_chat)
|
|
281
|
+
{
|
|
282
|
+
json body = json::parse(req_raw);
|
|
283
|
+
task_response_type res_type = TASK_RESPONSE_TYPE_OAI_CMPL;
|
|
284
|
+
|
|
285
|
+
if (is_chat)
|
|
286
|
+
{
|
|
287
|
+
std::vector<raw_buffer> dummy_files; // unused
|
|
288
|
+
json body_parsed = oaicompat_chat_params_parse(
|
|
289
|
+
body,
|
|
290
|
+
meta->chat_params,
|
|
291
|
+
dummy_files);
|
|
292
|
+
body = std::move(body_parsed);
|
|
293
|
+
res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
{
|
|
297
|
+
const auto &prompt = body.at("prompt");
|
|
298
|
+
|
|
299
|
+
// TODO: reduce some copies here in the future
|
|
300
|
+
server_task task = server_task(SERVER_TASK_TYPE_COMPLETION);
|
|
301
|
+
task.id = rd->get_new_id();
|
|
302
|
+
task.index = 0;
|
|
303
|
+
task.params = server_task::params_from_json_cmpl(
|
|
304
|
+
vocab,
|
|
305
|
+
params,
|
|
306
|
+
meta->slot_n_ctx,
|
|
307
|
+
meta->logit_bias_eog,
|
|
308
|
+
body);
|
|
309
|
+
task.params.res_type = res_type;
|
|
310
|
+
task.cli_prompt = prompt;
|
|
311
|
+
task.cli_files = files;
|
|
312
|
+
task.cli = true;
|
|
313
|
+
|
|
314
|
+
rd->post_task({std::move(task)});
|
|
315
|
+
}
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
std::pair<server_task_result_ptr, bool> get_next_result()
|
|
319
|
+
{
|
|
320
|
+
server_task_result_ptr result = rd->next(should_stop);
|
|
321
|
+
if (result)
|
|
322
|
+
{
|
|
323
|
+
const bool is_error = result->is_error();
|
|
324
|
+
return {std::move(result), is_error};
|
|
325
|
+
}
|
|
326
|
+
else
|
|
327
|
+
{
|
|
328
|
+
return {nullptr, false};
|
|
329
|
+
}
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
kv_dump dump_metadata()
|
|
333
|
+
{
|
|
334
|
+
kv_dump output;
|
|
335
|
+
int count = llama_model_meta_count(model);
|
|
336
|
+
std::string key;
|
|
337
|
+
std::string val;
|
|
338
|
+
std::vector<char> buf(1024);
|
|
339
|
+
int res = 0;
|
|
340
|
+
for (int i = 0; i < count; i++)
|
|
341
|
+
{
|
|
342
|
+
res = llama_model_meta_val_str_by_index(model, i, buf.data(), buf.size());
|
|
343
|
+
if (res < 0)
|
|
344
|
+
continue;
|
|
345
|
+
if (res > buf.size())
|
|
346
|
+
{
|
|
347
|
+
buf.resize(res + 1);
|
|
348
|
+
res = llama_model_meta_val_str_by_index(model, i, buf.data(), buf.size());
|
|
349
|
+
}
|
|
350
|
+
val = std::string(buf.data(), res);
|
|
351
|
+
res = llama_model_meta_key_by_index(model, i, buf.data(), buf.size());
|
|
352
|
+
if (res < 0)
|
|
353
|
+
continue;
|
|
354
|
+
if (res > buf.size())
|
|
355
|
+
{
|
|
356
|
+
buf.resize(res + 1);
|
|
357
|
+
res = llama_model_meta_key_by_index(model, i, buf.data(), buf.size());
|
|
358
|
+
}
|
|
359
|
+
key = std::string(buf.data(), res);
|
|
360
|
+
output.keys.push_back(std::move(key));
|
|
361
|
+
output.vals.push_back(std::move(val));
|
|
362
|
+
}
|
|
363
|
+
return output;
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
// returns true if there are more tasks in the queue after this one
|
|
367
|
+
int run_loop()
|
|
368
|
+
{
|
|
369
|
+
ctx_server.start_loop(); // only run one iteration of the generation loop (i.e. generating one token)
|
|
370
|
+
return has_more_tasks;
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
//////////////////////////////////////////
|
|
374
|
+
//////////////////////////////////////////
|
|
375
|
+
//////////////////////////////////////////
|
|
376
|
+
|
|
377
|
+
glue_msg_load_res action_load(const char *req_raw)
|
|
378
|
+
{
|
|
379
|
+
PARSE_REQ(glue_msg_load_req);
|
|
380
|
+
|
|
381
|
+
llama_log_set(common_log_default_callback, nullptr);
|
|
382
|
+
|
|
383
|
+
assert(ctx == nullptr);
|
|
384
|
+
std::vector<std::string> &model_paths = req.model_paths.arr;
|
|
385
|
+
bool n_ctx_auto = req.n_ctx_auto.value;
|
|
386
|
+
|
|
387
|
+
assert(model_paths.size() > 0);
|
|
388
|
+
params.model.path = model_paths[0];
|
|
389
|
+
|
|
390
|
+
if (req.log_level.not_null())
|
|
391
|
+
log_level = static_cast<ggml_log_level>(req.log_level.value);
|
|
392
|
+
|
|
393
|
+
// mmproj params
|
|
394
|
+
if (req.mmproj_path.not_null())
|
|
395
|
+
params.mmproj.path = req.mmproj_path.value;
|
|
396
|
+
if (req.image_min_tokens.not_null())
|
|
397
|
+
params.image_min_tokens = req.image_min_tokens.value;
|
|
398
|
+
if (req.image_max_tokens.not_null())
|
|
399
|
+
params.image_max_tokens = req.image_max_tokens.value;
|
|
400
|
+
|
|
401
|
+
// model params
|
|
402
|
+
if (req.use_mmap.not_null())
|
|
403
|
+
params.use_mmap = req.use_mmap.value;
|
|
404
|
+
if (req.use_mlock.not_null())
|
|
405
|
+
params.use_mlock = req.use_mlock.value;
|
|
406
|
+
if (req.n_gpu_layers.not_null())
|
|
407
|
+
params.n_gpu_layers = req.n_gpu_layers.value;
|
|
408
|
+
if (req.model_alias.not_null())
|
|
409
|
+
params.model_alias.insert(req.model_alias.value);
|
|
410
|
+
|
|
411
|
+
params.n_ctx = req.n_ctx.value;
|
|
412
|
+
params.cpuparams.n_threads = req.n_threads.value;
|
|
413
|
+
params.cpuparams_batch.n_threads = req.n_threads.value;
|
|
414
|
+
if (req.embeddings.not_null())
|
|
415
|
+
params.embedding = req.embeddings.value;
|
|
416
|
+
if (req.n_batch.not_null())
|
|
417
|
+
params.n_batch = req.n_batch.value;
|
|
418
|
+
if (req.n_parallel.not_null())
|
|
419
|
+
params.n_parallel = req.n_parallel.value;
|
|
420
|
+
if (req.pooling_type.not_null())
|
|
421
|
+
params.pooling_type = pooling_type_from_str(req.pooling_type.value);
|
|
422
|
+
// context extending: https://github.com/ggerganov/llama.cpp/pull/2054
|
|
423
|
+
if (req.rope_scaling_type.not_null())
|
|
424
|
+
params.rope_scaling_type = rope_scaling_type_from_str(req.rope_scaling_type.value);
|
|
425
|
+
if (req.rope_freq_base.not_null())
|
|
426
|
+
params.rope_freq_base = req.rope_freq_base.value;
|
|
427
|
+
if (req.rope_freq_scale.not_null())
|
|
428
|
+
params.rope_freq_scale = req.rope_freq_scale.value;
|
|
429
|
+
if (req.yarn_ext_factor.not_null())
|
|
430
|
+
params.yarn_ext_factor = req.yarn_ext_factor.value;
|
|
431
|
+
if (req.yarn_attn_factor.not_null())
|
|
432
|
+
params.yarn_attn_factor = req.yarn_attn_factor.value;
|
|
433
|
+
if (req.yarn_beta_fast.not_null())
|
|
434
|
+
params.yarn_beta_fast = req.yarn_beta_fast.value;
|
|
435
|
+
if (req.yarn_beta_slow.not_null())
|
|
436
|
+
params.yarn_beta_slow = req.yarn_beta_slow.value;
|
|
437
|
+
if (req.yarn_orig_ctx.not_null())
|
|
438
|
+
params.yarn_orig_ctx = req.yarn_orig_ctx.value;
|
|
439
|
+
if (req.warmup.not_null())
|
|
440
|
+
params.warmup = req.warmup.value;
|
|
441
|
+
|
|
442
|
+
// optimizations
|
|
443
|
+
if (req.cache_type_k.not_null())
|
|
444
|
+
params.cache_type_k = kv_cache_type_from_str(req.cache_type_k.value);
|
|
445
|
+
if (req.cache_type_v.not_null())
|
|
446
|
+
params.cache_type_v = kv_cache_type_from_str(req.cache_type_v.value);
|
|
447
|
+
if (req.flash_attn.not_null())
|
|
448
|
+
params.flash_attn_type = req.flash_attn.value ? LLAMA_FLASH_ATTN_TYPE_AUTO : LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
|
449
|
+
if (req.swa_full.not_null())
|
|
450
|
+
params.swa_full = req.swa_full.value;
|
|
451
|
+
if (req.n_ctx_checkpoints.not_null())
|
|
452
|
+
params.n_ctx_checkpoints = req.n_ctx_checkpoints.value;
|
|
453
|
+
if (req.checkpoint_min_step.not_null())
|
|
454
|
+
params.checkpoint_min_step = req.checkpoint_min_step.value;
|
|
455
|
+
|
|
456
|
+
// template params
|
|
457
|
+
if (req.chat_template.not_null())
|
|
458
|
+
params.chat_template = req.chat_template.value;
|
|
459
|
+
if (req.jinja.not_null())
|
|
460
|
+
params.use_jinja = req.jinja.value;
|
|
461
|
+
if (req.reasoning.not_null())
|
|
462
|
+
{
|
|
463
|
+
if (req.reasoning.value)
|
|
464
|
+
{
|
|
465
|
+
params.enable_reasoning = 1;
|
|
466
|
+
params.default_template_kwargs["enable_thinking"] = "true";
|
|
467
|
+
}
|
|
468
|
+
else
|
|
469
|
+
{
|
|
470
|
+
params.enable_reasoning = 0;
|
|
471
|
+
params.default_template_kwargs["enable_thinking"] = "false";
|
|
472
|
+
}
|
|
473
|
+
}
|
|
474
|
+
if (req.default_template_kwargs_keys.not_null() && req.default_template_kwargs_vals.not_null())
|
|
475
|
+
{
|
|
476
|
+
auto &keys = req.default_template_kwargs_keys.arr;
|
|
477
|
+
auto &vals = req.default_template_kwargs_vals.arr;
|
|
478
|
+
if (keys.size() != vals.size())
|
|
479
|
+
{
|
|
480
|
+
throw app_exception("default_template_kwargs_keys and default_template_kwargs_vals must have the same length");
|
|
481
|
+
}
|
|
482
|
+
for (size_t i = 0; i < keys.size(); i++)
|
|
483
|
+
{
|
|
484
|
+
params.default_template_kwargs[keys[i]] = vals[i];
|
|
485
|
+
}
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
// GPU
|
|
489
|
+
if (req.no_kv_offload.not_null())
|
|
490
|
+
params.no_kv_offload = req.no_kv_offload.value;
|
|
491
|
+
if (req.mmproj_offload.not_null())
|
|
492
|
+
params.mmproj_use_gpu = req.mmproj_offload.value;
|
|
493
|
+
|
|
494
|
+
// batch / context scheduling
|
|
495
|
+
if (req.cont_batching.not_null())
|
|
496
|
+
params.cont_batching = req.cont_batching.value;
|
|
497
|
+
if (req.n_keep.not_null())
|
|
498
|
+
params.n_keep = req.n_keep.value;
|
|
499
|
+
if (req.ctx_shift.not_null())
|
|
500
|
+
params.ctx_shift = req.ctx_shift.value;
|
|
501
|
+
if (req.cache_idle_slots.not_null())
|
|
502
|
+
params.cache_idle_slots = req.cache_idle_slots.value;
|
|
503
|
+
if (req.n_cache_reuse.not_null())
|
|
504
|
+
params.n_cache_reuse = req.n_cache_reuse.value;
|
|
505
|
+
|
|
506
|
+
// lora
|
|
507
|
+
if (req.lora_paths.not_null())
|
|
508
|
+
{
|
|
509
|
+
const auto &paths = req.lora_paths.arr;
|
|
510
|
+
const auto &scales = req.lora_scales.arr;
|
|
511
|
+
if (!scales.empty() && scales.size() != paths.size())
|
|
512
|
+
throw app_exception("lora_paths and lora_scales must have the same length");
|
|
513
|
+
for (size_t i = 0; i < paths.size(); i++)
|
|
514
|
+
{
|
|
515
|
+
common_adapter_lora_info info;
|
|
516
|
+
info.path = paths[i];
|
|
517
|
+
info.scale = scales.empty() ? 1.0f : scales[i];
|
|
518
|
+
params.lora_adapters.push_back(std::move(info));
|
|
519
|
+
}
|
|
520
|
+
}
|
|
521
|
+
if (req.lora_init_without_apply.not_null())
|
|
522
|
+
params.lora_init_without_apply = req.lora_init_without_apply.value;
|
|
523
|
+
|
|
524
|
+
// speculative decoding
|
|
525
|
+
if (req.spec_draft_model.not_null())
|
|
526
|
+
params.speculative.draft.mparams.path = req.spec_draft_model.value;
|
|
527
|
+
if (req.spec_draft_ngl.not_null())
|
|
528
|
+
params.speculative.draft.n_gpu_layers = req.spec_draft_ngl.value;
|
|
529
|
+
if (req.spec_draft_n_max.not_null())
|
|
530
|
+
params.speculative.draft.n_max = req.spec_draft_n_max.value;
|
|
531
|
+
if (req.spec_draft_n_min.not_null())
|
|
532
|
+
params.speculative.draft.n_min = req.spec_draft_n_min.value;
|
|
533
|
+
if (req.spec_draft_p_min.not_null())
|
|
534
|
+
params.speculative.draft.p_min = req.spec_draft_p_min.value;
|
|
535
|
+
if (req.spec_draft_threads.not_null())
|
|
536
|
+
params.speculative.draft.cpuparams.n_threads = req.spec_draft_threads.value;
|
|
537
|
+
if (req.spec_draft_threads_batch.not_null())
|
|
538
|
+
params.speculative.draft.cpuparams_batch.n_threads = req.spec_draft_threads_batch.value;
|
|
539
|
+
|
|
540
|
+
// kv overrides
|
|
541
|
+
if (req.kv_overrides_keys.not_null() && req.kv_overrides_vals.not_null())
|
|
542
|
+
{
|
|
543
|
+
const auto &keys = req.kv_overrides_keys.arr;
|
|
544
|
+
const auto &vals = req.kv_overrides_vals.arr;
|
|
545
|
+
if (keys.size() != vals.size())
|
|
546
|
+
throw app_exception("kv_overrides_keys and kv_overrides_vals must have the same length");
|
|
547
|
+
for (size_t i = 0; i < keys.size(); i++)
|
|
548
|
+
params.kv_overrides.push_back(parse_kv_override(keys[i], vals[i]));
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
// reasoning
|
|
552
|
+
if (req.reasoning_budget_tokens.not_null())
|
|
553
|
+
params.sampling.reasoning_budget_tokens = req.reasoning_budget_tokens.value;
|
|
554
|
+
if (req.reasoning_budget_message.not_null())
|
|
555
|
+
params.sampling.reasoning_budget_message = req.reasoning_budget_message.value;
|
|
556
|
+
if (req.reasoning_format.not_null())
|
|
557
|
+
params.reasoning_format = reasoning_format_from_str(req.reasoning_format.value);
|
|
558
|
+
|
|
559
|
+
// other
|
|
560
|
+
if (req.skip_chat_parsing.not_null())
|
|
561
|
+
params.force_pure_content_parser = req.skip_chat_parsing.value;
|
|
562
|
+
if (req.prefill_assistant.not_null())
|
|
563
|
+
params.prefill_assistant = req.prefill_assistant.value;
|
|
564
|
+
|
|
565
|
+
// init threadpool
|
|
566
|
+
ggml_threadpool_params_default(params.cpuparams.n_threads);
|
|
567
|
+
|
|
568
|
+
// load model
|
|
569
|
+
llama_backend_init();
|
|
570
|
+
llama_numa_init(params.numa);
|
|
571
|
+
if (!ctx_server.load_model(params))
|
|
572
|
+
{
|
|
573
|
+
glue_msg_load_res res;
|
|
574
|
+
res.success.value = false;
|
|
575
|
+
return res;
|
|
576
|
+
}
|
|
577
|
+
|
|
578
|
+
LOG_INF("%s", "Model loaded successfully\n");
|
|
579
|
+
|
|
580
|
+
ctx = ctx_server.get_llama_context();
|
|
581
|
+
model = llama_get_model(ctx);
|
|
582
|
+
vocab = llama_model_get_vocab(model);
|
|
583
|
+
meta = std::make_unique<server_context_meta>(ctx_server.get_meta());
|
|
584
|
+
auto metadata = dump_metadata();
|
|
585
|
+
|
|
586
|
+
// get EOG tokens
|
|
587
|
+
std::vector<llama_token> list_tokens_eog;
|
|
588
|
+
auto n_vocab = llama_vocab_n_tokens(vocab);
|
|
589
|
+
{
|
|
590
|
+
for (int i = 0; i < n_vocab; i++)
|
|
591
|
+
{
|
|
592
|
+
if (llama_vocab_is_eog(vocab, i))
|
|
593
|
+
{
|
|
594
|
+
list_tokens_eog.push_back(i);
|
|
595
|
+
}
|
|
596
|
+
}
|
|
597
|
+
}
|
|
598
|
+
|
|
599
|
+
// multimodal
|
|
600
|
+
auto server_meta = ctx_server.get_meta();
|
|
601
|
+
|
|
602
|
+
glue_msg_load_res res;
|
|
603
|
+
res.success.value = true;
|
|
604
|
+
res.n_ctx.value = params.n_ctx;
|
|
605
|
+
res.n_batch.value = llama_n_batch(ctx);
|
|
606
|
+
res.n_ubatch.value = llama_n_ubatch(ctx);
|
|
607
|
+
res.n_vocab.value = n_vocab;
|
|
608
|
+
res.n_ctx_train.value = llama_model_n_ctx_train(model);
|
|
609
|
+
res.n_embd.value = llama_model_n_embd(model);
|
|
610
|
+
res.n_layer.value = llama_model_n_layer(model);
|
|
611
|
+
res.metadata_key.arr = metadata.keys;
|
|
612
|
+
res.metadata_val.arr = metadata.vals;
|
|
613
|
+
res.token_bos.value = llama_vocab_bos(vocab);
|
|
614
|
+
res.token_eos.value = llama_vocab_eos(vocab);
|
|
615
|
+
res.token_eot.value = llama_vocab_eot(vocab);
|
|
616
|
+
res.list_tokens_eog.arr = std::move(list_tokens_eog);
|
|
617
|
+
res.add_bos_token.value = llama_vocab_get_add_bos(vocab) == 1;
|
|
618
|
+
res.add_eos_token.value = llama_vocab_get_add_eos(vocab) == 1;
|
|
619
|
+
res.has_encoder.value = llama_model_has_encoder(model);
|
|
620
|
+
res.token_decoder_start.value = llama_model_decoder_start_token(model);
|
|
621
|
+
res.media_marker.value = get_media_marker();
|
|
622
|
+
res.has_image_input.value = server_meta.chat_params.allow_image;
|
|
623
|
+
res.has_audio_input.value = server_meta.chat_params.allow_audio;
|
|
624
|
+
return res;
|
|
625
|
+
}
|
|
626
|
+
|
|
627
|
+
glue_msg_completion_res action_completion(const char *req_raw)
|
|
628
|
+
{
|
|
629
|
+
PARSE_REQ(glue_msg_completion_req);
|
|
630
|
+
glue_msg_completion_res res;
|
|
631
|
+
|
|
632
|
+
// prepare
|
|
633
|
+
rd = std::make_unique<server_response_reader>(ctx_server.get_response_reader());
|
|
634
|
+
last_error = "";
|
|
635
|
+
std::vector<raw_buffer> input_files;
|
|
636
|
+
for (const auto &file : req.files.arr)
|
|
637
|
+
{
|
|
638
|
+
input_files.push_back(file);
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
// create completion task and post to the queue
|
|
642
|
+
create_completion_task(req.data_json.value, input_files, req.is_chat.value);
|
|
643
|
+
|
|
644
|
+
res.success.value = true;
|
|
645
|
+
return res;
|
|
646
|
+
}
|
|
647
|
+
|
|
648
|
+
void create_embedding_tasks(std::string &req_raw)
|
|
649
|
+
{
|
|
650
|
+
json body = json::parse(req_raw);
|
|
651
|
+
|
|
652
|
+
json prompt;
|
|
653
|
+
if (body.count("input") != 0)
|
|
654
|
+
{
|
|
655
|
+
prompt = body.at("input");
|
|
656
|
+
}
|
|
657
|
+
else if (body.contains("content"))
|
|
658
|
+
{
|
|
659
|
+
prompt = body.at("content");
|
|
660
|
+
}
|
|
661
|
+
else
|
|
662
|
+
{
|
|
663
|
+
throw app_exception("\"input\" or \"content\" must be provided");
|
|
664
|
+
}
|
|
665
|
+
|
|
666
|
+
int embd_normalize = 2;
|
|
667
|
+
if (body.count("embd_normalize") != 0)
|
|
668
|
+
{
|
|
669
|
+
embd_normalize = body.at("embd_normalize");
|
|
670
|
+
}
|
|
671
|
+
|
|
672
|
+
auto tokenized_prompts = tokenize_input_prompts(vocab, nullptr, prompt, true, true);
|
|
673
|
+
for (const auto &tokens : tokenized_prompts)
|
|
674
|
+
{
|
|
675
|
+
if (tokens.empty())
|
|
676
|
+
{
|
|
677
|
+
throw app_exception("Input content cannot be empty");
|
|
678
|
+
}
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
std::vector<server_task> tasks;
|
|
682
|
+
for (size_t i = 0; i < tokenized_prompts.size(); i++)
|
|
683
|
+
{
|
|
684
|
+
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
|
|
685
|
+
task.id = rd->get_new_id();
|
|
686
|
+
task.tokens = std::move(tokenized_prompts[i]);
|
|
687
|
+
task.params.res_type = TASK_RESPONSE_TYPE_OAI_EMBD;
|
|
688
|
+
task.params.embd_normalize = embd_normalize;
|
|
689
|
+
tasks.push_back(std::move(task));
|
|
690
|
+
}
|
|
691
|
+
rd->post_tasks(std::move(tasks));
|
|
692
|
+
}
|
|
693
|
+
|
|
694
|
+
glue_msg_embedding_res action_embedding(const char *req_raw)
|
|
695
|
+
{
|
|
696
|
+
PARSE_REQ(glue_msg_embedding_req);
|
|
697
|
+
glue_msg_embedding_res res;
|
|
698
|
+
|
|
699
|
+
rd = std::make_unique<server_response_reader>(ctx_server.get_response_reader());
|
|
700
|
+
last_error = "";
|
|
701
|
+
|
|
702
|
+
create_embedding_tasks(req.data_json.value);
|
|
703
|
+
|
|
704
|
+
res.success.value = true;
|
|
705
|
+
return res;
|
|
706
|
+
}
|
|
707
|
+
|
|
708
|
+
glue_msg_rerank_res action_rerank(const char *req_raw)
|
|
709
|
+
{
|
|
710
|
+
PARSE_REQ(glue_msg_rerank_req);
|
|
711
|
+
glue_msg_rerank_res res;
|
|
712
|
+
|
|
713
|
+
json body = json::parse(req.data_json.value);
|
|
714
|
+
if (!body.contains("query") || !body.at("query").is_string())
|
|
715
|
+
{
|
|
716
|
+
throw app_exception("\"query\" must be a string");
|
|
717
|
+
}
|
|
718
|
+
if (!body.contains("document") || !body.at("document").is_string())
|
|
719
|
+
{
|
|
720
|
+
throw app_exception("\"document\" must be a string");
|
|
721
|
+
}
|
|
722
|
+
std::string query = body.at("query");
|
|
723
|
+
std::string document = body.at("document");
|
|
724
|
+
|
|
725
|
+
rd = std::make_unique<server_response_reader>(ctx_server.get_response_reader());
|
|
726
|
+
last_error = "";
|
|
727
|
+
|
|
728
|
+
auto tokens = format_prompt_rerank(model, vocab, nullptr, query, document);
|
|
729
|
+
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
|
|
730
|
+
task.id = rd->get_new_id();
|
|
731
|
+
task.index = 0;
|
|
732
|
+
task.tokens = std::move(tokens);
|
|
733
|
+
rd->post_task(std::move(task));
|
|
734
|
+
|
|
735
|
+
res.success.value = true;
|
|
736
|
+
return res;
|
|
737
|
+
}
|
|
738
|
+
|
|
739
|
+
glue_msg_get_result_res action_get_result(const char *req_raw)
|
|
740
|
+
{
|
|
741
|
+
PARSE_REQ(glue_msg_get_result_req);
|
|
742
|
+
glue_msg_get_result_res res;
|
|
743
|
+
|
|
744
|
+
bool has_more = run_loop();
|
|
745
|
+
auto [result, is_error] = get_next_result();
|
|
746
|
+
|
|
747
|
+
json data_json;
|
|
748
|
+
if (result)
|
|
749
|
+
{
|
|
750
|
+
if (auto *embd = dynamic_cast<server_task_result_embd *>(result.get()))
|
|
751
|
+
{
|
|
752
|
+
(void)embd;
|
|
753
|
+
// special handling for embeddings OAI-compat
|
|
754
|
+
json body = {{"model", meta->model_name}};
|
|
755
|
+
json responses = json::array();
|
|
756
|
+
responses.push_back(result->to_json());
|
|
757
|
+
// TODO: support base64 output
|
|
758
|
+
data_json = format_embeddings_response_oaicompat(body, meta->model_name, responses, false);
|
|
759
|
+
}
|
|
760
|
+
else if (auto *rerank = dynamic_cast<server_task_result_rerank *>(result.get()))
|
|
761
|
+
{
|
|
762
|
+
data_json = json{
|
|
763
|
+
{"score", rerank->score},
|
|
764
|
+
{"tokens_evaluated", rerank->n_tokens},
|
|
765
|
+
};
|
|
766
|
+
}
|
|
767
|
+
else
|
|
768
|
+
{
|
|
769
|
+
// completion result
|
|
770
|
+
data_json = result->to_json();
|
|
771
|
+
}
|
|
772
|
+
}
|
|
773
|
+
|
|
774
|
+
res.success.value = true;
|
|
775
|
+
res.has_more.value = has_more;
|
|
776
|
+
res.data_json.value = result ? data_json.dump() : "";
|
|
777
|
+
res.is_error.value = is_error;
|
|
778
|
+
return res;
|
|
779
|
+
}
|
|
780
|
+
|
|
781
|
+
glue_msg_test_backend_ops_res action_test_backend_ops(const char *req_raw)
|
|
782
|
+
{
|
|
783
|
+
PARSE_REQ(glue_msg_test_backend_ops_req);
|
|
784
|
+
glue_msg_test_backend_ops_res res;
|
|
785
|
+
|
|
786
|
+
auto &args = req.args.arr;
|
|
787
|
+
|
|
788
|
+
std::vector<char *> argv;
|
|
789
|
+
argv.reserve(args.size());
|
|
790
|
+
for (auto &s : args)
|
|
791
|
+
{
|
|
792
|
+
argv.push_back(const_cast<char *>(s.c_str()));
|
|
793
|
+
}
|
|
794
|
+
|
|
795
|
+
auto curr_log_lvl = log_level;
|
|
796
|
+
log_level = GGML_LOG_LEVEL_DEBUG;
|
|
797
|
+
force_single_thread = true;
|
|
798
|
+
int retcode = main_test_backend_ops((int)argv.size(), argv.data());
|
|
799
|
+
log_level = curr_log_lvl; // restore log level
|
|
800
|
+
force_single_thread = false; // restore threading
|
|
801
|
+
|
|
802
|
+
res.retcode.value = retcode;
|
|
803
|
+
res.success.value = retcode == 0;
|
|
804
|
+
|
|
805
|
+
return res;
|
|
806
|
+
}
|
|
807
|
+
};
|
|
808
|
+
|
|
809
|
+
////////////////////////////
|
|
810
|
+
// server_queue
|
|
811
|
+
|
|
812
|
+
int server_queue::get_new_id()
|
|
813
|
+
{
|
|
814
|
+
return id++;
|
|
815
|
+
}
|
|
816
|
+
|
|
817
|
+
int server_queue::post(server_task &&task, bool front)
|
|
818
|
+
{
|
|
819
|
+
GGML_ASSERT(task.id != -1);
|
|
820
|
+
// if this is cancel task make sure to clean up pending tasks
|
|
821
|
+
if (task.type == SERVER_TASK_TYPE_CANCEL)
|
|
822
|
+
{
|
|
823
|
+
cleanup_pending_task(task.id_target);
|
|
824
|
+
}
|
|
825
|
+
const int task_id = task.id;
|
|
826
|
+
if (front)
|
|
827
|
+
{
|
|
828
|
+
queue_tasks.push_front(std::move(task));
|
|
829
|
+
}
|
|
830
|
+
else
|
|
831
|
+
{
|
|
832
|
+
queue_tasks.push_back(std::move(task));
|
|
833
|
+
}
|
|
834
|
+
time_last_task = ggml_time_ms();
|
|
835
|
+
return task_id;
|
|
836
|
+
}
|
|
837
|
+
|
|
838
|
+
int server_queue::post(std::vector<server_task> &&tasks, bool front)
|
|
839
|
+
{
|
|
840
|
+
for (auto &task : tasks)
|
|
841
|
+
{
|
|
842
|
+
if (task.id == -1)
|
|
843
|
+
{
|
|
844
|
+
task.id = id++;
|
|
845
|
+
}
|
|
846
|
+
// if this is cancel task make sure to clean up pending tasks
|
|
847
|
+
if (task.type == SERVER_TASK_TYPE_CANCEL)
|
|
848
|
+
{
|
|
849
|
+
cleanup_pending_task(task.id_target);
|
|
850
|
+
}
|
|
851
|
+
if (front)
|
|
852
|
+
{
|
|
853
|
+
queue_tasks.push_front(std::move(task));
|
|
854
|
+
}
|
|
855
|
+
else
|
|
856
|
+
{
|
|
857
|
+
queue_tasks.push_back(std::move(task));
|
|
858
|
+
}
|
|
859
|
+
}
|
|
860
|
+
time_last_task = ggml_time_ms();
|
|
861
|
+
return 0;
|
|
862
|
+
}
|
|
863
|
+
|
|
864
|
+
void server_queue::cleanup_pending_task(int id_target)
|
|
865
|
+
{
|
|
866
|
+
auto rm_func = [id_target](const server_task &task)
|
|
867
|
+
{
|
|
868
|
+
return task.id == id_target;
|
|
869
|
+
};
|
|
870
|
+
queue_tasks.erase(
|
|
871
|
+
std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
|
|
872
|
+
queue_tasks.end());
|
|
873
|
+
queue_tasks_deferred.erase(
|
|
874
|
+
std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
|
|
875
|
+
queue_tasks_deferred.end());
|
|
876
|
+
}
|
|
877
|
+
|
|
878
|
+
void server_queue::defer(server_task &&task)
|
|
879
|
+
{
|
|
880
|
+
assert(false && "should not be called in wllama");
|
|
881
|
+
}
|
|
882
|
+
|
|
883
|
+
void server_queue::pop_deferred_task(int id_slot)
|
|
884
|
+
{
|
|
885
|
+
// no deferred task in wllama, so this is a no-op
|
|
886
|
+
}
|
|
887
|
+
|
|
888
|
+
void server_response::send(server_task_result_ptr &&result)
|
|
889
|
+
{
|
|
890
|
+
if (test_stack_trace == TEST_STACK_TRACE_ABORT)
|
|
891
|
+
{
|
|
892
|
+
LOG_DBG("%s: force abort for testing\n", __func__);
|
|
893
|
+
abort();
|
|
894
|
+
}
|
|
895
|
+
else if (test_stack_trace == TEST_STACK_TRACE_OOB)
|
|
896
|
+
{
|
|
897
|
+
LOG_DBG("%s: force out-of-bounds for testing\n", __func__);
|
|
898
|
+
int *ptr = reinterpret_cast<int *>(0x40000000); // 1GB
|
|
899
|
+
*ptr = 0;
|
|
900
|
+
}
|
|
901
|
+
LOG_DBG("%s\n", __func__);
|
|
902
|
+
queue_results.push_back(std::move(result));
|
|
903
|
+
}
|
|
904
|
+
|
|
905
|
+
void server_response::add_waiting_task_id(int id)
|
|
906
|
+
{
|
|
907
|
+
// no-op
|
|
908
|
+
}
|
|
909
|
+
|
|
910
|
+
void server_response::add_waiting_task_ids(const std::unordered_set<int> &id_tasks)
|
|
911
|
+
{
|
|
912
|
+
// no-op
|
|
913
|
+
}
|
|
914
|
+
|
|
915
|
+
void server_response::remove_waiting_task_ids(const std::unordered_set<int> &id_tasks)
|
|
916
|
+
{
|
|
917
|
+
// no-op
|
|
918
|
+
}
|
|
919
|
+
|
|
920
|
+
server_task_result_ptr server_response::recv(const std::unordered_set<int> &)
|
|
921
|
+
{
|
|
922
|
+
for (size_t i = 0; i < queue_results.size(); i++)
|
|
923
|
+
{
|
|
924
|
+
server_task_result_ptr res = std::move(queue_results[i]);
|
|
925
|
+
queue_results.erase(queue_results.begin() + i);
|
|
926
|
+
return res;
|
|
927
|
+
}
|
|
928
|
+
return nullptr;
|
|
929
|
+
}
|
|
930
|
+
|
|
931
|
+
void server_queue::start_loop(int64_t idle_sleep_ms)
|
|
932
|
+
{
|
|
933
|
+
while (true)
|
|
934
|
+
{
|
|
935
|
+
if (queue_tasks.empty())
|
|
936
|
+
{
|
|
937
|
+
break;
|
|
938
|
+
}
|
|
939
|
+
server_task task = std::move(queue_tasks.front());
|
|
940
|
+
queue_tasks.pop_front();
|
|
941
|
+
|
|
942
|
+
LOG_DBG("processing task, id = %d\n", task.id);
|
|
943
|
+
callback_new_task(std::move(task));
|
|
944
|
+
}
|
|
945
|
+
// all tasks in the current loop is processed, slots data is now ready
|
|
946
|
+
LOG_DBG("%s", "update slots\n");
|
|
947
|
+
|
|
948
|
+
// this will run the main inference process for all slots
|
|
949
|
+
callback_update_slots();
|
|
950
|
+
|
|
951
|
+
has_more_tasks = !queue_tasks.empty();
|
|
952
|
+
}
|
|
953
|
+
|
|
954
|
+
const char *llama_build_info()
|
|
955
|
+
{
|
|
956
|
+
return "wllama";
|
|
957
|
+
}
|
|
958
|
+
|
|
959
|
+
////////////////////////////
|
|
960
|
+
// server_response_reader
|
|
961
|
+
|
|
962
|
+
void server_response_reader::post_task(server_task &&task, bool front)
|
|
963
|
+
{
|
|
964
|
+
LOG_DBG("%s\n", __func__);
|
|
965
|
+
GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
|
|
966
|
+
GGML_ASSERT(!task.is_parent() && "not supported, use post_tasks() instead");
|
|
967
|
+
task.index = 0;
|
|
968
|
+
id_tasks.insert(task.id);
|
|
969
|
+
states.push_back(task.create_state());
|
|
970
|
+
queue_results.add_waiting_task_id(task.id);
|
|
971
|
+
queue_tasks.post(std::move(task), front);
|
|
972
|
+
}
|
|
973
|
+
|
|
974
|
+
void server_response_reader::post_tasks(std::vector<server_task> &&tasks, bool front)
|
|
975
|
+
{
|
|
976
|
+
LOG_DBG("%s\n", __func__);
|
|
977
|
+
GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
|
|
978
|
+
id_tasks = server_task::get_list_id(tasks);
|
|
979
|
+
states.reserve(tasks.size());
|
|
980
|
+
size_t index = 0;
|
|
981
|
+
for (auto &task : tasks)
|
|
982
|
+
{
|
|
983
|
+
task.index = index++;
|
|
984
|
+
states.push_back(task.create_state());
|
|
985
|
+
// for child tasks
|
|
986
|
+
for (auto &child_task : task.child_tasks)
|
|
987
|
+
{
|
|
988
|
+
child_task.index = index++;
|
|
989
|
+
states.push_back(child_task.create_state());
|
|
990
|
+
}
|
|
991
|
+
}
|
|
992
|
+
GGML_ASSERT(states.size() == id_tasks.size());
|
|
993
|
+
queue_results.add_waiting_task_ids(id_tasks);
|
|
994
|
+
queue_tasks.post(std::move(tasks), front);
|
|
995
|
+
}
|
|
996
|
+
|
|
997
|
+
bool server_response_reader::has_next() const
|
|
998
|
+
{
|
|
999
|
+
return !cancelled && received_count < id_tasks.size();
|
|
1000
|
+
}
|
|
1001
|
+
|
|
1002
|
+
// return nullptr if should_stop() is true before receiving a result
|
|
1003
|
+
// note: if one error is received, it will stop further processing and return error result
|
|
1004
|
+
server_task_result_ptr server_response_reader::next(const std::function<bool()> &should_stop)
|
|
1005
|
+
{
|
|
1006
|
+
LOG_DBG("%s\n", __func__);
|
|
1007
|
+
auto result = queue_results.recv(id_tasks);
|
|
1008
|
+
if (result && !states.empty())
|
|
1009
|
+
{
|
|
1010
|
+
// update the generation state if needed
|
|
1011
|
+
LOG_DBG("%s: update result\n", __func__);
|
|
1012
|
+
const size_t idx = result->index;
|
|
1013
|
+
GGML_ASSERT(idx < states.size());
|
|
1014
|
+
result->update(states[idx]);
|
|
1015
|
+
}
|
|
1016
|
+
if (result && result->is_error())
|
|
1017
|
+
{
|
|
1018
|
+
LOG_DBG("%s: received error result, stop further processing\n", __func__);
|
|
1019
|
+
stop();
|
|
1020
|
+
}
|
|
1021
|
+
return result;
|
|
1022
|
+
}
|
|
1023
|
+
|
|
1024
|
+
void server_response_reader::stop()
|
|
1025
|
+
{
|
|
1026
|
+
queue_results.remove_waiting_task_ids(id_tasks);
|
|
1027
|
+
cancelled = true;
|
|
1028
|
+
std::vector<server_task> cancel_tasks;
|
|
1029
|
+
cancel_tasks.reserve(id_tasks.size());
|
|
1030
|
+
for (const auto &id_task : id_tasks)
|
|
1031
|
+
{
|
|
1032
|
+
LOG_DBG("cancel task, id_task = %d\n", id_task);
|
|
1033
|
+
server_task task(SERVER_TASK_TYPE_CANCEL);
|
|
1034
|
+
task.id_target = id_task;
|
|
1035
|
+
cancel_tasks.push_back(std::move(task));
|
|
1036
|
+
}
|
|
1037
|
+
// push to beginning of the queue, so it has highest priority
|
|
1038
|
+
queue_tasks.post(std::move(cancel_tasks), true);
|
|
1039
|
+
}
|
|
1040
|
+
|
|
1041
|
+
////////////////////////////
|
|
1042
|
+
// common_log
|
|
1043
|
+
|
|
1044
|
+
int common_log_get_verbosity_thold(void)
|
|
1045
|
+
{
|
|
1046
|
+
return log_level;
|
|
1047
|
+
}
|
|
1048
|
+
|
|
1049
|
+
void common_log_set_verbosity_thold(int verbosity)
|
|
1050
|
+
{
|
|
1051
|
+
log_level = static_cast<ggml_log_level>(verbosity);
|
|
1052
|
+
}
|
|
1053
|
+
|
|
1054
|
+
struct common_log
|
|
1055
|
+
{
|
|
1056
|
+
void add(enum ggml_log_level level, const char *fmt, va_list args)
|
|
1057
|
+
{
|
|
1058
|
+
static std::vector<char> msg;
|
|
1059
|
+
|
|
1060
|
+
const size_t n = vsnprintf(msg.data(), msg.size(), fmt, args);
|
|
1061
|
+
if (n >= msg.size())
|
|
1062
|
+
{
|
|
1063
|
+
msg.resize(n + 1);
|
|
1064
|
+
// cannot use args twice, so make a copy in case we need to expand the buffer
|
|
1065
|
+
va_list args_copy;
|
|
1066
|
+
va_copy(args_copy, args);
|
|
1067
|
+
vsnprintf(msg.data(), msg.size(), fmt, args_copy);
|
|
1068
|
+
}
|
|
1069
|
+
|
|
1070
|
+
const char *lvl = "@@DEBUG";
|
|
1071
|
+
const char *text = msg.data();
|
|
1072
|
+
size_t len = strlen(text);
|
|
1073
|
+
if (len == 0 || text[len - 1] != '\n')
|
|
1074
|
+
{
|
|
1075
|
+
// do not print if the line does not terminate with \n
|
|
1076
|
+
return;
|
|
1077
|
+
}
|
|
1078
|
+
if (level == GGML_LOG_LEVEL_ERROR)
|
|
1079
|
+
{
|
|
1080
|
+
lvl = "@@ERROR";
|
|
1081
|
+
}
|
|
1082
|
+
else if (level == GGML_LOG_LEVEL_WARN)
|
|
1083
|
+
{
|
|
1084
|
+
lvl = "@@WARN";
|
|
1085
|
+
}
|
|
1086
|
+
else if (level == GGML_LOG_LEVEL_INFO)
|
|
1087
|
+
{
|
|
1088
|
+
lvl = "@@INFO";
|
|
1089
|
+
}
|
|
1090
|
+
fprintf(stderr, "%s@@%s", lvl, text);
|
|
1091
|
+
}
|
|
1092
|
+
};
|
|
1093
|
+
|
|
1094
|
+
struct common_log *common_log_init()
|
|
1095
|
+
{
|
|
1096
|
+
return new common_log;
|
|
1097
|
+
}
|
|
1098
|
+
|
|
1099
|
+
struct common_log *common_log_main()
|
|
1100
|
+
{
|
|
1101
|
+
static struct common_log log;
|
|
1102
|
+
return &log;
|
|
1103
|
+
}
|
|
1104
|
+
|
|
1105
|
+
void common_log_add(struct common_log *log, enum ggml_log_level level, const char *fmt, ...)
|
|
1106
|
+
{
|
|
1107
|
+
va_list args;
|
|
1108
|
+
va_start(args, fmt);
|
|
1109
|
+
log->add(level, fmt, args);
|
|
1110
|
+
va_end(args);
|
|
1111
|
+
}
|
|
1112
|
+
|
|
1113
|
+
static int common_get_verbosity(enum ggml_log_level level)
|
|
1114
|
+
{
|
|
1115
|
+
return static_cast<int>(level);
|
|
1116
|
+
}
|
|
1117
|
+
|
|
1118
|
+
void common_log_default_callback(enum ggml_log_level level, const char *text, void * /*user_data*/)
|
|
1119
|
+
{
|
|
1120
|
+
auto verbosity = common_get_verbosity(level);
|
|
1121
|
+
if (verbosity >= static_cast<int>(log_level))
|
|
1122
|
+
{
|
|
1123
|
+
common_log_add(common_log_main(), level, "%s", text);
|
|
1124
|
+
}
|
|
1125
|
+
}
|
|
1126
|
+
|
|
1127
|
+
enum common_params_fit_status common_fit_params(
|
|
1128
|
+
const char *path_model,
|
|
1129
|
+
llama_model_params *mparams,
|
|
1130
|
+
llama_context_params *cparams,
|
|
1131
|
+
float *tensor_split,
|
|
1132
|
+
llama_model_tensor_buft_override *tensor_buft_overrides,
|
|
1133
|
+
size_t *margins,
|
|
1134
|
+
uint32_t n_ctx_min,
|
|
1135
|
+
ggml_log_level log_level)
|
|
1136
|
+
{
|
|
1137
|
+
return COMMON_PARAMS_FIT_STATUS_FAILURE;
|
|
1138
|
+
}
|
|
1139
|
+
|
|
1140
|
+
std::pair<long, std::vector<char>> common_remote_get_content(const std::string &url,
|
|
1141
|
+
const common_remote_params ¶ms)
|
|
1142
|
+
{
|
|
1143
|
+
throw std::runtime_error("common_remote_get_content is not implemented in wllama");
|
|
1144
|
+
}
|
|
1145
|
+
|
|
1146
|
+
#if 1
|
|
1147
|
+
common_device_memory_data_vec common_get_device_memory_data(
|
|
1148
|
+
const char *path_model,
|
|
1149
|
+
const llama_model_params *mparams,
|
|
1150
|
+
const llama_context_params *cparams,
|
|
1151
|
+
std::vector<ggml_backend_dev_t> &devs,
|
|
1152
|
+
uint32_t &hp_ngl,
|
|
1153
|
+
uint32_t &hp_n_ctx_train,
|
|
1154
|
+
uint32_t &hp_n_expert,
|
|
1155
|
+
ggml_log_level log_level)
|
|
1156
|
+
{
|
|
1157
|
+
throw std::runtime_error("common_get_device_memory_data is not implemented in wllama");
|
|
1158
|
+
}
|
|
1159
|
+
#else
|
|
1160
|
+
std::vector<llama_device_memory_data> common_get_device_memory_data(
|
|
1161
|
+
const char *path_model,
|
|
1162
|
+
const struct llama_model_params *mparams,
|
|
1163
|
+
const struct llama_context_params *cparams,
|
|
1164
|
+
std::vector<ggml_backend_dev_t> &devs,
|
|
1165
|
+
uint32_t &hp_ngl,
|
|
1166
|
+
uint32_t &hp_n_ctx_train,
|
|
1167
|
+
uint32_t &hp_n_expert,
|
|
1168
|
+
enum ggml_log_level log_level)
|
|
1169
|
+
{
|
|
1170
|
+
throw std::runtime_error("common_get_device_memory_data is not implemented in wllama");
|
|
1171
|
+
}
|
|
1172
|
+
#endif
|