node-llama-cpp 2.7.5 → 3.0.0-beta.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. package/README.md +1 -1
  2. package/dist/chatWrappers/generateContextTextFromConversationHistory.d.ts +0 -8
  3. package/dist/chatWrappers/generateContextTextFromConversationHistory.js +0 -8
  4. package/dist/chatWrappers/generateContextTextFromConversationHistory.js.map +1 -1
  5. package/dist/chatWrappers/resolveChatWrapperBasedOnModel.d.ts +13 -0
  6. package/dist/chatWrappers/resolveChatWrapperBasedOnModel.js +49 -0
  7. package/dist/chatWrappers/resolveChatWrapperBasedOnModel.js.map +1 -0
  8. package/dist/cli/cli.js +1 -1
  9. package/dist/cli/cli.js.map +1 -1
  10. package/dist/cli/commands/BuildCommand.d.ts +2 -2
  11. package/dist/cli/commands/BuildCommand.js +1 -1
  12. package/dist/cli/commands/BuildCommand.js.map +1 -1
  13. package/dist/cli/commands/ChatCommand.js +20 -10
  14. package/dist/cli/commands/ChatCommand.js.map +1 -1
  15. package/dist/cli/commands/ClearCommand.js +2 -1
  16. package/dist/cli/commands/ClearCommand.js.map +1 -1
  17. package/dist/cli/commands/DownloadCommand.d.ts +4 -5
  18. package/dist/cli/commands/DownloadCommand.js +3 -2
  19. package/dist/cli/commands/DownloadCommand.js.map +1 -1
  20. package/dist/commands.d.ts +2 -1
  21. package/dist/commands.js +2 -1
  22. package/dist/commands.js.map +1 -1
  23. package/dist/config.d.ts +1 -0
  24. package/dist/config.js +1 -0
  25. package/dist/config.js.map +1 -1
  26. package/dist/index.d.ts +7 -4
  27. package/dist/index.js +6 -4
  28. package/dist/index.js.map +1 -1
  29. package/dist/llamaEvaluator/LlamaBins.d.ts +19 -4
  30. package/dist/llamaEvaluator/LlamaBins.js +3 -3
  31. package/dist/llamaEvaluator/LlamaChatSession.d.ts +24 -23
  32. package/dist/llamaEvaluator/LlamaChatSession.js +90 -36
  33. package/dist/llamaEvaluator/LlamaChatSession.js.map +1 -1
  34. package/dist/llamaEvaluator/LlamaContext/LlamaContext.d.ts +112 -0
  35. package/dist/llamaEvaluator/LlamaContext/LlamaContext.js +640 -0
  36. package/dist/llamaEvaluator/LlamaContext/LlamaContext.js.map +1 -0
  37. package/dist/llamaEvaluator/LlamaContext/types.d.ts +90 -0
  38. package/dist/llamaEvaluator/LlamaContext/types.js +2 -0
  39. package/dist/llamaEvaluator/LlamaContext/types.js.map +1 -0
  40. package/dist/llamaEvaluator/LlamaContext/utils/batchItemsPrioritizingStrategies/firstInFirstOutStrategy.d.ts +5 -0
  41. package/dist/llamaEvaluator/LlamaContext/utils/batchItemsPrioritizingStrategies/firstInFirstOutStrategy.js +16 -0
  42. package/dist/llamaEvaluator/LlamaContext/utils/batchItemsPrioritizingStrategies/firstInFirstOutStrategy.js.map +1 -0
  43. package/dist/llamaEvaluator/LlamaContext/utils/batchItemsPrioritizingStrategies/maximumParallelismStrategy.d.ts +5 -0
  44. package/dist/llamaEvaluator/LlamaContext/utils/batchItemsPrioritizingStrategies/maximumParallelismStrategy.js +42 -0
  45. package/dist/llamaEvaluator/LlamaContext/utils/batchItemsPrioritizingStrategies/maximumParallelismStrategy.js.map +1 -0
  46. package/dist/llamaEvaluator/LlamaContext/utils/resolveBatchItemsPrioritizingStrategy.d.ts +2 -0
  47. package/dist/llamaEvaluator/LlamaContext/utils/resolveBatchItemsPrioritizingStrategy.js +13 -0
  48. package/dist/llamaEvaluator/LlamaContext/utils/resolveBatchItemsPrioritizingStrategy.js.map +1 -0
  49. package/dist/llamaEvaluator/LlamaGrammar.d.ts +5 -5
  50. package/dist/llamaEvaluator/LlamaGrammar.js +7 -7
  51. package/dist/llamaEvaluator/LlamaGrammarEvaluationState.d.ts +6 -5
  52. package/dist/llamaEvaluator/LlamaGrammarEvaluationState.js +8 -7
  53. package/dist/llamaEvaluator/LlamaGrammarEvaluationState.js.map +1 -1
  54. package/dist/llamaEvaluator/LlamaModel.d.ts +93 -112
  55. package/dist/llamaEvaluator/LlamaModel.js +294 -59
  56. package/dist/llamaEvaluator/LlamaModel.js.map +1 -1
  57. package/dist/types.d.ts +3 -1
  58. package/dist/utils/ReplHistory.js +1 -1
  59. package/dist/utils/ReplHistory.js.map +1 -1
  60. package/dist/utils/cloneLlamaCppRepo.d.ts +1 -0
  61. package/dist/utils/cloneLlamaCppRepo.js +26 -1
  62. package/dist/utils/cloneLlamaCppRepo.js.map +1 -1
  63. package/dist/utils/getBin.d.ts +71 -39
  64. package/dist/utils/getBin.js.map +1 -1
  65. package/dist/utils/getBuildDefaults.d.ts +6 -0
  66. package/dist/utils/getBuildDefaults.js +10 -0
  67. package/dist/utils/getBuildDefaults.js.map +1 -0
  68. package/dist/utils/getReleaseInfo.d.ts +7 -0
  69. package/dist/utils/getReleaseInfo.js +30 -0
  70. package/dist/utils/getReleaseInfo.js.map +1 -0
  71. package/dist/utils/parseModelFileName.d.ts +9 -0
  72. package/dist/utils/parseModelFileName.js +68 -0
  73. package/dist/utils/parseModelFileName.js.map +1 -0
  74. package/dist/utils/parseModelTypeDescription.d.ts +6 -0
  75. package/dist/utils/parseModelTypeDescription.js +9 -0
  76. package/dist/utils/parseModelTypeDescription.js.map +1 -0
  77. package/llama/.clang-format +10 -9
  78. package/llama/addon.cpp +689 -356
  79. package/llama/binariesGithubRelease.json +1 -1
  80. package/llama/gitRelease.bundle +0 -0
  81. package/llama/grammars/README.md +2 -2
  82. package/llamaBins/linux-arm64/llama-addon.node +0 -0
  83. package/llamaBins/linux-armv7l/llama-addon.node +0 -0
  84. package/llamaBins/linux-x64/llama-addon.node +0 -0
  85. package/llamaBins/mac-arm64/ggml-metal.metal +107 -1
  86. package/llamaBins/mac-arm64/llama-addon.node +0 -0
  87. package/llamaBins/mac-x64/ggml-metal.metal +107 -1
  88. package/llamaBins/mac-x64/llama-addon.node +0 -0
  89. package/llamaBins/win-x64/llama-addon.exp +0 -0
  90. package/llamaBins/win-x64/llama-addon.lib +0 -0
  91. package/llamaBins/win-x64/llama-addon.node +0 -0
  92. package/package.json +13 -7
  93. package/dist/chatWrappers/createChatWrapperByBos.d.ts +0 -2
  94. package/dist/chatWrappers/createChatWrapperByBos.js +0 -14
  95. package/dist/chatWrappers/createChatWrapperByBos.js.map +0 -1
  96. package/dist/llamaEvaluator/LlamaContext.d.ts +0 -100
  97. package/dist/llamaEvaluator/LlamaContext.js +0 -141
  98. package/dist/llamaEvaluator/LlamaContext.js.map +0 -1
  99. package/dist/utils/withLock.d.ts +0 -1
  100. package/dist/utils/withLock.js +0 -19
  101. package/dist/utils/withLock.js.map +0 -1
