node-llama-cpp 2.8.4 → 3.0.0-beta.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. package/README.md +2 -2
  2. package/dist/ChatWrapper.d.ts +49 -0
  3. package/dist/ChatWrapper.js +120 -0
  4. package/dist/ChatWrapper.js.map +1 -0
  5. package/dist/chatWrappers/AlpacaChatWrapper.d.ts +12 -0
  6. package/dist/chatWrappers/AlpacaChatWrapper.js +21 -0
  7. package/dist/chatWrappers/AlpacaChatWrapper.js.map +1 -0
  8. package/dist/chatWrappers/ChatMLChatWrapper.d.ts +13 -0
  9. package/dist/chatWrappers/ChatMLChatWrapper.js +83 -0
  10. package/dist/chatWrappers/ChatMLChatWrapper.js.map +1 -0
  11. package/dist/chatWrappers/EmptyChatWrapper.d.ts +4 -0
  12. package/dist/chatWrappers/EmptyChatWrapper.js +5 -0
  13. package/dist/chatWrappers/EmptyChatWrapper.js.map +1 -0
  14. package/dist/chatWrappers/FalconChatWrapper.d.ts +21 -0
  15. package/dist/chatWrappers/FalconChatWrapper.js +104 -0
  16. package/dist/chatWrappers/FalconChatWrapper.js.map +1 -0
  17. package/dist/chatWrappers/FunctionaryChatWrapper.d.ts +41 -0
  18. package/dist/chatWrappers/FunctionaryChatWrapper.js +200 -0
  19. package/dist/chatWrappers/FunctionaryChatWrapper.js.map +1 -0
  20. package/dist/chatWrappers/GeneralChatWrapper.d.ts +21 -0
  21. package/dist/chatWrappers/GeneralChatWrapper.js +112 -0
  22. package/dist/chatWrappers/GeneralChatWrapper.js.map +1 -0
  23. package/dist/chatWrappers/LlamaChatWrapper.d.ts +13 -0
  24. package/dist/chatWrappers/LlamaChatWrapper.js +78 -0
  25. package/dist/chatWrappers/LlamaChatWrapper.js.map +1 -0
  26. package/dist/chatWrappers/resolveChatWrapperBasedOnModel.d.ts +13 -0
  27. package/dist/chatWrappers/resolveChatWrapperBasedOnModel.js +55 -0
  28. package/dist/chatWrappers/resolveChatWrapperBasedOnModel.js.map +1 -0
  29. package/dist/cli/cli.js +1 -1
  30. package/dist/cli/cli.js.map +1 -1
  31. package/dist/cli/commands/ChatCommand.d.ts +2 -1
  32. package/dist/cli/commands/ChatCommand.js +90 -42
  33. package/dist/cli/commands/ChatCommand.js.map +1 -1
  34. package/dist/config.js +1 -1
  35. package/dist/config.js.map +1 -1
  36. package/dist/index.d.ts +20 -12
  37. package/dist/index.js +19 -11
  38. package/dist/index.js.map +1 -1
  39. package/dist/llamaEvaluator/LlamaBins.d.ts +18 -4
  40. package/dist/llamaEvaluator/LlamaBins.js +3 -3
  41. package/dist/llamaEvaluator/LlamaChat/LlamaChat.d.ts +175 -0
  42. package/dist/llamaEvaluator/LlamaChat/LlamaChat.js +704 -0
  43. package/dist/llamaEvaluator/LlamaChat/LlamaChat.js.map +1 -0
  44. package/dist/llamaEvaluator/LlamaChat/utils/FunctionCallGrammar.d.ts +21 -0
  45. package/dist/llamaEvaluator/LlamaChat/utils/FunctionCallGrammar.js +120 -0
  46. package/dist/llamaEvaluator/LlamaChat/utils/FunctionCallGrammar.js.map +1 -0
  47. package/dist/llamaEvaluator/LlamaChat/utils/contextShiftStrategies/eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy.d.ts +16 -0
  48. package/dist/llamaEvaluator/LlamaChat/utils/contextShiftStrategies/eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy.js +117 -0
  49. package/dist/llamaEvaluator/LlamaChat/utils/contextShiftStrategies/eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy.js.map +1 -0
  50. package/dist/llamaEvaluator/LlamaChatSession/LlamaChatSession.d.ts +146 -0
  51. package/dist/llamaEvaluator/LlamaChatSession/LlamaChatSession.js +211 -0
  52. package/dist/llamaEvaluator/LlamaChatSession/LlamaChatSession.js.map +1 -0
  53. package/dist/llamaEvaluator/LlamaChatSession/utils/defineChatSessionFunction.d.ts +7 -0
  54. package/dist/llamaEvaluator/LlamaChatSession/utils/defineChatSessionFunction.js +8 -0
  55. package/dist/llamaEvaluator/LlamaChatSession/utils/defineChatSessionFunction.js.map +1 -0
  56. package/dist/llamaEvaluator/LlamaContext/LlamaContext.d.ts +107 -0
  57. package/dist/llamaEvaluator/LlamaContext/LlamaContext.js +597 -0
  58. package/dist/llamaEvaluator/LlamaContext/LlamaContext.js.map +1 -0
  59. package/dist/llamaEvaluator/LlamaContext/types.d.ts +86 -0
  60. package/dist/llamaEvaluator/LlamaContext/types.js +2 -0
  61. package/dist/llamaEvaluator/LlamaContext/types.js.map +1 -0
  62. package/dist/llamaEvaluator/LlamaContext/utils/batchItemsPrioritizingStrategies/firstInFirstOutStrategy.d.ts +5 -0
  63. package/dist/llamaEvaluator/LlamaContext/utils/batchItemsPrioritizingStrategies/firstInFirstOutStrategy.js +16 -0
  64. package/dist/llamaEvaluator/LlamaContext/utils/batchItemsPrioritizingStrategies/firstInFirstOutStrategy.js.map +1 -0
  65. package/dist/llamaEvaluator/LlamaContext/utils/batchItemsPrioritizingStrategies/maximumParallelismStrategy.d.ts +5 -0
  66. package/dist/llamaEvaluator/LlamaContext/utils/batchItemsPrioritizingStrategies/maximumParallelismStrategy.js +42 -0
  67. package/dist/llamaEvaluator/LlamaContext/utils/batchItemsPrioritizingStrategies/maximumParallelismStrategy.js.map +1 -0
  68. package/dist/llamaEvaluator/LlamaContext/utils/resolveBatchItemsPrioritizingStrategy.d.ts +2 -0
  69. package/dist/llamaEvaluator/LlamaContext/utils/resolveBatchItemsPrioritizingStrategy.js +13 -0
  70. package/dist/llamaEvaluator/LlamaContext/utils/resolveBatchItemsPrioritizingStrategy.js.map +1 -0
  71. package/dist/llamaEvaluator/LlamaGrammar.d.ts +9 -13
  72. package/dist/llamaEvaluator/LlamaGrammar.js +10 -15
  73. package/dist/llamaEvaluator/LlamaGrammar.js.map +1 -1
  74. package/dist/llamaEvaluator/LlamaGrammarEvaluationState.d.ts +6 -5
  75. package/dist/llamaEvaluator/LlamaGrammarEvaluationState.js +8 -7
  76. package/dist/llamaEvaluator/LlamaGrammarEvaluationState.js.map +1 -1
  77. package/dist/llamaEvaluator/LlamaJsonSchemaGrammar.js +2 -1
  78. package/dist/llamaEvaluator/LlamaJsonSchemaGrammar.js.map +1 -1
  79. package/dist/llamaEvaluator/LlamaModel.d.ts +101 -105
  80. package/dist/llamaEvaluator/LlamaModel.js +305 -57
  81. package/dist/llamaEvaluator/LlamaModel.js.map +1 -1
  82. package/dist/types.d.ts +44 -4
  83. package/dist/types.js +5 -1
  84. package/dist/types.js.map +1 -1
  85. package/dist/utils/LlamaText.d.ts +42 -0
  86. package/dist/utils/LlamaText.js +207 -0
  87. package/dist/utils/LlamaText.js.map +1 -0
  88. package/dist/utils/ReplHistory.js +1 -1
  89. package/dist/utils/ReplHistory.js.map +1 -1
  90. package/dist/utils/StopGenerationDetector.d.ts +28 -0
  91. package/dist/utils/StopGenerationDetector.js +205 -0
  92. package/dist/utils/StopGenerationDetector.js.map +1 -0
  93. package/dist/utils/TokenStreamRegulator.d.ts +30 -0
  94. package/dist/utils/TokenStreamRegulator.js +96 -0
  95. package/dist/utils/TokenStreamRegulator.js.map +1 -0
  96. package/dist/utils/appendUserMessageToChatHistory.d.ts +2 -0
  97. package/dist/utils/appendUserMessageToChatHistory.js +18 -0
  98. package/dist/utils/appendUserMessageToChatHistory.js.map +1 -0
  99. package/dist/utils/compareTokens.d.ts +2 -0
  100. package/dist/utils/compareTokens.js +4 -0
  101. package/dist/utils/compareTokens.js.map +1 -0
  102. package/dist/utils/findCharacterRemovalCountToFitChatHistoryInContext.d.ts +18 -0
  103. package/dist/utils/findCharacterRemovalCountToFitChatHistoryInContext.js +61 -0
  104. package/dist/utils/findCharacterRemovalCountToFitChatHistoryInContext.js.map +1 -0
  105. package/dist/utils/gbnfJson/GbnfGrammarGenerator.d.ts +1 -0
  106. package/dist/utils/gbnfJson/GbnfGrammarGenerator.js +17 -0
  107. package/dist/utils/gbnfJson/GbnfGrammarGenerator.js.map +1 -1
  108. package/dist/utils/gbnfJson/GbnfTerminal.d.ts +1 -1
  109. package/dist/utils/gbnfJson/GbnfTerminal.js.map +1 -1
  110. package/dist/utils/gbnfJson/terminals/GbnfVerbatimText.d.ts +6 -0
  111. package/dist/utils/gbnfJson/terminals/GbnfVerbatimText.js +21 -0
  112. package/dist/utils/gbnfJson/terminals/GbnfVerbatimText.js.map +1 -0
  113. package/dist/utils/gbnfJson/types.d.ts +1 -1
  114. package/dist/utils/gbnfJson/types.js.map +1 -1
  115. package/dist/utils/gbnfJson/utils/validateObjectAgainstGbnfSchema.d.ts +1 -0
  116. package/dist/utils/gbnfJson/utils/validateObjectAgainstGbnfSchema.js.map +1 -1
  117. package/dist/utils/getBin.d.ts +71 -38
  118. package/dist/utils/getBin.js.map +1 -1
  119. package/dist/utils/getGbnfGrammarForGbnfJsonSchema.js +1 -15
  120. package/dist/utils/getGbnfGrammarForGbnfJsonSchema.js.map +1 -1
  121. package/dist/utils/getReleaseInfo.d.ts +1 -1
  122. package/dist/utils/getReleaseInfo.js.map +1 -1
  123. package/dist/utils/getTypeScriptTypeStringForGbnfJsonSchema.d.ts +2 -0
  124. package/dist/utils/getTypeScriptTypeStringForGbnfJsonSchema.js +49 -0
  125. package/dist/utils/getTypeScriptTypeStringForGbnfJsonSchema.js.map +1 -0
  126. package/dist/utils/parseModelFileName.d.ts +9 -0
  127. package/dist/utils/parseModelFileName.js +68 -0
  128. package/dist/utils/parseModelFileName.js.map +1 -0
  129. package/dist/utils/parseModelTypeDescription.d.ts +6 -0
  130. package/dist/utils/parseModelTypeDescription.js +9 -0
  131. package/dist/utils/parseModelTypeDescription.js.map +1 -0
  132. package/dist/utils/resolveChatWrapper.d.ts +4 -0
  133. package/dist/utils/resolveChatWrapper.js +16 -0
  134. package/dist/utils/resolveChatWrapper.js.map +1 -0
  135. package/dist/utils/truncateTextAndRoundToWords.d.ts +8 -0
  136. package/dist/utils/truncateTextAndRoundToWords.js +27 -0
  137. package/dist/utils/truncateTextAndRoundToWords.js.map +1 -0
  138. package/llama/.clang-format +10 -9
  139. package/llama/addon.cpp +701 -352
  140. package/llama/gitRelease.bundle +0 -0
  141. package/llamaBins/linux-arm64/llama-addon.node +0 -0
  142. package/llamaBins/linux-armv7l/llama-addon.node +0 -0
  143. package/llamaBins/linux-x64/llama-addon.node +0 -0
  144. package/llamaBins/mac-arm64/llama-addon.node +0 -0
  145. package/llamaBins/mac-x64/llama-addon.node +0 -0
  146. package/llamaBins/win-x64/llama-addon.node +0 -0
  147. package/package.json +24 -14
  148. package/dist/ChatPromptWrapper.d.ts +0 -11
  149. package/dist/ChatPromptWrapper.js +0 -20
  150. package/dist/ChatPromptWrapper.js.map +0 -1
  151. package/dist/chatWrappers/ChatMLChatPromptWrapper.d.ts +0 -12
  152. package/dist/chatWrappers/ChatMLChatPromptWrapper.js +0 -22
  153. package/dist/chatWrappers/ChatMLChatPromptWrapper.js.map +0 -1
  154. package/dist/chatWrappers/EmptyChatPromptWrapper.d.ts +0 -4
  155. package/dist/chatWrappers/EmptyChatPromptWrapper.js +0 -5
  156. package/dist/chatWrappers/EmptyChatPromptWrapper.js.map +0 -1
  157. package/dist/chatWrappers/FalconChatPromptWrapper.d.ts +0 -19
  158. package/dist/chatWrappers/FalconChatPromptWrapper.js +0 -33
  159. package/dist/chatWrappers/FalconChatPromptWrapper.js.map +0 -1
  160. package/dist/chatWrappers/GeneralChatPromptWrapper.d.ts +0 -19
  161. package/dist/chatWrappers/GeneralChatPromptWrapper.js +0 -38
  162. package/dist/chatWrappers/GeneralChatPromptWrapper.js.map +0 -1
  163. package/dist/chatWrappers/LlamaChatPromptWrapper.d.ts +0 -12
  164. package/dist/chatWrappers/LlamaChatPromptWrapper.js +0 -23
  165. package/dist/chatWrappers/LlamaChatPromptWrapper.js.map +0 -1
  166. package/dist/chatWrappers/createChatWrapperByBos.d.ts +0 -2
  167. package/dist/chatWrappers/createChatWrapperByBos.js +0 -14
  168. package/dist/chatWrappers/createChatWrapperByBos.js.map +0 -1
  169. package/dist/chatWrappers/generateContextTextFromConversationHistory.d.ts +0 -23
  170. package/dist/chatWrappers/generateContextTextFromConversationHistory.js +0 -47
  171. package/dist/chatWrappers/generateContextTextFromConversationHistory.js.map +0 -1
  172. package/dist/llamaEvaluator/LlamaChatSession.d.ts +0 -122
  173. package/dist/llamaEvaluator/LlamaChatSession.js +0 -236
  174. package/dist/llamaEvaluator/LlamaChatSession.js.map +0 -1
  175. package/dist/llamaEvaluator/LlamaContext.d.ts +0 -98
  176. package/dist/llamaEvaluator/LlamaContext.js +0 -140
  177. package/dist/llamaEvaluator/LlamaContext.js.map +0 -1
  178. package/dist/utils/getTextCompletion.d.ts +0 -3
  179. package/dist/utils/getTextCompletion.js +0 -12
  180. package/dist/utils/getTextCompletion.js.map +0 -1
  181. package/dist/utils/withLock.d.ts +0 -1
  182. package/dist/utils/withLock.js +0 -19
  183. package/dist/utils/withLock.js.map +0 -1
  184. package/llamaBins/mac-arm64/ggml-metal.metal +0 -5820
  185. package/llamaBins/mac-x64/ggml-metal.metal +0 -5820
package/llama/addon.cpp CHANGED
@@ -1,442 +1,791 @@
1
1
  #include <stddef.h>
2
+
2
3
  #include <algorithm>
3
4
  #include <sstream>
4
5
  #include <vector>
5
6
 
6
7
  #include "common.h"
7
- #include "llama.h"
8
8
  #include "common/grammar-parser.h"
9
+ #include "llama.h"
9
10
  #include "napi.h"
10
11
 
11
- class LLAMAModel : public Napi::ObjectWrap<LLAMAModel> {
12
- public:
13
- llama_model_params model_params;
14
- llama_model* model;
12
+ std::string addon_model_token_to_piece(const struct llama_model * model, llama_token token) {
13
+ std::vector<char> result(8, 0);
14
+ const int n_tokens = llama_token_to_piece(model, token, result.data(), result.size());
15
+ if (n_tokens < 0) {
16
+ result.resize(-n_tokens);
17
+ int check = llama_token_to_piece(model, token, result.data(), result.size());
18
+ GGML_ASSERT(check == -n_tokens);
19
+ }
20
+ else {
21
+ result.resize(n_tokens);
22
+ }
23
+
24
+ return std::string(result.data(), result.size());
25
+ }
26
+
27
+ class AddonModel : public Napi::ObjectWrap<AddonModel> {
28
+ public:
29
+ llama_model_params model_params;
30
+ llama_model* model;
31
+ bool disposed = false;
32
+
33
+ AddonModel(const Napi::CallbackInfo& info) : Napi::ObjectWrap<AddonModel>(info) {
34
+ model_params = llama_model_default_params();
35
+
36
+ // Get the model path
37
+ std::string modelPath = info[0].As<Napi::String>().Utf8Value();
38
+
39
+ if (info.Length() > 1 && info[1].IsObject()) {
40
+ Napi::Object options = info[1].As<Napi::Object>();
15
41
 
16
- LLAMAModel(const Napi::CallbackInfo& info) : Napi::ObjectWrap<LLAMAModel>(info) {
17
- model_params = llama_model_default_params();
42
+ if (options.Has("gpuLayers")) {
43
+ model_params.n_gpu_layers = options.Get("gpuLayers").As<Napi::Number>().Int32Value();
44
+ }
18
45
 
19
- // Get the model path
20
- std::string modelPath = info[0].As<Napi::String>().Utf8Value();
46
+ if (options.Has("vocabOnly")) {
47
+ model_params.vocab_only = options.Get("vocabOnly").As<Napi::Boolean>().Value();
48
+ }
21
49
 
22
- if (info.Length() > 1 && info[1].IsObject()) {
23
- Napi::Object options = info[1].As<Napi::Object>();
50
+ if (options.Has("useMmap")) {
51
+ model_params.use_mmap = options.Get("useMmap").As<Napi::Boolean>().Value();
52
+ }
24
53
 
25
- if (options.Has("gpuLayers")) {
26
- model_params.n_gpu_layers = options.Get("gpuLayers").As<Napi::Number>().Int32Value();
54
+ if (options.Has("useMlock")) {
55
+ model_params.use_mlock = options.Get("useMlock").As<Napi::Boolean>().Value();
56
+ }
27
57
  }
28
58
 
29
- if (options.Has("vocabOnly")) {
30
- model_params.vocab_only = options.Get("vocabOnly").As<Napi::Boolean>().Value();
59
+ llama_backend_init(false);
60
+ model = llama_load_model_from_file(modelPath.c_str(), model_params);
61
+
62
+ if (model == NULL) {
63
+ Napi::Error::New(info.Env(), "Failed to load model").ThrowAsJavaScriptException();
64
+ return;
31
65
  }
66
+ }
67
+
68
+ ~AddonModel() {
69
+ dispose();
70
+ }
32
71
 
33
- if (options.Has("useMmap")) {
34
- model_params.use_mmap = options.Get("useMmap").As<Napi::Boolean>().Value();
72
+ void dispose() {
73
+ if (disposed) {
74
+ return;
35
75
  }
36
76
 
37
- if (options.Has("useMlock")) {
38
- model_params.use_mlock = options.Get("useMlock").As<Napi::Boolean>().Value();
77
+ llama_free_model(model);
78
+ disposed = true;
79
+ }
80
+
81
+ Napi::Value Dispose(const Napi::CallbackInfo& info) {
82
+ if (disposed) {
83
+ return info.Env().Undefined();
39
84
  }
85
+
86
+ dispose();
87
+
88
+ return info.Env().Undefined();
40
89
  }
41
90
 
42
- llama_backend_init(false);
43
- model = llama_load_model_from_file(modelPath.c_str(), model_params);
91
+ Napi::Value Tokenize(const Napi::CallbackInfo& info) {
92
+ if (disposed) {
93
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
94
+ return info.Env().Undefined();
95
+ }
96
+
97
+ std::string text = info[0].As<Napi::String>().Utf8Value();
98
+ bool specialTokens = info[1].As<Napi::Boolean>().Value();
99
+
100
+ std::vector<llama_token> tokens = llama_tokenize(model, text, false, specialTokens);
44
101
 
45
- if (model == NULL) {
46
- Napi::Error::New(info.Env(), "Failed to load model").ThrowAsJavaScriptException();
47
- return;
102
+ Napi::Uint32Array result = Napi::Uint32Array::New(info.Env(), tokens.size());
103
+ for (size_t i = 0; i < tokens.size(); ++i) {
104
+ result[i] = static_cast<uint32_t>(tokens[i]);
105
+ }
106
+
107
+ return result;
48
108
  }
49
- }
109
+ Napi::Value Detokenize(const Napi::CallbackInfo& info) {
110
+ if (disposed) {
111
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
112
+ return info.Env().Undefined();
113
+ }
50
114
 
51
- ~LLAMAModel() {
52
- llama_free_model(model);
53
- }
115
+ Napi::Uint32Array tokens = info[0].As<Napi::Uint32Array>();
54
116
 
55
- static void init(Napi::Object exports) {
56
- exports.Set("LLAMAModel", DefineClass(exports.Env(), "LLAMAModel", {}));
57
- }
58
- };
117
+ // Create a stringstream for accumulating the decoded string.
118
+ std::stringstream ss;
59
119
 
60
- class LLAMAGrammar : public Napi::ObjectWrap<LLAMAGrammar> {
61
- public:
62
- grammar_parser::parse_state parsed_grammar;
120
+ // Decode each token and accumulate the result.
121
+ for (size_t i = 0; i < tokens.ElementLength(); i++) {
122
+ const std::string piece = addon_model_token_to_piece(model, (llama_token)tokens[i]);
63
123
 
64
- LLAMAGrammar(const Napi::CallbackInfo& info) : Napi::ObjectWrap<LLAMAGrammar>(info) {
65
- // Get the model path
66
- std::string grammarCode = info[0].As<Napi::String>().Utf8Value();
67
- bool should_print_grammar = false;
124
+ if (piece.empty()) {
125
+ continue;
126
+ }
68
127
 
69
- if (info.Length() > 1 && info[1].IsObject()) {
70
- Napi::Object options = info[1].As<Napi::Object>();
128
+ ss << piece;
129
+ }
71
130
 
72
- if (options.Has("printGrammar")) {
73
- should_print_grammar = options.Get("printGrammar").As<Napi::Boolean>().Value();
131
+ return Napi::String::New(info.Env(), ss.str());
132
+ }
133
+
134
+ Napi::Value GetTrainContextSize(const Napi::CallbackInfo& info) {
135
+ if (disposed) {
136
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
137
+ return info.Env().Undefined();
74
138
  }
139
+
140
+ return Napi::Number::From(info.Env(), llama_n_ctx_train(model));
75
141
  }
76
142
 
77
- parsed_grammar = grammar_parser::parse(grammarCode.c_str());
78
- // will be empty (default) if there are parse errors
79
- if (parsed_grammar.rules.empty()) {
80
- Napi::Error::New(info.Env(), "Failed to parse grammar").ThrowAsJavaScriptException();
81
- return;
143
+ Napi::Value GetTotalSize(const Napi::CallbackInfo& info) {
144
+ if (disposed) {
145
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
146
+ return info.Env().Undefined();
147
+ }
148
+
149
+ return Napi::Number::From(info.Env(), llama_model_size(model));
82
150
  }
83
151
 
84
- if (should_print_grammar) {
85
- grammar_parser::print_grammar(stderr, parsed_grammar);
152
+ Napi::Value GetTotalParameters(const Napi::CallbackInfo& info) {
153
+ if (disposed) {
154
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
155
+ return info.Env().Undefined();
156
+ }
157
+
158
+ return Napi::Number::From(info.Env(), llama_model_n_params(model));
86
159
  }
87
- }
88
160
 
89
- static void init(Napi::Object exports) {
90
- exports.Set("LLAMAGrammar", DefineClass(exports.Env(), "LLAMAGrammar", {}));
91
- }
92
- };
161
+ Napi::Value GetModelDescription(const Napi::CallbackInfo& info) {
162
+ if (disposed) {
163
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
164
+ return info.Env().Undefined();
165
+ }
93
166
 
94
- class LLAMAGrammarEvaluationState : public Napi::ObjectWrap<LLAMAGrammarEvaluationState> {
95
- public:
96
- LLAMAGrammar* grammarDef;
97
- llama_grammar *grammar = nullptr;
167
+ char model_desc[128];
168
+ int actual_length = llama_model_desc(model, model_desc, sizeof(model_desc));
98
169
 
99
- LLAMAGrammarEvaluationState(const Napi::CallbackInfo& info) : Napi::ObjectWrap<LLAMAGrammarEvaluationState>(info) {
100
- grammarDef = Napi::ObjectWrap<LLAMAGrammar>::Unwrap(info[0].As<Napi::Object>());
101
- grammarDef->Ref();
170
+ return Napi::String::New(info.Env(), model_desc, actual_length);
171
+ }
102
172
 
103
- std::vector<const llama_grammar_element *> grammar_rules(grammarDef->parsed_grammar.c_rules());
104
- grammar = llama_grammar_init(
105
- grammar_rules.data(), grammar_rules.size(), grammarDef->parsed_grammar.symbol_ids.at("root")
106
- );
107
- }
173
+ Napi::Value TokenBos(const Napi::CallbackInfo& info) {
174
+ if (disposed) {
175
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
176
+ return info.Env().Undefined();
177
+ }
108
178
 
109
- ~LLAMAGrammarEvaluationState() {
110
- grammarDef->Unref();
179
+ return Napi::Number::From(info.Env(), llama_token_bos(model));
180
+ }
181
+ Napi::Value TokenEos(const Napi::CallbackInfo& info) {
182
+ if (disposed) {
183
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
184
+ return info.Env().Undefined();
185
+ }
111
186
 
112
- if (grammar != nullptr) {
113
- llama_grammar_free(grammar);
114
- grammar = nullptr;
187
+ return Napi::Number::From(info.Env(), llama_token_eos(model));
115
188
  }
116
- }
189
+ Napi::Value TokenNl(const Napi::CallbackInfo& info) {
190
+ if (disposed) {
191
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
192
+ return info.Env().Undefined();
193
+ }
117
194
 
118
- static void init(Napi::Object exports) {
119
- exports.Set("LLAMAGrammarEvaluationState", DefineClass(exports.Env(), "LLAMAGrammarEvaluationState", {}));
120
- }
195
+ return Napi::Number::From(info.Env(), llama_token_nl(model));
196
+ }
197
+ Napi::Value PrefixToken(const Napi::CallbackInfo& info) {
198
+ if (disposed) {
199
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
200
+ return info.Env().Undefined();
201
+ }
202
+
203
+ return Napi::Number::From(info.Env(), llama_token_prefix(model));
204
+ }
205
+ Napi::Value MiddleToken(const Napi::CallbackInfo& info) {
206
+ if (disposed) {
207
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
208
+ return info.Env().Undefined();
209
+ }
210
+
211
+ return Napi::Number::From(info.Env(), llama_token_middle(model));
212
+ }
213
+ Napi::Value SuffixToken(const Napi::CallbackInfo& info) {
214
+ if (disposed) {
215
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
216
+ return info.Env().Undefined();
217
+ }
218
+
219
+ return Napi::Number::From(info.Env(), llama_token_suffix(model));
220
+ }
221
+ Napi::Value EotToken(const Napi::CallbackInfo& info) {
222
+ if (disposed) {
223
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
224
+ return info.Env().Undefined();
225
+ }
226
+
227
+ return Napi::Number::From(info.Env(), llama_token_eot(model));
228
+ }
229
+ Napi::Value GetTokenString(const Napi::CallbackInfo& info) {
230
+ if (disposed) {
231
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
232
+ return info.Env().Undefined();
233
+ }
234
+
235
+ int token = info[0].As<Napi::Number>().Int32Value();
236
+ std::stringstream ss;
237
+
238
+ const char* str = llama_token_get_text(model, token);
239
+ if (str == nullptr) {
240
+ return info.Env().Undefined();
241
+ }
242
+
243
+ ss << str;
244
+
245
+ return Napi::String::New(info.Env(), ss.str());
246
+ }
247
+
248
+ static void init(Napi::Object exports) {
249
+ exports.Set(
250
+ "AddonModel",
251
+ DefineClass(
252
+ exports.Env(),
253
+ "AddonModel",
254
+ {
255
+ InstanceMethod("tokenize", &AddonModel::Tokenize),
256
+ InstanceMethod("detokenize", &AddonModel::Detokenize),
257
+ InstanceMethod("getTrainContextSize", &AddonModel::GetTrainContextSize),
258
+ InstanceMethod("getTotalSize", &AddonModel::GetTotalSize),
259
+ InstanceMethod("getTotalParameters", &AddonModel::GetTotalParameters),
260
+ InstanceMethod("getModelDescription", &AddonModel::GetModelDescription),
261
+ InstanceMethod("tokenBos", &AddonModel::TokenBos),
262
+ InstanceMethod("tokenEos", &AddonModel::TokenEos),
263
+ InstanceMethod("tokenNl", &AddonModel::TokenNl),
264
+ InstanceMethod("prefixToken", &AddonModel::PrefixToken),
265
+ InstanceMethod("middleToken", &AddonModel::MiddleToken),
266
+ InstanceMethod("suffixToken", &AddonModel::SuffixToken),
267
+ InstanceMethod("eotToken", &AddonModel::EotToken),
268
+ InstanceMethod("getTokenString", &AddonModel::GetTokenString),
269
+ InstanceMethod("dispose", &AddonModel::Dispose)
270
+ }
271
+ )
272
+ );
273
+ }
121
274
  };
122
275
 
123
- class LLAMAContext : public Napi::ObjectWrap<LLAMAContext> {
124
- public:
125
- LLAMAModel* model;
126
- llama_context_params context_params;
127
- llama_context* ctx;
128
- int n_cur = 0;
129
-
130
- LLAMAContext(const Napi::CallbackInfo& info) : Napi::ObjectWrap<LLAMAContext>(info) {
131
- model = Napi::ObjectWrap<LLAMAModel>::Unwrap(info[0].As<Napi::Object>());
132
- model->Ref();
133
-
134
- context_params = llama_context_default_params();
135
- context_params.seed = -1;
136
- context_params.n_ctx = 4096;
137
- context_params.n_threads = 6;
138
- context_params.n_threads_batch == -1 ? context_params.n_threads : context_params.n_threads_batch;
139
-
140
- if (info.Length() > 1 && info[1].IsObject()) {
141
- Napi::Object options = info[1].As<Napi::Object>();
142
-
143
- if (options.Has("seed")) {
144
- context_params.seed = options.Get("seed").As<Napi::Number>().Int32Value();
145
- }
146
-
147
- if (options.Has("contextSize")) {
148
- context_params.n_ctx = options.Get("contextSize").As<Napi::Number>().Int32Value();
149
- }
150
-
151
- if (options.Has("batchSize")) {
152
- context_params.n_batch = options.Get("batchSize").As<Napi::Number>().Int32Value();
153
- }
154
-
155
- if (options.Has("logitsAll")) {
156
- context_params.logits_all = options.Get("logitsAll").As<Napi::Boolean>().Value();
157
- }
158
-
159
- if (options.Has("embedding")) {
160
- context_params.embedding = options.Get("embedding").As<Napi::Boolean>().Value();
161
- }
162
-
163
- if (options.Has("threads")) {
164
- context_params.n_threads = options.Get("threads").As<Napi::Number>().Int32Value();
165
- context_params.n_threads_batch == -1 ? context_params.n_threads : context_params.n_threads_batch;
166
- }
167
- }
276
+ class AddonGrammar : public Napi::ObjectWrap<AddonGrammar> {
277
+ public:
278
+ grammar_parser::parse_state parsed_grammar;
168
279
 
169
- ctx = llama_new_context_with_model(model->model, context_params);
170
- Napi::MemoryManagement::AdjustExternalMemory(Env(), llama_get_state_size(ctx));
171
- }
172
- ~LLAMAContext() {
173
- Napi::MemoryManagement::AdjustExternalMemory(Env(), -(int64_t)llama_get_state_size(ctx));
174
- llama_free(ctx);
175
- model->Unref();
176
- }
177
- Napi::Value Encode(const Napi::CallbackInfo& info) {
178
- std::string text = info[0].As<Napi::String>().Utf8Value();
280
+ AddonGrammar(const Napi::CallbackInfo& info) : Napi::ObjectWrap<AddonGrammar>(info) {
281
+ // Get the model path
282
+ std::string grammarCode = info[0].As<Napi::String>().Utf8Value();
283
+ bool should_print_grammar = false;
179
284
 
180
- std::vector<llama_token> tokens = llama_tokenize(ctx, text, false);
285
+ if (info.Length() > 1 && info[1].IsObject()) {
286
+ Napi::Object options = info[1].As<Napi::Object>();
181
287
 
182
- Napi::Uint32Array result = Napi::Uint32Array::New(info.Env(), tokens.size());
183
- for (size_t i = 0; i < tokens.size(); ++i) { result[i] = static_cast<uint32_t>(tokens[i]); }
288
+ if (options.Has("printGrammar")) {
289
+ should_print_grammar = options.Get("printGrammar").As<Napi::Boolean>().Value();
290
+ }
291
+ }
292
+
293
+ parsed_grammar = grammar_parser::parse(grammarCode.c_str());
294
+ // will be empty (default) if there are parse errors
295
+ if (parsed_grammar.rules.empty()) {
296
+ Napi::Error::New(info.Env(), "Failed to parse grammar").ThrowAsJavaScriptException();
297
+ return;
298
+ }
299
+
300
+ if (should_print_grammar) {
301
+ grammar_parser::print_grammar(stderr, parsed_grammar);
302
+ }
303
+ }
184
304
 
185
- return result;
186
- }
187
- Napi::Value Decode(const Napi::CallbackInfo& info) {
188
- Napi::Uint32Array tokens = info[0].As<Napi::Uint32Array>();
305
+ static void init(Napi::Object exports) {
306
+ exports.Set("AddonGrammar", DefineClass(exports.Env(), "AddonGrammar", {}));
307
+ }
308
+ };
189
309
 
190
- // Create a stringstream for accumulating the decoded string.
191
- std::stringstream ss;
310
+ class AddonGrammarEvaluationState : public Napi::ObjectWrap<AddonGrammarEvaluationState> {
311
+ public:
312
+ AddonGrammar* grammarDef;
313
+ llama_grammar* grammar = nullptr;
192
314
 
193
- // Decode each token and accumulate the result.
194
- for (size_t i = 0; i < tokens.ElementLength(); i++) {
195
- const std::string piece = llama_token_to_piece(ctx, (llama_token)tokens[i]);
315
+ AddonGrammarEvaluationState(const Napi::CallbackInfo& info) : Napi::ObjectWrap<AddonGrammarEvaluationState>(info) {
316
+ grammarDef = Napi::ObjectWrap<AddonGrammar>::Unwrap(info[0].As<Napi::Object>());
317
+ grammarDef->Ref();
196
318
 
197
- if (piece.empty()) {
198
- continue;
319
+ std::vector<const llama_grammar_element*> grammar_rules(grammarDef->parsed_grammar.c_rules());
320
+ grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), grammarDef->parsed_grammar.symbol_ids.at("root"));
199
321
  }
200
322
 
201
- ss << piece;
202
- }
323
+ ~AddonGrammarEvaluationState() {
324
+ grammarDef->Unref();
203
325
 
204
- return Napi::String::New(info.Env(), ss.str());
205
- }
206
- Napi::Value TokenBos(const Napi::CallbackInfo& info) {
207
- return Napi::Number::From(info.Env(), llama_token_bos(model->model)); // TODO: move this to the model
208
- }
209
- Napi::Value TokenEos(const Napi::CallbackInfo& info) {
210
- return Napi::Number::From(info.Env(), llama_token_eos(model->model)); // TODO: move this to the model
211
- }
212
- Napi::Value TokenNl(const Napi::CallbackInfo& info) {
213
- return Napi::Number::From(info.Env(), llama_token_nl(model->model)); // TODO: move this to the model
214
- }
215
- Napi::Value GetContextSize(const Napi::CallbackInfo& info) {
216
- return Napi::Number::From(info.Env(), llama_n_ctx(ctx));
217
- }
218
- Napi::Value GetTokenString(const Napi::CallbackInfo& info) {
219
- int token = info[0].As<Napi::Number>().Int32Value();
220
- std::stringstream ss;
221
-
222
- const char* str = llama_token_get_text(model->model, token); // TODO: move this to the model
223
- if (str == nullptr) {
224
- return info.Env().Undefined();
225
- }
326
+ if (grammar != nullptr) {
327
+ llama_grammar_free(grammar);
328
+ grammar = nullptr;
329
+ }
330
+ }
226
331
 
227
- ss << str;
228
-
229
- return Napi::String::New(info.Env(), ss.str());
230
- }
231
- Napi::Value Eval(const Napi::CallbackInfo& info);
232
- static void init(Napi::Object exports) {
233
- exports.Set("LLAMAContext",
234
- DefineClass(exports.Env(),
235
- "LLAMAContext",
236
- {
237
- InstanceMethod("encode", &LLAMAContext::Encode),
238
- InstanceMethod("decode", &LLAMAContext::Decode),
239
- InstanceMethod("tokenBos", &LLAMAContext::TokenBos),
240
- InstanceMethod("tokenEos", &LLAMAContext::TokenEos),
241
- InstanceMethod("tokenNl", &LLAMAContext::TokenNl),
242
- InstanceMethod("getContextSize", &LLAMAContext::GetContextSize),
243
- InstanceMethod("getTokenString", &LLAMAContext::GetTokenString),
244
- InstanceMethod("eval", &LLAMAContext::Eval),
245
- }));
246
- }
332
+ static void init(Napi::Object exports) {
333
+ exports.Set("AddonGrammarEvaluationState", DefineClass(exports.Env(), "AddonGrammarEvaluationState", {}));
334
+ }
247
335
  };
248
336
 
337
+ class AddonContext : public Napi::ObjectWrap<AddonContext> {
338
+ public:
339
+ AddonModel* model;
340
+ llama_context_params context_params;
341
+ llama_context* ctx;
342
+ llama_batch batch;
343
+ bool has_batch = false;
344
+ int32_t batch_n_tokens = 0;
345
+ int n_cur = 0;
346
+ bool disposed = false;
347
+
348
+ AddonContext(const Napi::CallbackInfo& info) : Napi::ObjectWrap<AddonContext>(info) {
349
+ model = Napi::ObjectWrap<AddonModel>::Unwrap(info[0].As<Napi::Object>());
350
+ model->Ref();
351
+
352
+ context_params = llama_context_default_params();
353
+ context_params.seed = -1;
354
+ context_params.n_ctx = 4096;
355
+ context_params.n_threads = 6;
356
+ context_params.n_threads_batch = context_params.n_threads;
357
+
358
+ if (info.Length() > 1 && info[1].IsObject()) {
359
+ Napi::Object options = info[1].As<Napi::Object>();
360
+
361
+ if (options.Has("seed")) {
362
+ context_params.seed = options.Get("seed").As<Napi::Number>().Uint32Value();
363
+ }
364
+
365
+ if (options.Has("contextSize")) {
366
+ context_params.n_ctx = options.Get("contextSize").As<Napi::Number>().Uint32Value();
367
+ }
368
+
369
+ if (options.Has("batchSize")) {
370
+ context_params.n_batch = options.Get("batchSize").As<Napi::Number>().Uint32Value();
371
+ }
372
+
373
+ if (options.Has("logitsAll")) {
374
+ context_params.logits_all = options.Get("logitsAll").As<Napi::Boolean>().Value();
375
+ }
376
+
377
+ if (options.Has("embedding")) {
378
+ context_params.embedding = options.Get("embedding").As<Napi::Boolean>().Value();
379
+ }
380
+
381
+ if (options.Has("threads")) {
382
+ const auto n_threads = options.Get("threads").As<Napi::Number>().Uint32Value();
383
+ const auto resolved_n_threads = n_threads == 0 ? std::thread::hardware_concurrency() : n_threads;
384
+
385
+ context_params.n_threads = resolved_n_threads;
386
+ context_params.n_threads_batch = resolved_n_threads;
387
+ }
388
+ }
249
389
 
250
- class LLAMAContextEvalWorker : Napi::AsyncWorker, Napi::Promise::Deferred {
251
- LLAMAContext* ctx;
252
- LLAMAGrammarEvaluationState* grammar_evaluation_state;
253
- bool use_grammar = false;
254
- std::vector<llama_token> tokens;
255
- llama_token result;
256
- float temperature;
257
- int32_t top_k;
258
- float top_p;
259
- float repeat_penalty = 1.10f; // 1.0 = disabled
260
- float repeat_penalty_presence_penalty = 0.00f; // 0.0 = disabled
261
- float repeat_penalty_frequency_penalty = 0.00f; // 0.0 = disabled
262
- std::vector<llama_token> repeat_penalty_tokens;
263
- bool use_repeat_penalty = false;
264
-
265
- public:
266
- LLAMAContextEvalWorker(const Napi::CallbackInfo& info, LLAMAContext* ctx) : Napi::AsyncWorker(info.Env(), "LLAMAContextEvalWorker"), ctx(ctx), Napi::Promise::Deferred(info.Env()) {
267
- ctx->Ref();
268
- Napi::Uint32Array tokens = info[0].As<Napi::Uint32Array>();
269
-
270
- temperature = 0.0f;
271
- top_k = 40;
272
- top_p = 0.95f;
273
-
274
- if (info.Length() > 1 && info[1].IsObject()) {
275
- Napi::Object options = info[1].As<Napi::Object>();
276
-
277
- if (options.Has("temperature")) {
278
- temperature = options.Get("temperature").As<Napi::Number>().FloatValue();
279
- }
280
-
281
- if (options.Has("topK")) {
282
- top_k = options.Get("topK").As<Napi::Number>().Int32Value();
283
- }
284
-
285
- if (options.Has("topP")) {
286
- top_p = options.Get("topP").As<Napi::Number>().FloatValue();
287
- }
288
-
289
- if (options.Has("repeatPenalty")) {
290
- repeat_penalty = options.Get("repeatPenalty").As<Napi::Number>().FloatValue();
291
- }
292
-
293
- if (options.Has("repeatPenaltyTokens")) {
294
- Napi::Uint32Array repeat_penalty_tokens_uint32_array = options.Get("repeatPenaltyTokens").As<Napi::Uint32Array>();
295
-
296
- repeat_penalty_tokens.reserve(repeat_penalty_tokens_uint32_array.ElementLength());
297
- for (size_t i = 0; i < repeat_penalty_tokens_uint32_array.ElementLength(); i++) {
298
- repeat_penalty_tokens.push_back(static_cast<llama_token>(repeat_penalty_tokens_uint32_array[i]));
299
- }
300
-
301
- use_repeat_penalty = true;
302
- }
303
-
304
- if (options.Has("repeatPenaltyPresencePenalty")) {
305
- repeat_penalty_presence_penalty = options.Get("repeatPenaltyPresencePenalty").As<Napi::Number>().FloatValue();
306
- }
307
-
308
- if (options.Has("repeatPenaltyFrequencyPenalty")) {
309
- repeat_penalty_frequency_penalty = options.Get("repeatPenaltyFrequencyPenalty").As<Napi::Number>().FloatValue();
310
- }
311
-
312
- if (options.Has("grammarEvaluationState")) {
313
- grammar_evaluation_state = Napi::ObjectWrap<LLAMAGrammarEvaluationState>::Unwrap(options.Get("grammarEvaluationState").As<Napi::Object>());
314
- grammar_evaluation_state->Ref();
315
- use_grammar = true;
316
- }
317
- }
390
+ ctx = llama_new_context_with_model(model->model, context_params);
391
+ Napi::MemoryManagement::AdjustExternalMemory(Env(), llama_get_state_size(ctx));
392
+ }
393
+ ~AddonContext() {
394
+ dispose();
395
+ }
318
396
 
319
- this->tokens.reserve(tokens.ElementLength());
320
- for (size_t i = 0; i < tokens.ElementLength(); i++) { this->tokens.push_back(static_cast<llama_token>(tokens[i])); }
321
- }
322
- ~LLAMAContextEvalWorker() {
323
- ctx->Unref();
397
+ void dispose() {
398
+ if (disposed) {
399
+ return;
400
+ }
324
401
 
325
- if (use_grammar) {
326
- grammar_evaluation_state->Unref();
327
- use_grammar = false;
328
- }
329
- }
330
- using Napi::AsyncWorker::Queue;
331
- using Napi::Promise::Deferred::Promise;
402
+ Napi::MemoryManagement::AdjustExternalMemory(Env(), -(int64_t)llama_get_state_size(ctx));
403
+ llama_free(ctx);
404
+ model->Unref();
332
405
 
333
- protected:
334
- void Execute() {
335
- llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
406
+ disposeBatch();
336
407
 
337
- for (size_t i = 0; i < tokens.size(); i++) {
338
- llama_batch_add(batch, tokens[i], ctx->n_cur, { 0 }, false);
408
+ disposed = true;
409
+ }
410
+ void disposeBatch() {
411
+ if (!has_batch) {
412
+ return;
413
+ }
339
414
 
340
- ctx->n_cur++;
341
- }
342
- GGML_ASSERT(batch.n_tokens == (int) tokens.size());
415
+ llama_batch_free(batch);
416
+ has_batch = false;
417
+ batch_n_tokens = 0;
418
+ }
419
+ Napi::Value Dispose(const Napi::CallbackInfo& info) {
420
+ if (disposed) {
421
+ return info.Env().Undefined();
422
+ }
343
423
 
344
- batch.logits[batch.n_tokens - 1] = true;
424
+ dispose();
345
425
 
346
- // Perform the evaluation using llama_decode.
347
- int r = llama_decode(ctx->ctx, batch);
426
+ return info.Env().Undefined();
427
+ }
428
+ Napi::Value GetContextSize(const Napi::CallbackInfo& info) {
429
+ if (disposed) {
430
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
431
+ return info.Env().Undefined();
432
+ }
348
433
 
349
- llama_batch_free(batch);
434
+ return Napi::Number::From(info.Env(), llama_n_ctx(ctx));
435
+ }
436
+ Napi::Value InitBatch(const Napi::CallbackInfo& info) {
437
+ if (disposed) {
438
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
439
+ return info.Env().Undefined();
440
+ }
350
441
 
351
- if (r != 0) {
352
- if (r == 1) {
353
- SetError("could not find a KV slot for the batch (try reducing the size of the batch or increase the context)");
354
- } else {
355
- SetError("Eval has failed");
356
- }
442
+ if (has_batch) {
443
+ llama_batch_free(batch);
444
+ }
357
445
 
358
- return;
359
- }
446
+ int32_t n_tokens = info[0].As<Napi::Number>().Int32Value();
360
447
 
361
- llama_token new_token_id = 0;
448
+ batch = llama_batch_init(n_tokens, 0, 1);
449
+ has_batch = true;
450
+ batch_n_tokens = n_tokens;
362
451
 
363
- // Select the best prediction.
364
- auto logits = llama_get_logits_ith(ctx->ctx, batch.n_tokens - 1);
365
- auto n_vocab = llama_n_vocab(ctx->model->model);
452
+ return info.Env().Undefined();
453
+ }
454
+ Napi::Value DisposeBatch(const Napi::CallbackInfo& info) {
455
+ if (disposed) {
456
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
457
+ return info.Env().Undefined();
458
+ }
366
459
 
367
- std::vector<llama_token_data> candidates;
368
- candidates.reserve(n_vocab);
460
+ disposeBatch();
369
461
 
370
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
371
- candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
372
- }
462
+ return info.Env().Undefined();
463
+ }
464
+ Napi::Value AddToBatch(const Napi::CallbackInfo& info) {
465
+ if (!has_batch) {
466
+ Napi::Error::New(info.Env(), "No batch is initialized").ThrowAsJavaScriptException();
467
+ return info.Env().Undefined();
468
+ }
373
469
 
374
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
470
+ int32_t sequenceId = info[0].As<Napi::Number>().Int32Value();
471
+ int32_t firstTokenContextIndex = info[1].As<Napi::Number>().Int32Value();
472
+ Napi::Uint32Array tokens = info[2].As<Napi::Uint32Array>();
473
+ bool generateLogitAtTheEnd = info[3].As<Napi::Boolean>().Value();
375
474
 
376
- auto eos_token = llama_token_eos(ctx->model->model);
475
+ auto tokensLength = tokens.ElementLength();
476
+ GGML_ASSERT(batch.n_tokens + tokensLength <= batch_n_tokens);
377
477
 
378
- if (use_repeat_penalty && !repeat_penalty_tokens.empty()) {
379
- llama_sample_repetition_penalties(
380
- ctx->ctx, &candidates_p, repeat_penalty_tokens.data(), repeat_penalty_tokens.size(), repeat_penalty,
381
- repeat_penalty_frequency_penalty, repeat_penalty_presence_penalty
382
- );
383
- }
478
+ for (size_t i = 0; i < tokensLength; i++) {
479
+ llama_batch_add(batch, static_cast<llama_token>(tokens[i]), firstTokenContextIndex + i, { sequenceId }, false);
480
+ }
384
481
 
385
- if (use_grammar && (grammar_evaluation_state)->grammar != nullptr) {
386
- llama_sample_grammar(ctx->ctx, &candidates_p, (grammar_evaluation_state)->grammar);
387
- }
482
+ if (generateLogitAtTheEnd) {
483
+ batch.logits[batch.n_tokens - 1] = true;
388
484
 
389
- if (temperature <= 0) {
390
- new_token_id = llama_sample_token_greedy(ctx->ctx , &candidates_p);
391
- } else {
392
- const int32_t resolved_top_k = top_k <= 0 ? llama_n_vocab(ctx->model->model) : std::min(top_k, llama_n_vocab(ctx->model->model));
393
- const int32_t n_probs = 0; // Number of probabilities to keep - 0 = disabled
394
- const float tfs_z = 1.00f; // Tail free sampling - 1.0 = disabled
395
- const float typical_p = 1.00f; // Typical probability - 1.0 = disabled
396
- const float resolved_top_p = top_p; // Top p sampling - 1.0 = disabled
397
-
398
- // Temperature sampling
399
- size_t min_keep = std::max(1, n_probs);
400
- llama_sample_top_k(ctx->ctx, &candidates_p, resolved_top_k, min_keep);
401
- llama_sample_tail_free(ctx->ctx, &candidates_p, tfs_z, min_keep);
402
- llama_sample_typical(ctx->ctx, &candidates_p, typical_p, min_keep);
403
- llama_sample_top_p(ctx->ctx, &candidates_p, resolved_top_p, min_keep);
404
- llama_sample_temperature(ctx->ctx, &candidates_p, temperature);
405
- new_token_id = llama_sample_token(ctx->ctx, &candidates_p);
406
- }
485
+ auto logit_index = batch.n_tokens - 1;
407
486
 
408
- if (new_token_id != eos_token && use_grammar && (grammar_evaluation_state)->grammar != nullptr) {
409
- llama_grammar_accept_token(ctx->ctx, (grammar_evaluation_state)->grammar, new_token_id);
410
- }
487
+ return Napi::Number::From(info.Env(), logit_index);
488
+ }
489
+
490
+ return info.Env().Undefined();
491
+ }
492
+ Napi::Value DisposeSequence(const Napi::CallbackInfo& info) {
493
+ if (disposed) {
494
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
495
+ return info.Env().Undefined();
496
+ }
497
+
498
+ int32_t sequenceId = info[0].As<Napi::Number>().Int32Value();
499
+
500
+ llama_kv_cache_seq_rm(ctx, sequenceId, -1, -1);
501
+
502
+ return info.Env().Undefined();
503
+ }
504
+ Napi::Value RemoveTokenCellsFromSequence(const Napi::CallbackInfo& info) {
505
+ if (disposed) {
506
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
507
+ return info.Env().Undefined();
508
+ }
411
509
 
412
- result = new_token_id;
413
- }
414
- void OnOK() {
415
- Napi::Env env = Napi::AsyncWorker::Env();
416
- Napi::Number resultValue = Napi::Number::New(env, static_cast<uint32_t>(result));
417
- Napi::Promise::Deferred::Resolve(resultValue);
418
- }
419
- void OnError(const Napi::Error& err) { Napi::Promise::Deferred::Reject(err.Value()); }
510
+ int32_t sequenceId = info[0].As<Napi::Number>().Int32Value();
511
+ int32_t startPos = info[1].As<Napi::Number>().Int32Value();
512
+ int32_t endPos = info[2].As<Napi::Number>().Int32Value();
513
+
514
+ llama_kv_cache_seq_rm(ctx, sequenceId, startPos, endPos);
515
+
516
+ return info.Env().Undefined();
517
+ }
518
+ Napi::Value ShiftSequenceTokenCells(const Napi::CallbackInfo& info) {
519
+ if (disposed) {
520
+ Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
521
+ return info.Env().Undefined();
522
+ }
523
+
524
+ int32_t sequenceId = info[0].As<Napi::Number>().Int32Value();
525
+ int32_t startPos = info[1].As<Napi::Number>().Int32Value();
526
+ int32_t endPos = info[2].As<Napi::Number>().Int32Value();
527
+ int32_t shiftDelta = info[3].As<Napi::Number>().Int32Value();
528
+
529
+ llama_kv_cache_seq_shift(ctx, sequenceId, startPos, endPos, shiftDelta);
530
+
531
+ return info.Env().Undefined();
532
+ }
533
+ Napi::Value DecodeBatch(const Napi::CallbackInfo& info);
534
+ Napi::Value SampleToken(const Napi::CallbackInfo& info);
535
+
536
+ Napi::Value AcceptGrammarEvaluationStateToken(const Napi::CallbackInfo& info) {
537
+ AddonGrammarEvaluationState* grammar_evaluation_state = Napi::ObjectWrap<AddonGrammarEvaluationState>::Unwrap(info[0].As<Napi::Object>());
538
+ llama_token tokenId = info[1].As<Napi::Number>().Int32Value();
539
+
540
+ if ((grammar_evaluation_state)->grammar != nullptr) {
541
+ llama_grammar_accept_token(ctx, (grammar_evaluation_state)->grammar, tokenId);
542
+ }
543
+
544
+ return info.Env().Undefined();
545
+ }
546
+
547
+ static void init(Napi::Object exports) {
548
+ exports.Set(
549
+ "AddonContext",
550
+ DefineClass(
551
+ exports.Env(),
552
+ "AddonContext",
553
+ {
554
+ InstanceMethod("getContextSize", &AddonContext::GetContextSize),
555
+ InstanceMethod("initBatch", &AddonContext::InitBatch),
556
+ InstanceMethod("addToBatch", &AddonContext::AddToBatch),
557
+ InstanceMethod("disposeSequence", &AddonContext::DisposeSequence),
558
+ InstanceMethod("removeTokenCellsFromSequence", &AddonContext::RemoveTokenCellsFromSequence),
559
+ InstanceMethod("shiftSequenceTokenCells", &AddonContext::ShiftSequenceTokenCells),
560
+ InstanceMethod("decodeBatch", &AddonContext::DecodeBatch),
561
+ InstanceMethod("sampleToken", &AddonContext::SampleToken),
562
+ InstanceMethod("acceptGrammarEvaluationStateToken", &AddonContext::AcceptGrammarEvaluationStateToken),
563
+ InstanceMethod("dispose", &AddonContext::Dispose)
564
+ }
565
+ )
566
+ );
567
+ }
568
+ };
569
+
570
+
571
+ class AddonContextDecodeBatchWorker : Napi::AsyncWorker, Napi::Promise::Deferred {
572
+ public:
573
+ AddonContext* ctx;
574
+
575
+ AddonContextDecodeBatchWorker(const Napi::CallbackInfo& info, AddonContext* ctx)
576
+ : Napi::AsyncWorker(info.Env(), "AddonContextDecodeBatchWorker"),
577
+ ctx(ctx),
578
+ Napi::Promise::Deferred(info.Env()) {
579
+ ctx->Ref();
580
+ }
581
+ ~AddonContextDecodeBatchWorker() {
582
+ ctx->Unref();
583
+ }
584
+ using Napi::AsyncWorker::Queue;
585
+ using Napi::Promise::Deferred::Promise;
586
+
587
+ protected:
588
+ void Execute() {
589
+ // Perform the evaluation using llama_decode.
590
+ int r = llama_decode(ctx->ctx, ctx->batch);
591
+
592
+ if (r != 0) {
593
+ if (r == 1) {
594
+ SetError("could not find a KV slot for the batch (try reducing the size of the batch or increase the context)");
595
+ } else {
596
+ SetError("Eval has failed");
597
+ }
598
+
599
+ return;
600
+ }
601
+ }
602
+ void OnOK() {
603
+ Napi::Env env = Napi::AsyncWorker::Env();
604
+ Napi::Promise::Deferred::Resolve(env.Undefined());
605
+ }
606
+ void OnError(const Napi::Error& err) {
607
+ Napi::Promise::Deferred::Reject(err.Value());
608
+ }
420
609
  };
421
610
 
422
- Napi::Value LLAMAContext::Eval(const Napi::CallbackInfo& info) {
423
- LLAMAContextEvalWorker* worker = new LLAMAContextEvalWorker(info, this);
424
- worker->Queue();
425
- return worker->Promise();
611
+ Napi::Value AddonContext::DecodeBatch(const Napi::CallbackInfo& info) {
612
+ AddonContextDecodeBatchWorker* worker = new AddonContextDecodeBatchWorker(info, this);
613
+ worker->Queue();
614
+ return worker->Promise();
426
615
  }
427
616
 
428
- Napi::Value systemInfo(const Napi::CallbackInfo& info) { return Napi::String::From(info.Env(), llama_print_system_info()); }
617
+ class AddonContextSampleTokenWorker : Napi::AsyncWorker, Napi::Promise::Deferred {
618
+ public:
619
+ AddonContext* ctx;
620
+ AddonGrammarEvaluationState* grammar_evaluation_state;
621
+ int32_t batchLogitIndex;
622
+ bool use_grammar = false;
623
+ llama_token result;
624
+ float temperature = 0.0f;
625
+ int32_t top_k = 40;
626
+ float top_p = 0.95f;
627
+ float repeat_penalty = 1.10f; // 1.0 = disabled
628
+ float repeat_penalty_presence_penalty = 0.00f; // 0.0 = disabled
629
+ float repeat_penalty_frequency_penalty = 0.00f; // 0.0 = disabled
630
+ std::vector<llama_token> repeat_penalty_tokens;
631
+ bool use_repeat_penalty = false;
632
+
633
+ AddonContextSampleTokenWorker(const Napi::CallbackInfo& info, AddonContext* ctx)
634
+ : Napi::AsyncWorker(info.Env(), "AddonContextSampleTokenWorker"),
635
+ ctx(ctx),
636
+ Napi::Promise::Deferred(info.Env()) {
637
+ ctx->Ref();
638
+
639
+ batchLogitIndex = info[0].As<Napi::Number>().Int32Value();
640
+
641
+ if (info.Length() > 1 && info[1].IsObject()) {
642
+ Napi::Object options = info[1].As<Napi::Object>();
643
+
644
+ if (options.Has("temperature")) {
645
+ temperature = options.Get("temperature").As<Napi::Number>().FloatValue();
646
+ }
647
+
648
+ if (options.Has("topK")) {
649
+ top_k = options.Get("topK").As<Napi::Number>().Int32Value();
650
+ }
651
+
652
+ if (options.Has("topP")) {
653
+ top_p = options.Get("topP").As<Napi::Number>().FloatValue();
654
+ }
655
+
656
+ if (options.Has("repeatPenalty")) {
657
+ repeat_penalty = options.Get("repeatPenalty").As<Napi::Number>().FloatValue();
658
+ }
659
+
660
+ if (options.Has("repeatPenaltyTokens")) {
661
+ Napi::Uint32Array repeat_penalty_tokens_uint32_array = options.Get("repeatPenaltyTokens").As<Napi::Uint32Array>();
662
+
663
+ repeat_penalty_tokens.reserve(repeat_penalty_tokens_uint32_array.ElementLength());
664
+ for (size_t i = 0; i < repeat_penalty_tokens_uint32_array.ElementLength(); i++) {
665
+ repeat_penalty_tokens.push_back(static_cast<llama_token>(repeat_penalty_tokens_uint32_array[i]));
666
+ }
667
+
668
+ use_repeat_penalty = true;
669
+ }
670
+
671
+ if (options.Has("repeatPenaltyPresencePenalty")) {
672
+ repeat_penalty_presence_penalty = options.Get("repeatPenaltyPresencePenalty").As<Napi::Number>().FloatValue();
673
+ }
674
+
675
+ if (options.Has("repeatPenaltyFrequencyPenalty")) {
676
+ repeat_penalty_frequency_penalty = options.Get("repeatPenaltyFrequencyPenalty").As<Napi::Number>().FloatValue();
677
+ }
678
+
679
+ if (options.Has("grammarEvaluationState")) {
680
+ grammar_evaluation_state =
681
+ Napi::ObjectWrap<AddonGrammarEvaluationState>::Unwrap(options.Get("grammarEvaluationState").As<Napi::Object>());
682
+ grammar_evaluation_state->Ref();
683
+ use_grammar = true;
684
+ }
685
+ }
686
+ }
687
+ ~AddonContextSampleTokenWorker() {
688
+ ctx->Unref();
689
+
690
+ if (use_grammar) {
691
+ grammar_evaluation_state->Unref();
692
+ use_grammar = false;
693
+ }
694
+ }
695
+ using Napi::AsyncWorker::Queue;
696
+ using Napi::Promise::Deferred::Promise;
697
+
698
+ protected:
699
+ void Execute() {
700
+ llama_token new_token_id = 0;
701
+
702
+ // Select the best prediction.
703
+ auto logits = llama_get_logits_ith(ctx->ctx, batchLogitIndex);
704
+ auto n_vocab = llama_n_vocab(ctx->model->model);
705
+
706
+ std::vector<llama_token_data> candidates;
707
+ candidates.reserve(n_vocab);
708
+
709
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
710
+ candidates.emplace_back(llama_token_data { token_id, logits[token_id], 0.0f });
711
+ }
712
+
713
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
714
+
715
+ auto eos_token = llama_token_eos(ctx->model->model);
716
+
717
+ if (use_repeat_penalty && !repeat_penalty_tokens.empty()) {
718
+ llama_sample_repetition_penalties(
719
+ ctx->ctx,
720
+ &candidates_p,
721
+ repeat_penalty_tokens.data(),
722
+ repeat_penalty_tokens.size(),
723
+ repeat_penalty,
724
+ repeat_penalty_frequency_penalty,
725
+ repeat_penalty_presence_penalty
726
+ );
727
+ }
728
+
729
+ if (use_grammar && (grammar_evaluation_state)->grammar != nullptr) {
730
+ llama_sample_grammar(ctx->ctx, &candidates_p, (grammar_evaluation_state)->grammar);
731
+ }
732
+
733
+ if (temperature <= 0) {
734
+ new_token_id = llama_sample_token_greedy(ctx->ctx, &candidates_p);
735
+ } else {
736
+ const int32_t resolved_top_k =
737
+ top_k <= 0 ? llama_n_vocab(ctx->model->model) : std::min(top_k, llama_n_vocab(ctx->model->model));
738
+ const int32_t n_probs = 0; // Number of probabilities to keep - 0 = disabled
739
+ const float tfs_z = 1.00f; // Tail free sampling - 1.0 = disabled
740
+ const float typical_p = 1.00f; // Typical probability - 1.0 = disabled
741
+ const float resolved_top_p = top_p; // Top p sampling - 1.0 = disabled
742
+
743
+ // Temperature sampling
744
+ size_t min_keep = std::max(1, n_probs);
745
+ llama_sample_top_k(ctx->ctx, &candidates_p, resolved_top_k, min_keep);
746
+ llama_sample_tail_free(ctx->ctx, &candidates_p, tfs_z, min_keep);
747
+ llama_sample_typical(ctx->ctx, &candidates_p, typical_p, min_keep);
748
+ llama_sample_top_p(ctx->ctx, &candidates_p, resolved_top_p, min_keep);
749
+ llama_sample_temp(ctx->ctx, &candidates_p, temperature);
750
+ new_token_id = llama_sample_token(ctx->ctx, &candidates_p);
751
+ }
752
+
753
+ if (new_token_id != eos_token && use_grammar && (grammar_evaluation_state)->grammar != nullptr) {
754
+ llama_grammar_accept_token(ctx->ctx, (grammar_evaluation_state)->grammar, new_token_id);
755
+ }
756
+
757
+ result = new_token_id;
758
+ }
759
+ void OnOK() {
760
+ Napi::Env env = Napi::AsyncWorker::Env();
761
+ Napi::Number resultValue = Napi::Number::New(env, static_cast<uint32_t>(result));
762
+ Napi::Promise::Deferred::Resolve(resultValue);
763
+ }
764
+ void OnError(const Napi::Error& err) {
765
+ Napi::Promise::Deferred::Reject(err.Value());
766
+ }
767
+ };
768
+
769
+ Napi::Value AddonContext::SampleToken(const Napi::CallbackInfo& info) {
770
+ AddonContextSampleTokenWorker* worker = new AddonContextSampleTokenWorker(info, this);
771
+ worker->Queue();
772
+ return worker->Promise();
773
+ }
774
+
775
+ Napi::Value systemInfo(const Napi::CallbackInfo& info) {
776
+ return Napi::String::From(info.Env(), llama_print_system_info());
777
+ }
429
778
 
430
779
  Napi::Object registerCallback(Napi::Env env, Napi::Object exports) {
431
- llama_backend_init(false);
432
- exports.DefineProperties({
433
- Napi::PropertyDescriptor::Function("systemInfo", systemInfo),
434
- });
435
- LLAMAModel::init(exports);
436
- LLAMAGrammar::init(exports);
437
- LLAMAGrammarEvaluationState::init(exports);
438
- LLAMAContext::init(exports);
439
- return exports;
780
+ llama_backend_init(false);
781
+ exports.DefineProperties({
782
+ Napi::PropertyDescriptor::Function("systemInfo", systemInfo),
783
+ });
784
+ AddonModel::init(exports);
785
+ AddonGrammar::init(exports);
786
+ AddonGrammarEvaluationState::init(exports);
787
+ AddonContext::init(exports);
788
+ return exports;
440
789
  }
441
790
 
442
791
  NODE_API_MODULE(NODE_GYP_MODULE_NAME, registerCallback)