package/llama/addon.cpp CHANGED
@@ -1,446 +1,779 @@
1
1
  #include <stddef.h>
2
+
2
3
  #include <algorithm>
3
4
  #include <sstream>
4
5
  #include <vector>
5
6
 
6
7
  #include "common.h"
7
- #include "llama.h"
8
8
  #include "common/grammar-parser.h"
9
+ #include "llama.h"
9
10
  #include "napi.h"
10
11
 
11
- class LLAMAModel : public Napi::ObjectWrap<LLAMAModel> {
12
- public:
13
- llama_model_params model_params;
14
- llama_model* model;
12
+ std::string addon_model_token_to_piece(const struct llama_model * model, llama_token token) {
13
+ std::vector<char> result(8, 0);
14
+ const int n_tokens = llama_token_to_piece(model, token, result.data(), result.size());
15
+ if (n_tokens < 0) {
16
+ result.resize(-n_tokens);
17
+ int check = llama_token_to_piece(model, token, result.data(), result.size());
18
+ GGML_ASSERT(check == -n_tokens);
19
+ }
20
+ else {
21
+ result.resize(n_tokens);
22
+ }
23
+
24
+ return std::string(result.data(), result.size());
25
+ }
26
+
27
+ class AddonModel : public Napi::ObjectWrap<AddonModel> {
28
+ public:
29
+ llama_model_params model_params;
30
+ llama_model* model;
31
+ bool disposed = false;
32
+
33
+ AddonModel(const Napi::CallbackInfo& info) : Napi::ObjectWrap<AddonModel>(info) {
34
+ model_params = llama_model_default_params();
35
+
36
+ // Get the model path
37
+ std::string modelPath = info[0].As<Napi::String>().Utf8Value();
38
+
39
+ if (info.Length() > 1 && info[1].IsObject()) {
40
+ Napi::Object options = info[1].As<Napi::Object>();
15
41
 
16
- LLAMAModel(const Napi::CallbackInfo& info) : Napi::ObjectWrap<LLAMAModel>(info) {
17
- model_params = llama_model_default_params();
42
+ if (options.Has("gpuLayers")) {
43
+ model_params.n_gpu_layers = options.Get("gpuLayers").As<Napi::Number>().Int32Value();
44
+ }
18
45
 
19
- // Get the model path
20
- std::string modelPath = info[0].As<Napi::String>().Utf8Value();
46
+ if (options.Has("vocabOnly")) {
47
+ model_params.vocab_only = options.Get("vocabOnly").As<Napi::Boolean>().Value();
48
+ }
21
49
 
22
- if (info.Length() > 1 && info[1].IsObject()) {
23
- Napi::Object options = info[1].As<Napi::Object>();
50
+ if (options.Has("useMmap")) {
51
+ model_params.use_mmap = options.Get("useMmap").As<Napi::Boolean>().Value();
52
+ }
24
53
 
25
- if (options.Has("gpuLayers")) {
26
- model_params.n_gpu_layers = options.Get("gpuLayers").As<Napi::Number>().Int32Value();
54
+ if (options.Has("useMlock")) {
55
+ model_params.use_mlock = options.Get("useMlock").As<Napi::Boolean>().Value();
56
+ }
27
57
  }
28
58
 
29
- if (options.Has("vocabOnly")) {
30
- model_params.vocab_only = options.Get("vocabOnly").As<Napi::Boolean>().Value();
59
+ llama_backend_init(false);
60
+ model = llama_load_model_from_file(modelPath.c_str(), model_params);
61
+
62
+ if (model == NULL) {
63
+ Napi::Error::New(info.Env(), "Failed to load model").ThrowAsJavaScriptException();
64
+ return;
31
65
  }
66
+ }
32
67
 
33
- if (options.Has("useMmap")) {
34
- model_params.use_mmap = options.Get("useMmap").As<Napi::Boolean>().Value();
68
+ ~AddonModel() {
69
+ dispose();
70
+ }
71
+
72
+ void dispose() {
73
+ if (disposed) {
74
+ return;
35
75
  }
36
76
 
37
- if (options.Has("useMlock")) {
38
- model_params.use_mlock = options.Get("useMlock").As<Napi::Boolean>().Value();
77
+ llama_free_model(model);
78
+ disposed = true;
79
+ }
80
+
81
+ Napi::Value Dispose(const Napi::CallbackInfo& info) {
82
+ if (disposed) {
83
+ return info.Env().Undefined();
39
84
  }
85
+
86
+ dispose();
87
+
88
+ return info.Env().Undefined();
40
89
  }
41
90
 
42
- llama_backend_init(false);
43
- model = llama_load_model_from_file(modelPath.c_str(), model_params);
91
+ Napi::Value Tokenize(const Napi::CallbackInfo& info) {
92
+ if (disposed) {
93
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
94
+ return info.Env().Undefined();
95
+ }
96
+
97
+ std::string text = info[0].As<Napi::String>().Utf8Value();
44
98
 
45
- if (model == NULL) {
46
- Napi::Error::New(info.Env(), "Failed to load model").ThrowAsJavaScriptException();
47
- return;
99
+ std::vector<llama_token> tokens = llama_tokenize(model, text, true, true);
100
+
101
+ Napi::Uint32Array result = Napi::Uint32Array::New(info.Env(), tokens.size());
102
+ for (size_t i = 0; i < tokens.size(); ++i) {
103
+ result[i] = static_cast<uint32_t>(tokens[i]);
104
+ }
105
+
106
+ return result;
48
107
  }
49
- }
108
+ Napi::Value Detokenize(const Napi::CallbackInfo& info) {
109
+ if (disposed) {
110
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
111
+ return info.Env().Undefined();
112
+ }
50
113
 
51
- ~LLAMAModel() {
52
- llama_free_model(model);
53
- }
114
+ Napi::Uint32Array tokens = info[0].As<Napi::Uint32Array>();
54
115
 
55
- static void init(Napi::Object exports) {
56
- exports.Set("LLAMAModel", DefineClass(exports.Env(), "LLAMAModel", {}));
57
- }
58
- };
116
+ // Create a stringstream for accumulating the decoded string.
117
+ std::stringstream ss;
59
118
 
60
- class LLAMAGrammar : public Napi::ObjectWrap<LLAMAGrammar> {
61
- public:
62
- grammar_parser::parse_state parsed_grammar;
119
+ // Decode each token and accumulate the result.
120
+ for (size_t i = 0; i < tokens.ElementLength(); i++) {
121
+ const std::string piece = addon_model_token_to_piece(model, (llama_token)tokens[i]);
63
122
 
64
- LLAMAGrammar(const Napi::CallbackInfo& info) : Napi::ObjectWrap<LLAMAGrammar>(info) {
65
- // Get the model path
66
- std::string grammarCode = info[0].As<Napi::String>().Utf8Value();
67
- bool should_print_grammar = false;
123
+ if (piece.empty()) {
124
+ continue;
125
+ }
68
126
 
69
- if (info.Length() > 1 && info[1].IsObject()) {
70
- Napi::Object options = info[1].As<Napi::Object>();
127
+ ss << piece;
128
+ }
129
+
130
+ return Napi::String::New(info.Env(), ss.str());
131
+ }
71
132
 
72
- if (options.Has("printGrammar")) {
73
- should_print_grammar = options.Get("printGrammar").As<Napi::Boolean>().Value();
133
+ Napi::Value GetTrainContextSize(const Napi::CallbackInfo& info) {
134
+ if (disposed) {
135
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
136
+ return info.Env().Undefined();
74
137
  }
138
+
139
+ return Napi::Number::From(info.Env(), llama_n_ctx_train(model));
75
140
  }
76
141
 
77
- parsed_grammar = grammar_parser::parse(grammarCode.c_str());
78
- // will be empty (default) if there are parse errors
79
- if (parsed_grammar.rules.empty()) {
80
- Napi::Error::New(info.Env(), "Failed to parse grammar").ThrowAsJavaScriptException();
81
- return;
142
+ Napi::Value GetTotalSize(const Napi::CallbackInfo& info) {
143
+ if (disposed) {
144
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
145
+ return info.Env().Undefined();
146
+ }
147
+
148
+ return Napi::Number::From(info.Env(), llama_model_size(model));
82
149
  }
83
150
 
84
- if (should_print_grammar) {
85
- grammar_parser::print_grammar(stderr, parsed_grammar);
151
+ Napi::Value GetTotalParameters(const Napi::CallbackInfo& info) {
152
+ if (disposed) {
153
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
154
+ return info.Env().Undefined();
155
+ }
156
+
157
+ return Napi::Number::From(info.Env(), llama_model_n_params(model));
86
158
  }
87
- }
88
159
 
89
- static void init(Napi::Object exports) {
90
- exports.Set("LLAMAGrammar", DefineClass(exports.Env(), "LLAMAGrammar", {}));
91
- }
92
- };
160
+ Napi::Value GetModelDescription(const Napi::CallbackInfo& info) {
161
+ if (disposed) {
162
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
163
+ return info.Env().Undefined();
164
+ }
165
+
166
+ char model_desc[128];
167
+ int actual_length = llama_model_desc(model, model_desc, sizeof(model_desc));
168
+
169
+ return Napi::String::New(info.Env(), model_desc, actual_length);
170
+ }
93
171
 
94
- class LLAMAGrammarEvaluationState : public Napi::ObjectWrap<LLAMAGrammarEvaluationState> {
95
- public:
96
- LLAMAGrammar* grammarDef;
97
- llama_grammar *grammar = nullptr;
172
+ Napi::Value TokenBos(const Napi::CallbackInfo& info) {
173
+ if (disposed) {
174
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
175
+ return info.Env().Undefined();
176
+ }
98
177
 
99
- LLAMAGrammarEvaluationState(const Napi::CallbackInfo& info) : Napi::ObjectWrap<LLAMAGrammarEvaluationState>(info) {
100
- grammarDef = Napi::ObjectWrap<LLAMAGrammar>::Unwrap(info[0].As<Napi::Object>());
101
- grammarDef->Ref();
178
+ return Napi::Number::From(info.Env(), llama_token_bos(model));
179
+ }
180
+ Napi::Value TokenEos(const Napi::CallbackInfo& info) {
181
+ if (disposed) {
182
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
183
+ return info.Env().Undefined();
184
+ }
102
185
 
103
- std::vector<const llama_grammar_element *> grammar_rules(grammarDef->parsed_grammar.c_rules());
104
- grammar = llama_grammar_init(
105
- grammar_rules.data(), grammar_rules.size(), grammarDef->parsed_grammar.symbol_ids.at("root")
106
- );
107
- }
186
+ return Napi::Number::From(info.Env(), llama_token_eos(model));
187
+ }
188
+ Napi::Value TokenNl(const Napi::CallbackInfo& info) {
189
+ if (disposed) {
190
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
191
+ return info.Env().Undefined();
192
+ }
108
193
 
109
- ~LLAMAGrammarEvaluationState() {
110
- grammarDef->Unref();
194
+ return Napi::Number::From(info.Env(), llama_token_nl(model));
195
+ }
196
+ Napi::Value PrefixToken(const Napi::CallbackInfo& info) {
197
+ if (disposed) {
198
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
199
+ return info.Env().Undefined();
200
+ }
111
201
 
112
- if (grammar != nullptr) {
113
- llama_grammar_free(grammar);
114
- grammar = nullptr;
202
+ return Napi::Number::From(info.Env(), llama_token_prefix(model));
115
203
  }
116
- }
204
+ Napi::Value MiddleToken(const Napi::CallbackInfo& info) {
205
+ if (disposed) {
206
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
207
+ return info.Env().Undefined();
208
+ }
117
209
 
118
- static void init(Napi::Object exports) {
119
- exports.Set("LLAMAGrammarEvaluationState", DefineClass(exports.Env(), "LLAMAGrammarEvaluationState", {}));
120
- }
210
+ return Napi::Number::From(info.Env(), llama_token_middle(model));
211
+ }
212
+ Napi::Value SuffixToken(const Napi::CallbackInfo& info) {
213
+ if (disposed) {
214
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
215
+ return info.Env().Undefined();
216
+ }
217
+
218
+ return Napi::Number::From(info.Env(), llama_token_suffix(model));
219
+ }
220
+ Napi::Value EotToken(const Napi::CallbackInfo& info) {
221
+ if (disposed) {
222
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
223
+ return info.Env().Undefined();
224
+ }
225
+
226
+ return Napi::Number::From(info.Env(), llama_token_eot(model));
227
+ }
228
+ Napi::Value GetTokenString(const Napi::CallbackInfo& info) {
229
+ if (disposed) {
230
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
231
+ return info.Env().Undefined();
232
+ }
233
+
234
+ int token = info[0].As<Napi::Number>().Int32Value();
235
+ std::stringstream ss;
236
+
237
+ const char* str = llama_token_get_text(model, token);
238
+ if (str == nullptr) {
239
+ return info.Env().Undefined();
240
+ }
241
+
242
+ ss << str;
243
+
244
+ return Napi::String::New(info.Env(), ss.str());
245
+ }
246
+
247
+ static void init(Napi::Object exports) {
248
+ exports.Set(
249
+ "AddonModel",
250
+ DefineClass(
251
+ exports.Env(),
252
+ "AddonModel",
253
+ {
254
+ InstanceMethod("tokenize", &AddonModel::Tokenize),
255
+ InstanceMethod("detokenize", &AddonModel::Detokenize),
256
+ InstanceMethod("getTrainContextSize", &AddonModel::GetTrainContextSize),
257
+ InstanceMethod("getTotalSize", &AddonModel::GetTotalSize),
258
+ InstanceMethod("getTotalParameters", &AddonModel::GetTotalParameters),
259
+ InstanceMethod("getModelDescription", &AddonModel::GetModelDescription),
260
+ InstanceMethod("tokenBos", &AddonModel::TokenBos),
261
+ InstanceMethod("tokenEos", &AddonModel::TokenEos),
262
+ InstanceMethod("tokenNl", &AddonModel::TokenNl),
263
+ InstanceMethod("prefixToken", &AddonModel::PrefixToken),
264
+ InstanceMethod("middleToken", &AddonModel::MiddleToken),
265
+ InstanceMethod("suffixToken", &AddonModel::SuffixToken),
266
+ InstanceMethod("eotToken", &AddonModel::EotToken),
267
+ InstanceMethod("getTokenString", &AddonModel::GetTokenString),
268
+ InstanceMethod("dispose", &AddonModel::Dispose)
269
+ }
270
+ )
271
+ );
272
+ }
121
273
  };
122
274
 
123
- class LLAMAContext : public Napi::ObjectWrap<LLAMAContext> {
124
- public:
125
- LLAMAModel* model;
126
- llama_context_params context_params;
127
- llama_context* ctx;
128
- int n_cur = 0;
129
-
130
- LLAMAContext(const Napi::CallbackInfo& info) : Napi::ObjectWrap<LLAMAContext>(info) {
131
- model = Napi::ObjectWrap<LLAMAModel>::Unwrap(info[0].As<Napi::Object>());
132
- model->Ref();
133
-
134
- context_params = llama_context_default_params();
135
- context_params.seed = -1;
136
- context_params.n_ctx = 4096;
137
- context_params.n_threads = 6;
138
- context_params.n_threads_batch == -1 ? context_params.n_threads : context_params.n_threads_batch;
139
-
140
- if (info.Length() > 1 && info[1].IsObject()) {
141
- Napi::Object options = info[1].As<Napi::Object>();
142
-
143
- if (options.Has("seed")) {
144
- context_params.seed = options.Get("seed").As<Napi::Number>().Int32Value();
145
- }
146
-
147
- if (options.Has("contextSize")) {
148
- context_params.n_ctx = options.Get("contextSize").As<Napi::Number>().Int32Value();
149
- }
150
-
151
- if (options.Has("batchSize")) {
152
- context_params.n_batch = options.Get("batchSize").As<Napi::Number>().Int32Value();
153
- }
154
-
155
- if (options.Has("f16Kv")) {
156
- context_params.f16_kv = options.Get("f16Kv").As<Napi::Boolean>().Value();
157
- }
158
-
159
- if (options.Has("logitsAll")) {
160
- context_params.logits_all = options.Get("logitsAll").As<Napi::Boolean>().Value();
161
- }
162
-
163
- if (options.Has("embedding")) {
164
- context_params.embedding = options.Get("embedding").As<Napi::Boolean>().Value();
165
- }
166
-
167
- if (options.Has("threads")) {
168
- context_params.n_threads = options.Get("threads").As<Napi::Number>().Int32Value();
169
- context_params.n_threads_batch == -1 ? context_params.n_threads : context_params.n_threads_batch;
170
- }
171
- }
275
+ class AddonGrammar : public Napi::ObjectWrap<AddonGrammar> {
276
+ public:
277
+ grammar_parser::parse_state parsed_grammar;
172
278
 
173
- ctx = llama_new_context_with_model(model->model, context_params);
174
- Napi::MemoryManagement::AdjustExternalMemory(Env(), llama_get_state_size(ctx));
175
- }
176
- ~LLAMAContext() {
177
- Napi::MemoryManagement::AdjustExternalMemory(Env(), -(int64_t)llama_get_state_size(ctx));
178
- llama_free(ctx);
179
- model->Unref();
180
- }
181
- Napi::Value Encode(const Napi::CallbackInfo& info) {
182
- std::string text = info[0].As<Napi::String>().Utf8Value();
279
+ AddonGrammar(const Napi::CallbackInfo& info) : Napi::ObjectWrap<AddonGrammar>(info) {
280
+ // Get the model path
281
+ std::string grammarCode = info[0].As<Napi::String>().Utf8Value();
282
+ bool should_print_grammar = false;
183
283
 
184
- std::vector<llama_token> tokens = llama_tokenize(ctx, text, false);
284
+ if (info.Length() > 1 && info[1].IsObject()) {
285
+ Napi::Object options = info[1].As<Napi::Object>();
185
286
 
186
- Napi::Uint32Array result = Napi::Uint32Array::New(info.Env(), tokens.size());
187
- for (size_t i = 0; i < tokens.size(); ++i) { result[i] = static_cast<uint32_t>(tokens[i]); }
287
+ if (options.Has("printGrammar")) {
288
+ should_print_grammar = options.Get("printGrammar").As<Napi::Boolean>().Value();
289
+ }
290
+ }
188
291
 
189
- return result;
190
- }
191
- Napi::Value Decode(const Napi::CallbackInfo& info) {
192
- Napi::Uint32Array tokens = info[0].As<Napi::Uint32Array>();
292
+ parsed_grammar = grammar_parser::parse(grammarCode.c_str());
293
+ // will be empty (default) if there are parse errors
294
+ if (parsed_grammar.rules.empty()) {
295
+ Napi::Error::New(info.Env(), "Failed to parse grammar").ThrowAsJavaScriptException();
296
+ return;
297
+ }
193
298
 
194
- // Create a stringstream for accumulating the decoded string.
195
- std::stringstream ss;
299
+ if (should_print_grammar) {
300
+ grammar_parser::print_grammar(stderr, parsed_grammar);
301
+ }
302
+ }
303
+
304
+ static void init(Napi::Object exports) {
305
+ exports.Set("AddonGrammar", DefineClass(exports.Env(), "AddonGrammar", {}));
306
+ }
307
+ };
308
+
309
+ class AddonGrammarEvaluationState : public Napi::ObjectWrap<AddonGrammarEvaluationState> {
310
+ public:
311
+ AddonGrammar* grammarDef;
312
+ llama_grammar* grammar = nullptr;
196
313
 
197
- // Decode each token and accumulate the result.
198
- for (size_t i = 0; i < tokens.ElementLength(); i++) {
199
- const std::string piece = llama_token_to_piece(ctx, (llama_token)tokens[i]);
314
+ AddonGrammarEvaluationState(const Napi::CallbackInfo& info) : Napi::ObjectWrap<AddonGrammarEvaluationState>(info) {
315
+ grammarDef = Napi::ObjectWrap<AddonGrammar>::Unwrap(info[0].As<Napi::Object>());
316
+ grammarDef->Ref();
200
317
 
201
- if (piece.empty()) {
202
- continue;
318
+ std::vector<const llama_grammar_element*> grammar_rules(grammarDef->parsed_grammar.c_rules());
319
+ grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), grammarDef->parsed_grammar.symbol_ids.at("root"));
203
320
  }
204
321
 
205
- ss << piece;
206
- }
322
+ ~AddonGrammarEvaluationState() {
323
+ grammarDef->Unref();
207
324
 
208
- return Napi::String::New(info.Env(), ss.str());
209
- }
210
- Napi::Value TokenBos(const Napi::CallbackInfo& info) {
211
- return Napi::Number::From(info.Env(), llama_token_bos(model->model)); // TODO: move this to the model
212
- }
213
- Napi::Value TokenEos(const Napi::CallbackInfo& info) {
214
- return Napi::Number::From(info.Env(), llama_token_eos(model->model)); // TODO: move this to the model
215
- }
216
- Napi::Value TokenNl(const Napi::CallbackInfo& info) {
217
- return Napi::Number::From(info.Env(), llama_token_nl(model->model)); // TODO: move this to the model
218
- }
219
- Napi::Value GetContextSize(const Napi::CallbackInfo& info) {
220
- return Napi::Number::From(info.Env(), llama_n_ctx(ctx));
221
- }
222
- Napi::Value GetTokenString(const Napi::CallbackInfo& info) {
223
- int token = info[0].As<Napi::Number>().Int32Value();
224
- std::stringstream ss;
225
-
226
- const char* str = llama_token_get_text(model->model, token); // TODO: move this to the model
227
- if (str == nullptr) {
228
- return info.Env().Undefined();
229
- }
325
+ if (grammar != nullptr) {
326
+ llama_grammar_free(grammar);
327
+ grammar = nullptr;
328
+ }
329
+ }
230
330
 
231
- ss << str;
232
-
233
- return Napi::String::New(info.Env(), ss.str());
234
- }
235
- Napi::Value Eval(const Napi::CallbackInfo& info);
236
- static void init(Napi::Object exports) {
237
- exports.Set("LLAMAContext",
238
- DefineClass(exports.Env(),
239
- "LLAMAContext",
240
- {
241
- InstanceMethod("encode", &LLAMAContext::Encode),
242
- InstanceMethod("decode", &LLAMAContext::Decode),
243
- InstanceMethod("tokenBos", &LLAMAContext::TokenBos),
244
- InstanceMethod("tokenEos", &LLAMAContext::TokenEos),
245
- InstanceMethod("tokenNl", &LLAMAContext::TokenNl),
246
- InstanceMethod("getContextSize", &LLAMAContext::GetContextSize),
247
- InstanceMethod("getTokenString", &LLAMAContext::GetTokenString),
248
- InstanceMethod("eval", &LLAMAContext::Eval),
249
- }));
250
- }
331
+ static void init(Napi::Object exports) {
332
+ exports.Set("AddonGrammarEvaluationState", DefineClass(exports.Env(), "AddonGrammarEvaluationState", {}));
333
+ }
251
334
  };
252
335
 
336
+ class AddonContext : public Napi::ObjectWrap<AddonContext> {
337
+ public:
338
+ AddonModel* model;
339
+ llama_context_params context_params;
340
+ llama_context* ctx;
341
+ llama_batch batch;
342
+ bool has_batch = false;
343
+ int32_t batch_n_tokens = 0;
344
+ int n_cur = 0;
345
+ bool disposed = false;
346
+
347
+ AddonContext(const Napi::CallbackInfo& info) : Napi::ObjectWrap<AddonContext>(info) {
348
+ model = Napi::ObjectWrap<AddonModel>::Unwrap(info[0].As<Napi::Object>());
349
+ model->Ref();
350
+
351
+ context_params = llama_context_default_params();
352
+ context_params.seed = -1;
353
+ context_params.n_ctx = 4096;
354
+ context_params.n_threads = 6;
355
+ context_params.n_threads_batch == -1 ? context_params.n_threads : context_params.n_threads_batch;
356
+
357
+ if (info.Length() > 1 && info[1].IsObject()) {
358
+ Napi::Object options = info[1].As<Napi::Object>();
359
+
360
+ if (options.Has("seed")) {
361
+ context_params.seed = options.Get("seed").As<Napi::Number>().Int32Value();
362
+ }
363
+
364
+ if (options.Has("contextSize")) {
365
+ context_params.n_ctx = options.Get("contextSize").As<Napi::Number>().Int32Value();
366
+ }
367
+
368
+ if (options.Has("batchSize")) {
369
+ context_params.n_batch = options.Get("batchSize").As<Napi::Number>().Int32Value();
370
+ }
371
+
372
+ if (options.Has("f16Kv")) {
373
+ context_params.f16_kv = options.Get("f16Kv").As<Napi::Boolean>().Value();
374
+ }
375
+
376
+ if (options.Has("logitsAll")) {
377
+ context_params.logits_all = options.Get("logitsAll").As<Napi::Boolean>().Value();
378
+ }
379
+
380
+ if (options.Has("embedding")) {
381
+ context_params.embedding = options.Get("embedding").As<Napi::Boolean>().Value();
382
+ }
383
+
384
+ if (options.Has("threads")) {
385
+ context_params.n_threads = options.Get("threads").As<Napi::Number>().Int32Value();
386
+ context_params.n_threads_batch == -1 ? context_params.n_threads : context_params.n_threads_batch;
387
+ }
388
+ }
253
389
 
254
- class LLAMAContextEvalWorker : Napi::AsyncWorker, Napi::Promise::Deferred {
255
- LLAMAContext* ctx;
256
- LLAMAGrammarEvaluationState* grammar_evaluation_state;
257
- bool use_grammar = false;
258
- std::vector<llama_token> tokens;
259
- llama_token result;
260
- float temperature;
261
- int32_t top_k;
262
- float top_p;
263
- float repeat_penalty = 1.10f; // 1.0 = disabled
264
- float repeat_penalty_presence_penalty = 0.00f; // 0.0 = disabled
265
- float repeat_penalty_frequency_penalty = 0.00f; // 0.0 = disabled
266
- std::vector<llama_token> repeat_penalty_tokens;
267
- bool use_repeat_penalty = false;
268
-
269
- public:
270
- LLAMAContextEvalWorker(const Napi::CallbackInfo& info, LLAMAContext* ctx) : Napi::AsyncWorker(info.Env(), "LLAMAContextEvalWorker"), ctx(ctx), Napi::Promise::Deferred(info.Env()) {
271
- ctx->Ref();
272
- Napi::Uint32Array tokens = info[0].As<Napi::Uint32Array>();
273
-
274
- temperature = 0.0f;
275
- top_k = 40;
276
- top_p = 0.95f;
277
-
278
- if (info.Length() > 1 && info[1].IsObject()) {
279
- Napi::Object options = info[1].As<Napi::Object>();
280
-
281
- if (options.Has("temperature")) {
282
- temperature = options.Get("temperature").As<Napi::Number>().FloatValue();
283
- }
284
-
285
- if (options.Has("topK")) {
286
- top_k = options.Get("topK").As<Napi::Number>().Int32Value();
287
- }
288
-
289
- if (options.Has("topP")) {
290
- top_p = options.Get("topP").As<Napi::Number>().FloatValue();
291
- }
292
-
293
- if (options.Has("repeatPenalty")) {
294
- repeat_penalty = options.Get("repeatPenalty").As<Napi::Number>().FloatValue();
295
- }
296
-
297
- if (options.Has("repeatPenaltyTokens")) {
298
- Napi::Uint32Array repeat_penalty_tokens_uint32_array = options.Get("repeatPenaltyTokens").As<Napi::Uint32Array>();
299
-
300
- repeat_penalty_tokens.reserve(repeat_penalty_tokens_uint32_array.ElementLength());
301
- for (size_t i = 0; i < repeat_penalty_tokens_uint32_array.ElementLength(); i++) {
302
- repeat_penalty_tokens.push_back(static_cast<llama_token>(repeat_penalty_tokens_uint32_array[i]));
303
- }
304
-
305
- use_repeat_penalty = true;
306
- }
307
-
308
- if (options.Has("repeatPenaltyPresencePenalty")) {
309
- repeat_penalty_presence_penalty = options.Get("repeatPenaltyPresencePenalty").As<Napi::Number>().FloatValue();
310
- }
311
-
312
- if (options.Has("repeatPenaltyFrequencyPenalty")) {
313
- repeat_penalty_frequency_penalty = options.Get("repeatPenaltyFrequencyPenalty").As<Napi::Number>().FloatValue();
314
- }
315
-
316
- if (options.Has("grammarEvaluationState")) {
317
- grammar_evaluation_state = Napi::ObjectWrap<LLAMAGrammarEvaluationState>::Unwrap(options.Get("grammarEvaluationState").As<Napi::Object>());
318
- grammar_evaluation_state->Ref();
319
- use_grammar = true;
320
- }
321
- }
390
+ ctx = llama_new_context_with_model(model->model, context_params);
391
+ Napi::MemoryManagement::AdjustExternalMemory(Env(), llama_get_state_size(ctx));
392
+ }
393
+ ~AddonContext() {
394
+ dispose();
395
+ }
322
396
 
323
- this->tokens.reserve(tokens.ElementLength());
324
- for (size_t i = 0; i < tokens.ElementLength(); i++) { this->tokens.push_back(static_cast<llama_token>(tokens[i])); }
325
- }
326
- ~LLAMAContextEvalWorker() {
327
- ctx->Unref();
397
+ void dispose() {
398
+ if (disposed) {
399
+ return;
400
+ }
328
401
 
329
- if (use_grammar) {
330
- grammar_evaluation_state->Unref();
331
- use_grammar = false;
332
- }
333
- }
334
- using Napi::AsyncWorker::Queue;
335
- using Napi::Promise::Deferred::Promise;
402
+ Napi::MemoryManagement::AdjustExternalMemory(Env(), -(int64_t)llama_get_state_size(ctx));
403
+ llama_free(ctx);
404
+ model->Unref();
336
405
 
337
- protected:
338
- void Execute() {
339
- llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
406
+ disposeBatch();
340
407
 
341
- for (size_t i = 0; i < tokens.size(); i++) {
342
- llama_batch_add(batch, tokens[i], ctx->n_cur, { 0 }, false);
408
+ disposed = true;
409
+ }
410
+ void disposeBatch() {
411
+ if (!has_batch) {
412
+ return;
413
+ }
343
414
 
344
- ctx->n_cur++;
345
- }
346
- GGML_ASSERT(batch.n_tokens == (int) tokens.size());
415
+ llama_batch_free(batch);
416
+ has_batch = false;
417
+ batch_n_tokens = 0;
418
+ }
419
+ Napi::Value Dispose(const Napi::CallbackInfo& info) {
420
+ if (disposed) {
421
+ return info.Env().Undefined();
422
+ }
347
423
 
348
- batch.logits[batch.n_tokens - 1] = true;
424
+ dispose();
349
425
 
350
- // Perform the evaluation using llama_decode.
351
- int r = llama_decode(ctx->ctx, batch);
426
+ return info.Env().Undefined();
427
+ }
428
+ Napi::Value GetContextSize(const Napi::CallbackInfo& info) {
429
+ if (disposed) {
430
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
431
+ return info.Env().Undefined();
432
+ }
352
433
 
353
- llama_batch_free(batch);
434
+ return Napi::Number::From(info.Env(), llama_n_ctx(ctx));
435
+ }
436
+ Napi::Value InitBatch(const Napi::CallbackInfo& info) {
437
+ if (disposed) {
438
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
439
+ return info.Env().Undefined();
440
+ }
354
441
 
355
- if (r != 0) {
356
- if (r == 1) {
357
- SetError("could not find a KV slot for the batch (try reducing the size of the batch or increase the context)");
358
- } else {
359
- SetError("Eval has failed");
360
- }
442
+ if (has_batch) {
443
+ llama_batch_free(batch);
444
+ }
361
445
 
362
- return;
363
- }
446
+ int32_t n_tokens = info[0].As<Napi::Number>().Int32Value();
364
447
 
365
- llama_token new_token_id = 0;
448
+ batch = llama_batch_init(n_tokens, 0, 1);
449
+ has_batch = true;
450
+ batch_n_tokens = n_tokens;
366
451
 
367
- // Select the best prediction.
368
- auto logits = llama_get_logits_ith(ctx->ctx, batch.n_tokens - 1);
369
- auto n_vocab = llama_n_vocab(ctx->model->model);
452
+ return info.Env().Undefined();
453
+ }
454
+ Napi::Value DisposeBatch(const Napi::CallbackInfo& info) {
455
+ if (disposed) {
456
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
457
+ return info.Env().Undefined();
458
+ }
370
459
 
371
- std::vector<llama_token_data> candidates;
372
- candidates.reserve(n_vocab);
460
+ disposeBatch();
373
461
 
374
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
375
- candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
376
- }
462
+ return info.Env().Undefined();
463
+ }
464
+ Napi::Value AddToBatch(const Napi::CallbackInfo& info) {
465
+ if (!has_batch) {
466
+ Napi::Error::New(info.Env(), "No batch is initialized").ThrowAsJavaScriptException();
467
+ return info.Env().Undefined();
468
+ }
377
469
 
378
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
470
+ int32_t sequenceId = info[0].As<Napi::Number>().Int32Value();
471
+ int32_t firstTokenContextIndex = info[1].As<Napi::Number>().Int32Value();
472
+ Napi::Uint32Array tokens = info[2].As<Napi::Uint32Array>();
473
+ bool generateLogitAtTheEnd = info[3].As<Napi::Boolean>().Value();
379
474
 
380
- auto eos_token = llama_token_eos(ctx->model->model);
475
+ auto tokensLength = tokens.ElementLength();
476
+ GGML_ASSERT(batch.n_tokens + tokensLength <= batch_n_tokens);
381
477
 
382
- if (use_repeat_penalty && !repeat_penalty_tokens.empty()) {
383
- llama_sample_repetition_penalties(
384
- ctx->ctx, &candidates_p, repeat_penalty_tokens.data(), repeat_penalty_tokens.size(), repeat_penalty,
385
- repeat_penalty_frequency_penalty, repeat_penalty_presence_penalty
386
- );
387
- }
478
+ for (size_t i = 0; i < tokensLength; i++) {
479
+ llama_batch_add(batch, static_cast<llama_token>(tokens[i]), firstTokenContextIndex + i, { sequenceId }, false);
480
+ }
388
481
 
389
- if (use_grammar && (grammar_evaluation_state)->grammar != nullptr) {
390
- llama_sample_grammar(ctx->ctx, &candidates_p, (grammar_evaluation_state)->grammar);
391
- }
482
+ if (generateLogitAtTheEnd) {
483
+ batch.logits[batch.n_tokens - 1] = true;
392
484
 
393
- if (temperature <= 0) {
394
- new_token_id = llama_sample_token_greedy(ctx->ctx , &candidates_p);
395
- } else {
396
- const int32_t resolved_top_k = top_k <= 0 ? llama_n_vocab(ctx->model->model) : std::min(top_k, llama_n_vocab(ctx->model->model));
397
- const int32_t n_probs = 0; // Number of probabilities to keep - 0 = disabled
398
- const float tfs_z = 1.00f; // Tail free sampling - 1.0 = disabled
399
- const float typical_p = 1.00f; // Typical probability - 1.0 = disabled
400
- const float resolved_top_p = top_p; // Top p sampling - 1.0 = disabled
401
-
402
- // Temperature sampling
403
- size_t min_keep = std::max(1, n_probs);
404
- llama_sample_top_k(ctx->ctx, &candidates_p, resolved_top_k, min_keep);
405
- llama_sample_tail_free(ctx->ctx, &candidates_p, tfs_z, min_keep);
406
- llama_sample_typical(ctx->ctx, &candidates_p, typical_p, min_keep);
407
- llama_sample_top_p(ctx->ctx, &candidates_p, resolved_top_p, min_keep);
408
- llama_sample_temperature(ctx->ctx, &candidates_p, temperature);
409
- new_token_id = llama_sample_token(ctx->ctx, &candidates_p);
410
- }
485
+ auto logit_index = batch.n_tokens - 1;
411
486
 
412
- if (new_token_id != eos_token && use_grammar && (grammar_evaluation_state)->grammar != nullptr) {
413
- llama_grammar_accept_token(ctx->ctx, (grammar_evaluation_state)->grammar, new_token_id);
414
- }
487
+ return Napi::Number::From(info.Env(), logit_index);
488
+ }
489
+
490
+ return info.Env().Undefined();
491
+ }
492
+ Napi::Value DisposeSequence(const Napi::CallbackInfo& info) {
493
+ if (disposed) {
494
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
495
+ return info.Env().Undefined();
496
+ }
497
+
498
+ int32_t sequenceId = info[0].As<Napi::Number>().Int32Value();
499
+
500
+ llama_kv_cache_seq_rm(ctx, sequenceId, -1, -1);
501
+
502
+ return info.Env().Undefined();
503
+ }
504
+ Napi::Value RemoveTokenCellsFromSequence(const Napi::CallbackInfo& info) {
505
+ if (disposed) {
506
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
507
+ return info.Env().Undefined();
508
+ }
509
+
510
+ int32_t sequenceId = info[0].As<Napi::Number>().Int32Value();
511
+ int32_t startPos = info[1].As<Napi::Number>().Int32Value();
512
+ int32_t endPos = info[2].As<Napi::Number>().Int32Value();
513
+
514
+ llama_kv_cache_seq_rm(ctx, sequenceId, startPos, endPos);
515
+
516
+ return info.Env().Undefined();
517
+ }
518
+ Napi::Value ShiftSequenceTokenCells(const Napi::CallbackInfo& info) {
519
+ if (disposed) {
520
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
521
+ return info.Env().Undefined();
522
+ }
523
+
524
+ int32_t sequenceId = info[0].As<Napi::Number>().Int32Value();
525
+ int32_t startPos = info[1].As<Napi::Number>().Int32Value();
526
+ int32_t endPos = info[2].As<Napi::Number>().Int32Value();
527
+ int32_t shiftDelta = info[3].As<Napi::Number>().Int32Value();
415
528
 
416
- result = new_token_id;
417
- }
418
- void OnOK() {
419
- Napi::Env env = Napi::AsyncWorker::Env();
420
- Napi::Number resultValue = Napi::Number::New(env, static_cast<uint32_t>(result));
421
- Napi::Promise::Deferred::Resolve(resultValue);
422
- }
423
- void OnError(const Napi::Error& err) { Napi::Promise::Deferred::Reject(err.Value()); }
529
+ llama_kv_cache_seq_shift(ctx, sequenceId, startPos, endPos, shiftDelta);
530
+
531
+ return info.Env().Undefined();
532
+ }
533
+ Napi::Value DecodeBatch(const Napi::CallbackInfo& info);
534
+ Napi::Value SampleToken(const Napi::CallbackInfo& info);
535
+
536
+ static void init(Napi::Object exports) {
537
+ exports.Set(
538
+ "AddonContext",
539
+ DefineClass(
540
+ exports.Env(),
541
+ "AddonContext",
542
+ {
543
+ InstanceMethod("getContextSize", &AddonContext::GetContextSize),
544
+ InstanceMethod("initBatch", &AddonContext::InitBatch),
545
+ InstanceMethod("addToBatch", &AddonContext::AddToBatch),
546
+ InstanceMethod("disposeSequence", &AddonContext::DisposeSequence),
547
+ InstanceMethod("removeTokenCellsFromSequence", &AddonContext::RemoveTokenCellsFromSequence),
548
+ InstanceMethod("shiftSequenceTokenCells", &AddonContext::ShiftSequenceTokenCells),
549
+ InstanceMethod("decodeBatch", &AddonContext::DecodeBatch),
550
+ InstanceMethod("sampleToken", &AddonContext::SampleToken),
551
+ InstanceMethod("dispose", &AddonContext::Dispose)
552
+ }
553
+ )
554
+ );
555
+ }
556
+ };
557
+
558
+
559
+ class AddonContextDecodeBatchWorker : Napi::AsyncWorker, Napi::Promise::Deferred {
560
+ public:
561
+ AddonContext* ctx;
562
+
563
+ AddonContextDecodeBatchWorker(const Napi::CallbackInfo& info, AddonContext* ctx)
564
+ : Napi::AsyncWorker(info.Env(), "AddonContextDecodeBatchWorker"),
565
+ ctx(ctx),
566
+ Napi::Promise::Deferred(info.Env()) {
567
+ ctx->Ref();
568
+ }
569
+ ~AddonContextDecodeBatchWorker() {
570
+ ctx->Unref();
571
+ }
572
+ using Napi::AsyncWorker::Queue;
573
+ using Napi::Promise::Deferred::Promise;
574
+
575
+ protected:
576
+ void Execute() {
577
+ // Perform the evaluation using llama_decode.
578
+ int r = llama_decode(ctx->ctx, ctx->batch);
579
+
580
+ if (r != 0) {
581
+ if (r == 1) {
582
+ SetError("could not find a KV slot for the batch (try reducing the size of the batch or increase the context)");
583
+ } else {
584
+ SetError("Eval has failed");
585
+ }
586
+
587
+ return;
588
+ }
589
+ }
590
+ void OnOK() {
591
+ Napi::Env env = Napi::AsyncWorker::Env();
592
+ Napi::Promise::Deferred::Resolve(env.Undefined());
593
+ }
594
+ void OnError(const Napi::Error& err) {
595
+ Napi::Promise::Deferred::Reject(err.Value());
596
+ }
597
+ };
598
+
599
+ Napi::Value AddonContext::DecodeBatch(const Napi::CallbackInfo& info) {
600
+ AddonContextDecodeBatchWorker* worker = new AddonContextDecodeBatchWorker(info, this);
601
+ worker->Queue();
602
+ return worker->Promise();
603
+ }
604
+
605
+ class AddonContextSampleTokenWorker : Napi::AsyncWorker, Napi::Promise::Deferred {
606
+ public:
607
+ AddonContext* ctx;
608
+ AddonGrammarEvaluationState* grammar_evaluation_state;
609
+ int32_t batchLogitIndex;
610
+ bool use_grammar = false;
611
+ llama_token result;
612
+ float temperature = 0.0f;
613
+ int32_t top_k = 40;
614
+ float top_p = 0.95f;
615
+ float repeat_penalty = 1.10f; // 1.0 = disabled
616
+ float repeat_penalty_presence_penalty = 0.00f; // 0.0 = disabled
617
+ float repeat_penalty_frequency_penalty = 0.00f; // 0.0 = disabled
618
+ std::vector<llama_token> repeat_penalty_tokens;
619
+ bool use_repeat_penalty = false;
620
+
621
+ AddonContextSampleTokenWorker(const Napi::CallbackInfo& info, AddonContext* ctx)
622
+ : Napi::AsyncWorker(info.Env(), "AddonContextSampleTokenWorker"),
623
+ ctx(ctx),
624
+ Napi::Promise::Deferred(info.Env()) {
625
+ ctx->Ref();
626
+
627
+ batchLogitIndex = info[0].As<Napi::Number>().Int32Value();
628
+
629
+ if (info.Length() > 1 && info[1].IsObject()) {
630
+ Napi::Object options = info[1].As<Napi::Object>();
631
+
632
+ if (options.Has("temperature")) {
633
+ temperature = options.Get("temperature").As<Napi::Number>().FloatValue();
634
+ }
635
+
636
+ if (options.Has("topK")) {
637
+ top_k = options.Get("topK").As<Napi::Number>().Int32Value();
638
+ }
639
+
640
+ if (options.Has("topP")) {
641
+ top_p = options.Get("topP").As<Napi::Number>().FloatValue();
642
+ }
643
+
644
+ if (options.Has("repeatPenalty")) {
645
+ repeat_penalty = options.Get("repeatPenalty").As<Napi::Number>().FloatValue();
646
+ }
647
+
648
+ if (options.Has("repeatPenaltyTokens")) {
649
+ Napi::Uint32Array repeat_penalty_tokens_uint32_array = options.Get("repeatPenaltyTokens").As<Napi::Uint32Array>();
650
+
651
+ repeat_penalty_tokens.reserve(repeat_penalty_tokens_uint32_array.ElementLength());
652
+ for (size_t i = 0; i < repeat_penalty_tokens_uint32_array.ElementLength(); i++) {
653
+ repeat_penalty_tokens.push_back(static_cast<llama_token>(repeat_penalty_tokens_uint32_array[i]));
654
+ }
655
+
656
+ use_repeat_penalty = true;
657
+ }
658
+
659
+ if (options.Has("repeatPenaltyPresencePenalty")) {
660
+ repeat_penalty_presence_penalty = options.Get("repeatPenaltyPresencePenalty").As<Napi::Number>().FloatValue();
661
+ }
662
+
663
+ if (options.Has("repeatPenaltyFrequencyPenalty")) {
664
+ repeat_penalty_frequency_penalty = options.Get("repeatPenaltyFrequencyPenalty").As<Napi::Number>().FloatValue();
665
+ }
666
+
667
+ if (options.Has("grammarEvaluationState")) {
668
+ grammar_evaluation_state =
669
+ Napi::ObjectWrap<AddonGrammarEvaluationState>::Unwrap(options.Get("grammarEvaluationState").As<Napi::Object>());
670
+ grammar_evaluation_state->Ref();
671
+ use_grammar = true;
672
+ }
673
+ }
674
+ }
675
+ ~AddonContextSampleTokenWorker() {
676
+ ctx->Unref();
677
+
678
+ if (use_grammar) {
679
+ grammar_evaluation_state->Unref();
680
+ use_grammar = false;
681
+ }
682
+ }
683
+ using Napi::AsyncWorker::Queue;
684
+ using Napi::Promise::Deferred::Promise;
685
+
686
+ protected:
687
+ void Execute() {
688
+ llama_token new_token_id = 0;
689
+
690
+ // Select the best prediction.
691
+ auto logits = llama_get_logits_ith(ctx->ctx, batchLogitIndex);
692
+ auto n_vocab = llama_n_vocab(ctx->model->model);
693
+
694
+ std::vector<llama_token_data> candidates;
695
+ candidates.reserve(n_vocab);
696
+
697
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
698
+ candidates.emplace_back(llama_token_data { token_id, logits[token_id], 0.0f });
699
+ }
700
+
701
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
702
+
703
+ auto eos_token = llama_token_eos(ctx->model->model);
704
+
705
+ if (use_repeat_penalty && !repeat_penalty_tokens.empty()) {
706
+ llama_sample_repetition_penalties(
707
+ ctx->ctx,
708
+ &candidates_p,
709
+ repeat_penalty_tokens.data(),
710
+ repeat_penalty_tokens.size(),
711
+ repeat_penalty,
712
+ repeat_penalty_frequency_penalty,
713
+ repeat_penalty_presence_penalty
714
+ );
715
+ }
716
+
717
+ if (use_grammar && (grammar_evaluation_state)->grammar != nullptr) {
718
+ llama_sample_grammar(ctx->ctx, &candidates_p, (grammar_evaluation_state)->grammar);
719
+ }
720
+
721
+ if (temperature <= 0) {
722
+ new_token_id = llama_sample_token_greedy(ctx->ctx, &candidates_p);
723
+ } else {
724
+ const int32_t resolved_top_k =
725
+ top_k <= 0 ? llama_n_vocab(ctx->model->model) : std::min(top_k, llama_n_vocab(ctx->model->model));
726
+ const int32_t n_probs = 0; // Number of probabilities to keep - 0 = disabled
727
+ const float tfs_z = 1.00f; // Tail free sampling - 1.0 = disabled
728
+ const float typical_p = 1.00f; // Typical probability - 1.0 = disabled
729
+ const float resolved_top_p = top_p; // Top p sampling - 1.0 = disabled
730
+
731
+ // Temperature sampling
732
+ size_t min_keep = std::max(1, n_probs);
733
+ llama_sample_top_k(ctx->ctx, &candidates_p, resolved_top_k, min_keep);
734
+ llama_sample_tail_free(ctx->ctx, &candidates_p, tfs_z, min_keep);
735
+ llama_sample_typical(ctx->ctx, &candidates_p, typical_p, min_keep);
736
+ llama_sample_top_p(ctx->ctx, &candidates_p, resolved_top_p, min_keep);
737
+ llama_sample_temp(ctx->ctx, &candidates_p, temperature);
738
+ new_token_id = llama_sample_token(ctx->ctx, &candidates_p);
739
+ }
740
+
741
+ if (new_token_id != eos_token && use_grammar && (grammar_evaluation_state)->grammar != nullptr) {
742
+ llama_grammar_accept_token(ctx->ctx, (grammar_evaluation_state)->grammar, new_token_id);
743
+ }
744
+
745
+ result = new_token_id;
746
+ }
747
+ void OnOK() {
748
+ Napi::Env env = Napi::AsyncWorker::Env();
749
+ Napi::Number resultValue = Napi::Number::New(env, static_cast<uint32_t>(result));
750
+ Napi::Promise::Deferred::Resolve(resultValue);
751
+ }
752
+ void OnError(const Napi::Error& err) {
753
+ Napi::Promise::Deferred::Reject(err.Value());
754
+ }
424
755
  };
425
756
 
426
- Napi::Value LLAMAContext::Eval(const Napi::CallbackInfo& info) {
427
- LLAMAContextEvalWorker* worker = new LLAMAContextEvalWorker(info, this);
428
- worker->Queue();
429
- return worker->Promise();
757
+ Napi::Value AddonContext::SampleToken(const Napi::CallbackInfo& info) {
758
+ AddonContextSampleTokenWorker* worker = new AddonContextSampleTokenWorker(info, this);
759
+ worker->Queue();
760
+ return worker->Promise();
430
761
  }
431
762
 
432
- Napi::Value systemInfo(const Napi::CallbackInfo& info) { return Napi::String::From(info.Env(), llama_print_system_info()); }
763
+ Napi::Value systemInfo(const Napi::CallbackInfo& info) {
764
+ return Napi::String::From(info.Env(), llama_print_system_info());
765
+ }
433
766
 
434
767
  Napi::Object registerCallback(Napi::Env env, Napi::Object exports) {
435
- llama_backend_init(false);
436
- exports.DefineProperties({
437
- Napi::PropertyDescriptor::Function("systemInfo", systemInfo),
438
- });
439
- LLAMAModel::init(exports);
440
- LLAMAGrammar::init(exports);
441
- LLAMAGrammarEvaluationState::init(exports);
442
- LLAMAContext::init(exports);
443
- return exports;
768
+ llama_backend_init(false);
769
+ exports.DefineProperties({
770
+ Napi::PropertyDescriptor::Function("systemInfo", systemInfo),
771
+ });
772
+ AddonModel::init(exports);
773
+ AddonGrammar::init(exports);
774
+ AddonGrammarEvaluationState::init(exports);
775
+ AddonContext::init(exports);
776
+ return exports;
444
777
  }
445
778
 
446
779
  NODE_API_MODULE(NODE_GYP_MODULE_NAME, registerCallback)