inference-server 1.0.0-beta.19

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (227) hide show
  1. package/README.md +216 -0
  2. package/dist/api/openai/enums.d.ts +4 -0
  3. package/dist/api/openai/enums.js +17 -0
  4. package/dist/api/openai/enums.js.map +1 -0
  5. package/dist/api/openai/handlers/chat.d.ts +3 -0
  6. package/dist/api/openai/handlers/chat.js +358 -0
  7. package/dist/api/openai/handlers/chat.js.map +1 -0
  8. package/dist/api/openai/handlers/completions.d.ts +3 -0
  9. package/dist/api/openai/handlers/completions.js +169 -0
  10. package/dist/api/openai/handlers/completions.js.map +1 -0
  11. package/dist/api/openai/handlers/embeddings.d.ts +3 -0
  12. package/dist/api/openai/handlers/embeddings.js +74 -0
  13. package/dist/api/openai/handlers/embeddings.js.map +1 -0
  14. package/dist/api/openai/handlers/images.d.ts +0 -0
  15. package/dist/api/openai/handlers/images.js +4 -0
  16. package/dist/api/openai/handlers/images.js.map +1 -0
  17. package/dist/api/openai/handlers/models.d.ts +3 -0
  18. package/dist/api/openai/handlers/models.js +23 -0
  19. package/dist/api/openai/handlers/models.js.map +1 -0
  20. package/dist/api/openai/handlers/transcription.d.ts +0 -0
  21. package/dist/api/openai/handlers/transcription.js +4 -0
  22. package/dist/api/openai/handlers/transcription.js.map +1 -0
  23. package/dist/api/openai/index.d.ts +7 -0
  24. package/dist/api/openai/index.js +14 -0
  25. package/dist/api/openai/index.js.map +1 -0
  26. package/dist/api/parseJSONRequestBody.d.ts +2 -0
  27. package/dist/api/parseJSONRequestBody.js +24 -0
  28. package/dist/api/parseJSONRequestBody.js.map +1 -0
  29. package/dist/api/v1/index.d.ts +2 -0
  30. package/dist/api/v1/index.js +29 -0
  31. package/dist/api/v1/index.js.map +1 -0
  32. package/dist/cli.d.ts +1 -0
  33. package/dist/cli.js +10 -0
  34. package/dist/cli.js.map +1 -0
  35. package/dist/engines/gpt4all/engine.d.ts +34 -0
  36. package/dist/engines/gpt4all/engine.js +357 -0
  37. package/dist/engines/gpt4all/engine.js.map +1 -0
  38. package/dist/engines/gpt4all/util.d.ts +3 -0
  39. package/dist/engines/gpt4all/util.js +29 -0
  40. package/dist/engines/gpt4all/util.js.map +1 -0
  41. package/dist/engines/index.d.ts +19 -0
  42. package/dist/engines/index.js +21 -0
  43. package/dist/engines/index.js.map +1 -0
  44. package/dist/engines/node-llama-cpp/engine.d.ts +49 -0
  45. package/dist/engines/node-llama-cpp/engine.js +666 -0
  46. package/dist/engines/node-llama-cpp/engine.js.map +1 -0
  47. package/dist/engines/node-llama-cpp/types.d.ts +13 -0
  48. package/dist/engines/node-llama-cpp/types.js +2 -0
  49. package/dist/engines/node-llama-cpp/types.js.map +1 -0
  50. package/dist/engines/node-llama-cpp/util.d.ts +15 -0
  51. package/dist/engines/node-llama-cpp/util.js +84 -0
  52. package/dist/engines/node-llama-cpp/util.js.map +1 -0
  53. package/dist/engines/node-llama-cpp/validateModelFile.d.ts +8 -0
  54. package/dist/engines/node-llama-cpp/validateModelFile.js +36 -0
  55. package/dist/engines/node-llama-cpp/validateModelFile.js.map +1 -0
  56. package/dist/engines/stable-diffusion-cpp/engine.d.ts +90 -0
  57. package/dist/engines/stable-diffusion-cpp/engine.js +294 -0
  58. package/dist/engines/stable-diffusion-cpp/engine.js.map +1 -0
  59. package/dist/engines/stable-diffusion-cpp/types.d.ts +3 -0
  60. package/dist/engines/stable-diffusion-cpp/types.js +2 -0
  61. package/dist/engines/stable-diffusion-cpp/types.js.map +1 -0
  62. package/dist/engines/stable-diffusion-cpp/util.d.ts +4 -0
  63. package/dist/engines/stable-diffusion-cpp/util.js +55 -0
  64. package/dist/engines/stable-diffusion-cpp/util.js.map +1 -0
  65. package/dist/engines/stable-diffusion-cpp/validateModelFiles.d.ts +19 -0
  66. package/dist/engines/stable-diffusion-cpp/validateModelFiles.js +91 -0
  67. package/dist/engines/stable-diffusion-cpp/validateModelFiles.js.map +1 -0
  68. package/dist/engines/transformers-js/engine.d.ts +37 -0
  69. package/dist/engines/transformers-js/engine.js +538 -0
  70. package/dist/engines/transformers-js/engine.js.map +1 -0
  71. package/dist/engines/transformers-js/types.d.ts +7 -0
  72. package/dist/engines/transformers-js/types.js +2 -0
  73. package/dist/engines/transformers-js/types.js.map +1 -0
  74. package/dist/engines/transformers-js/util.d.ts +7 -0
  75. package/dist/engines/transformers-js/util.js +36 -0
  76. package/dist/engines/transformers-js/util.js.map +1 -0
  77. package/dist/engines/transformers-js/validateModelFiles.d.ts +17 -0
  78. package/dist/engines/transformers-js/validateModelFiles.js +133 -0
  79. package/dist/engines/transformers-js/validateModelFiles.js.map +1 -0
  80. package/dist/experiments/ChatWithVision.d.ts +11 -0
  81. package/dist/experiments/ChatWithVision.js +91 -0
  82. package/dist/experiments/ChatWithVision.js.map +1 -0
  83. package/dist/experiments/StableDiffPromptGenerator.d.ts +0 -0
  84. package/dist/experiments/StableDiffPromptGenerator.js +4 -0
  85. package/dist/experiments/StableDiffPromptGenerator.js.map +1 -0
  86. package/dist/experiments/VoiceFunctionCall.d.ts +18 -0
  87. package/dist/experiments/VoiceFunctionCall.js +51 -0
  88. package/dist/experiments/VoiceFunctionCall.js.map +1 -0
  89. package/dist/http.d.ts +19 -0
  90. package/dist/http.js +54 -0
  91. package/dist/http.js.map +1 -0
  92. package/dist/index.d.ts +7 -0
  93. package/dist/index.js +8 -0
  94. package/dist/index.js.map +1 -0
  95. package/dist/instance.d.ts +88 -0
  96. package/dist/instance.js +594 -0
  97. package/dist/instance.js.map +1 -0
  98. package/dist/lib/acquireFileLock.d.ts +7 -0
  99. package/dist/lib/acquireFileLock.js +38 -0
  100. package/dist/lib/acquireFileLock.js.map +1 -0
  101. package/dist/lib/calculateContextIdentity.d.ts +7 -0
  102. package/dist/lib/calculateContextIdentity.js +39 -0
  103. package/dist/lib/calculateContextIdentity.js.map +1 -0
  104. package/dist/lib/calculateFileChecksum.d.ts +1 -0
  105. package/dist/lib/calculateFileChecksum.js +16 -0
  106. package/dist/lib/calculateFileChecksum.js.map +1 -0
  107. package/dist/lib/copyDirectory.d.ts +6 -0
  108. package/dist/lib/copyDirectory.js +27 -0
  109. package/dist/lib/copyDirectory.js.map +1 -0
  110. package/dist/lib/decodeAudio.d.ts +1 -0
  111. package/dist/lib/decodeAudio.js +26 -0
  112. package/dist/lib/decodeAudio.js.map +1 -0
  113. package/dist/lib/downloadModelFile.d.ts +10 -0
  114. package/dist/lib/downloadModelFile.js +58 -0
  115. package/dist/lib/downloadModelFile.js.map +1 -0
  116. package/dist/lib/flattenMessageTextContent.d.ts +2 -0
  117. package/dist/lib/flattenMessageTextContent.js +11 -0
  118. package/dist/lib/flattenMessageTextContent.js.map +1 -0
  119. package/dist/lib/getCacheDirPath.d.ts +12 -0
  120. package/dist/lib/getCacheDirPath.js +31 -0
  121. package/dist/lib/getCacheDirPath.js.map +1 -0
  122. package/dist/lib/loadImage.d.ts +12 -0
  123. package/dist/lib/loadImage.js +30 -0
  124. package/dist/lib/loadImage.js.map +1 -0
  125. package/dist/lib/logger.d.ts +12 -0
  126. package/dist/lib/logger.js +98 -0
  127. package/dist/lib/logger.js.map +1 -0
  128. package/dist/lib/math.d.ts +7 -0
  129. package/dist/lib/math.js +30 -0
  130. package/dist/lib/math.js.map +1 -0
  131. package/dist/lib/resolveModelFileLocation.d.ts +15 -0
  132. package/dist/lib/resolveModelFileLocation.js +41 -0
  133. package/dist/lib/resolveModelFileLocation.js.map +1 -0
  134. package/dist/lib/util.d.ts +7 -0
  135. package/dist/lib/util.js +61 -0
  136. package/dist/lib/util.js.map +1 -0
  137. package/dist/lib/validateModelFile.d.ts +9 -0
  138. package/dist/lib/validateModelFile.js +62 -0
  139. package/dist/lib/validateModelFile.js.map +1 -0
  140. package/dist/lib/validateModelOptions.d.ts +3 -0
  141. package/dist/lib/validateModelOptions.js +23 -0
  142. package/dist/lib/validateModelOptions.js.map +1 -0
  143. package/dist/pool.d.ts +61 -0
  144. package/dist/pool.js +512 -0
  145. package/dist/pool.js.map +1 -0
  146. package/dist/server.d.ts +59 -0
  147. package/dist/server.js +221 -0
  148. package/dist/server.js.map +1 -0
  149. package/dist/standalone.d.ts +1 -0
  150. package/dist/standalone.js +306 -0
  151. package/dist/standalone.js.map +1 -0
  152. package/dist/store.d.ts +60 -0
  153. package/dist/store.js +203 -0
  154. package/dist/store.js.map +1 -0
  155. package/dist/types/completions.d.ts +57 -0
  156. package/dist/types/completions.js +2 -0
  157. package/dist/types/completions.js.map +1 -0
  158. package/dist/types/index.d.ts +326 -0
  159. package/dist/types/index.js +2 -0
  160. package/dist/types/index.js.map +1 -0
  161. package/docs/engines.md +28 -0
  162. package/docs/gpu.md +72 -0
  163. package/docs/http-api.md +147 -0
  164. package/examples/all-options.js +108 -0
  165. package/examples/chat-cli.js +56 -0
  166. package/examples/chat-server.js +65 -0
  167. package/examples/concurrency.js +70 -0
  168. package/examples/express.js +70 -0
  169. package/examples/pool.js +91 -0
  170. package/package.json +113 -0
  171. package/src/api/openai/enums.ts +20 -0
  172. package/src/api/openai/handlers/chat.ts +408 -0
  173. package/src/api/openai/handlers/completions.ts +196 -0
  174. package/src/api/openai/handlers/embeddings.ts +92 -0
  175. package/src/api/openai/handlers/images.ts +3 -0
  176. package/src/api/openai/handlers/models.ts +33 -0
  177. package/src/api/openai/handlers/transcription.ts +2 -0
  178. package/src/api/openai/index.ts +16 -0
  179. package/src/api/parseJSONRequestBody.ts +26 -0
  180. package/src/api/v1/DRAFT.md +16 -0
  181. package/src/api/v1/index.ts +37 -0
  182. package/src/cli.ts +9 -0
  183. package/src/engines/gpt4all/engine.ts +441 -0
  184. package/src/engines/gpt4all/util.ts +31 -0
  185. package/src/engines/index.ts +28 -0
  186. package/src/engines/node-llama-cpp/engine.ts +811 -0
  187. package/src/engines/node-llama-cpp/types.ts +17 -0
  188. package/src/engines/node-llama-cpp/util.ts +126 -0
  189. package/src/engines/node-llama-cpp/validateModelFile.ts +46 -0
  190. package/src/engines/stable-diffusion-cpp/engine.ts +369 -0
  191. package/src/engines/stable-diffusion-cpp/types.ts +54 -0
  192. package/src/engines/stable-diffusion-cpp/util.ts +58 -0
  193. package/src/engines/stable-diffusion-cpp/validateModelFiles.ts +119 -0
  194. package/src/engines/transformers-js/engine.ts +659 -0
  195. package/src/engines/transformers-js/types.ts +25 -0
  196. package/src/engines/transformers-js/util.ts +40 -0
  197. package/src/engines/transformers-js/validateModelFiles.ts +168 -0
  198. package/src/experiments/ChatWithVision.ts +103 -0
  199. package/src/experiments/StableDiffPromptGenerator.ts +2 -0
  200. package/src/experiments/VoiceFunctionCall.ts +71 -0
  201. package/src/http.ts +72 -0
  202. package/src/index.ts +7 -0
  203. package/src/instance.ts +723 -0
  204. package/src/lib/acquireFileLock.ts +38 -0
  205. package/src/lib/calculateContextIdentity.ts +53 -0
  206. package/src/lib/calculateFileChecksum.ts +18 -0
  207. package/src/lib/copyDirectory.ts +29 -0
  208. package/src/lib/decodeAudio.ts +39 -0
  209. package/src/lib/downloadModelFile.ts +70 -0
  210. package/src/lib/flattenMessageTextContent.ts +19 -0
  211. package/src/lib/getCacheDirPath.ts +34 -0
  212. package/src/lib/loadImage.ts +46 -0
  213. package/src/lib/logger.ts +112 -0
  214. package/src/lib/math.ts +31 -0
  215. package/src/lib/resolveModelFileLocation.ts +49 -0
  216. package/src/lib/util.ts +75 -0
  217. package/src/lib/validateModelFile.ts +71 -0
  218. package/src/lib/validateModelOptions.ts +31 -0
  219. package/src/pool.ts +651 -0
  220. package/src/server.ts +270 -0
  221. package/src/standalone.ts +320 -0
  222. package/src/store.ts +278 -0
  223. package/src/types/completions.ts +86 -0
  224. package/src/types/index.ts +488 -0
  225. package/tsconfig.json +29 -0
  226. package/tsconfig.release.json +11 -0
  227. package/vitest.config.ts +18 -0
@@ -0,0 +1,666 @@
1
+ import path from 'node:path';
2
+ import fs from 'node:fs';
3
+ import { nanoid } from 'nanoid';
4
+ import { getLlama, LlamaChat, LlamaCompletion, LlamaLogLevel, TokenBias, LlamaGrammar, defineChatSessionFunction, createModelDownloader, readGgufFileInfo, LlamaJsonSchemaGrammar, } from 'node-llama-cpp';
5
+ import { LogLevels } from '../../lib/logger.js';
6
+ import { flattenMessageTextContent } from '../../lib/flattenMessageTextContent.js';
7
+ import { acquireFileLock } from '../../lib/acquireFileLock.js';
8
+ import { getRandomNumber } from '../../lib/util.js';
9
+ import { validateModelFile } from '../../lib/validateModelFile.js';
10
+ import { createChatMessageArray, addFunctionCallToChatHistory, mapFinishReason } from './util.js';
11
+ export const autoGpu = true;
12
+ export async function prepareModel({ config, log }, onProgress, signal) {
13
+ fs.mkdirSync(path.dirname(config.location), { recursive: true });
14
+ const releaseFileLock = await acquireFileLock(config.location, signal);
15
+ if (signal?.aborted) {
16
+ releaseFileLock();
17
+ return;
18
+ }
19
+ log(LogLevels.info, `Preparing node-llama-cpp model at ${config.location}`, {
20
+ model: config.id,
21
+ });
22
+ const downloadModel = async (url, validationResult) => {
23
+ log(LogLevels.info, `Downloading model files`, {
24
+ model: config.id,
25
+ url: url,
26
+ location: config.location,
27
+ error: validationResult,
28
+ });
29
+ const downloader = await createModelDownloader({
30
+ modelUrl: url,
31
+ dirPath: path.dirname(config.location),
32
+ fileName: path.basename(config.location),
33
+ deleteTempFileOnCancel: false,
34
+ onProgress: (status) => {
35
+ if (onProgress) {
36
+ onProgress({
37
+ file: config.location,
38
+ loadedBytes: status.downloadedSize,
39
+ totalBytes: status.totalSize,
40
+ });
41
+ }
42
+ },
43
+ });
44
+ await downloader.download();
45
+ };
46
+ try {
47
+ if (signal?.aborted) {
48
+ return;
49
+ }
50
+ const validationError = await validateModelFile(config);
51
+ if (signal?.aborted) {
52
+ return;
53
+ }
54
+ if (validationError) {
55
+ if (config.url) {
56
+ await downloadModel(config.url, validationError);
57
+ }
58
+ else {
59
+ throw new Error(`${validationError} - No URL provided`);
60
+ }
61
+ }
62
+ const finalValidationError = await validateModelFile(config);
63
+ if (finalValidationError) {
64
+ throw new Error(`Downloaded files are invalid: ${finalValidationError}`);
65
+ }
66
+ const gguf = await readGgufFileInfo(config.location, {
67
+ signal,
68
+ ignoreKeys: [
69
+ 'gguf.tokenizer.ggml.merges',
70
+ 'gguf.tokenizer.ggml.tokens',
71
+ 'gguf.tokenizer.ggml.scores',
72
+ 'gguf.tokenizer.ggml.token_type',
73
+ ],
74
+ });
75
+ return {
76
+ gguf,
77
+ };
78
+ }
79
+ catch (err) {
80
+ throw err;
81
+ }
82
+ finally {
83
+ releaseFileLock();
84
+ }
85
+ }
86
+ export async function createInstance({ config, log }, signal) {
87
+ log(LogLevels.debug, 'Load Llama model', config.device);
88
+ // takes "auto" | "metal" | "cuda" | "vulkan"
89
+ const gpuSetting = (config.device?.gpu ?? 'auto');
90
+ const llama = await getLlama({
91
+ gpu: gpuSetting,
92
+ // forwarding llama logger
93
+ logLevel: LlamaLogLevel.debug,
94
+ logger: (level, message) => {
95
+ if (level === LlamaLogLevel.warn) {
96
+ log(LogLevels.warn, message);
97
+ }
98
+ else if (level === LlamaLogLevel.error || level === LlamaLogLevel.fatal) {
99
+ log(LogLevels.error, message);
100
+ }
101
+ else if (level === LlamaLogLevel.info || level === LlamaLogLevel.debug) {
102
+ log(LogLevels.verbose, message);
103
+ }
104
+ },
105
+ });
106
+ const llamaGrammars = {
107
+ json: await LlamaGrammar.getFor(llama, 'json'),
108
+ };
109
+ if (config.grammars) {
110
+ for (const key in config.grammars) {
111
+ const input = config.grammars[key];
112
+ if (typeof input === 'string') {
113
+ llamaGrammars[key] = new LlamaGrammar(llama, {
114
+ grammar: input,
115
+ });
116
+ }
117
+ else {
118
+ // assume input is a JSON schema object
119
+ llamaGrammars[key] = new LlamaJsonSchemaGrammar(llama, input);
120
+ }
121
+ }
122
+ }
123
+ const llamaModel = await llama.loadModel({
124
+ modelPath: config.location, // full model absolute path
125
+ loadSignal: signal,
126
+ useMlock: config.device?.memLock ?? false,
127
+ gpuLayers: config.device?.gpuLayers,
128
+ // onLoadProgress: (percent) => {}
129
+ });
130
+ const context = await llamaModel.createContext({
131
+ sequences: 1,
132
+ lora: config.lora,
133
+ threads: config.device?.cpuThreads,
134
+ batchSize: config.batchSize,
135
+ contextSize: config.contextSize,
136
+ flashAttention: true,
137
+ createSignal: signal,
138
+ });
139
+ const instance = {
140
+ model: llamaModel,
141
+ context,
142
+ grammars: llamaGrammars,
143
+ chat: undefined,
144
+ chatHistory: [],
145
+ pendingFunctionCalls: {},
146
+ lastEvaluation: undefined,
147
+ completion: undefined,
148
+ contextSequence: context.getSequence(),
149
+ };
150
+ if (config.initialMessages) {
151
+ const initialChatHistory = createChatMessageArray(config.initialMessages);
152
+ const chat = new LlamaChat({
153
+ contextSequence: instance.contextSequence,
154
+ // autoDisposeSequence: true,
155
+ });
156
+ let inputFunctions;
157
+ if (config.tools?.definitions && Object.keys(config.tools.definitions).length > 0) {
158
+ const functionDefs = config.tools.definitions;
159
+ inputFunctions = {};
160
+ for (const functionName in functionDefs) {
161
+ const functionDef = functionDefs[functionName];
162
+ inputFunctions[functionName] = defineChatSessionFunction({
163
+ description: functionDef.description,
164
+ params: functionDef.parameters,
165
+ handler: functionDef.handler || (() => { }),
166
+ });
167
+ }
168
+ }
169
+ const loadMessagesRes = await chat.loadChatAndCompleteUserMessage(initialChatHistory, {
170
+ initialUserPrompt: '',
171
+ functions: inputFunctions,
172
+ documentFunctionParams: config.tools?.includeToolDocumentation,
173
+ });
174
+ instance.chat = chat;
175
+ instance.chatHistory = initialChatHistory;
176
+ instance.lastEvaluation = {
177
+ cleanHistory: initialChatHistory,
178
+ contextWindow: loadMessagesRes.lastEvaluation.contextWindow,
179
+ contextShiftMetadata: loadMessagesRes.lastEvaluation.contextShiftMetadata,
180
+ };
181
+ }
182
+ if (config.prefix) {
183
+ const contextSequence = instance.contextSequence;
184
+ const completion = new LlamaCompletion({
185
+ contextSequence: contextSequence,
186
+ });
187
+ await completion.generateCompletion(config.prefix, {
188
+ maxTokens: 0,
189
+ });
190
+ instance.completion = completion;
191
+ instance.contextSequence = contextSequence;
192
+ }
193
+ return instance;
194
+ }
195
+ export async function disposeInstance(instance) {
196
+ await instance.model.dispose();
197
+ }
198
+ export async function processChatCompletionTask({ request, config, resetContext, log, onChunk }, instance, signal) {
199
+ if (!instance.chat || resetContext) {
200
+ log(LogLevels.debug, 'Recreating chat context', {
201
+ resetContext,
202
+ willDisposeChat: !!instance.chat,
203
+ });
204
+ // if context reset is requested, dispose the chat instance
205
+ if (instance.chat) {
206
+ await instance.chat.dispose();
207
+ }
208
+ let contextSequence = instance.contextSequence;
209
+ if (!contextSequence || contextSequence.disposed) {
210
+ if (instance.context.sequencesLeft) {
211
+ contextSequence = instance.context.getSequence();
212
+ instance.contextSequence = contextSequence;
213
+ }
214
+ else {
215
+ throw new Error('No context sequence available');
216
+ }
217
+ }
218
+ else {
219
+ contextSequence.clearHistory();
220
+ }
221
+ instance.chat = new LlamaChat({
222
+ contextSequence: contextSequence,
223
+ // autoDisposeSequence: true,
224
+ });
225
+ // reset state and reingest the conversation history
226
+ instance.lastEvaluation = undefined;
227
+ instance.pendingFunctionCalls = {};
228
+ instance.chatHistory = createChatMessageArray(request.messages);
229
+ // drop last user message. its gonna be added later, after resolved function calls
230
+ if (instance.chatHistory[instance.chatHistory.length - 1].type === 'user') {
231
+ instance.chatHistory.pop();
232
+ }
233
+ }
234
+ // set additional stop generation triggers for this completion
235
+ const customStopTriggers = [];
236
+ const stopTrigger = request.stop ?? config.completionDefaults?.stop;
237
+ if (stopTrigger) {
238
+ customStopTriggers.push(...stopTrigger.map((t) => [t]));
239
+ }
240
+ // setting up logit/token bias dictionary
241
+ let tokenBias;
242
+ const completionTokenBias = request.tokenBias ?? config.completionDefaults?.tokenBias;
243
+ if (completionTokenBias) {
244
+ tokenBias = new TokenBias(instance.model.tokenizer);
245
+ for (const key in completionTokenBias) {
246
+ const bias = completionTokenBias[key] / 10;
247
+ const tokenId = parseInt(key);
248
+ if (!isNaN(tokenId)) {
249
+ tokenBias.set(tokenId, bias);
250
+ }
251
+ else {
252
+ tokenBias.set(key, bias);
253
+ }
254
+ }
255
+ }
256
+ // setting up available function definitions
257
+ const functionDefinitions = {
258
+ ...config.tools?.definitions,
259
+ ...request.tools,
260
+ };
261
+ // see if the user submitted any function call results
262
+ const supportsParallelFunctionCalling = instance.chat.chatWrapper.settings.functions.parallelism != null && !!config.tools?.parallelism;
263
+ const resolvedFunctionCalls = [];
264
+ const functionCallResultMessages = request.messages.filter((m) => m.role === 'tool');
265
+ for (const message of functionCallResultMessages) {
266
+ if (!instance.pendingFunctionCalls[message.callId]) {
267
+ log(LogLevels.warn, `Received function result for non-existing call id "${message.callId}`);
268
+ continue;
269
+ }
270
+ log(LogLevels.debug, 'Resolving pending function call', {
271
+ id: message.callId,
272
+ result: message.content,
273
+ });
274
+ const functionCall = instance.pendingFunctionCalls[message.callId];
275
+ const functionDef = functionDefinitions[functionCall.functionName];
276
+ resolvedFunctionCalls.push({
277
+ name: functionCall.functionName,
278
+ description: functionDef?.description,
279
+ params: functionCall.params,
280
+ result: message.content,
281
+ rawCall: functionCall.raw,
282
+ startsNewChunk: supportsParallelFunctionCalling,
283
+ });
284
+ delete instance.pendingFunctionCalls[message.callId];
285
+ }
286
+ // if we resolved any results, add them to history
287
+ if (resolvedFunctionCalls.length) {
288
+ instance.chatHistory.push({
289
+ type: 'model',
290
+ response: resolvedFunctionCalls.map((call) => {
291
+ return {
292
+ type: 'functionCall',
293
+ ...call,
294
+ };
295
+ }),
296
+ });
297
+ }
298
+ // add the new user message to the chat history
299
+ let assistantPrefill = '';
300
+ const lastMessage = request.messages[request.messages.length - 1];
301
+ if (lastMessage.role === 'user' && lastMessage.content) {
302
+ const newUserText = flattenMessageTextContent(lastMessage.content);
303
+ if (newUserText) {
304
+ instance.chatHistory.push({
305
+ type: 'user',
306
+ text: newUserText,
307
+ });
308
+ }
309
+ }
310
+ else if (lastMessage.role === 'assistant') {
311
+ // use last message as prefill for response, if its an assistant message
312
+ assistantPrefill = flattenMessageTextContent(lastMessage.content);
313
+ }
314
+ else if (!resolvedFunctionCalls.length) {
315
+ log(LogLevels.warn, 'Last message is not valid for chat completion. This is likely a mistake.', lastMessage);
316
+ throw new Error('Invalid last chat message');
317
+ }
318
+ // only grammar or functions can be used, not both.
319
+ // currently ignoring function definitions if grammar is provided
320
+ let inputGrammar;
321
+ let inputFunctions;
322
+ if (request.grammar) {
323
+ if (!instance.grammars[request.grammar]) {
324
+ throw new Error(`Grammar "${request.grammar}" not found.`);
325
+ }
326
+ inputGrammar = instance.grammars[request.grammar];
327
+ }
328
+ else if (Object.keys(functionDefinitions).length > 0) {
329
+ inputFunctions = {};
330
+ for (const functionName in functionDefinitions) {
331
+ const functionDef = functionDefinitions[functionName];
332
+ inputFunctions[functionName] = defineChatSessionFunction({
333
+ description: functionDef.description,
334
+ params: functionDef.parameters,
335
+ handler: functionDef.handler || (() => { }),
336
+ });
337
+ }
338
+ }
339
+ const defaults = config.completionDefaults ?? {};
340
+ let lastEvaluation = instance.lastEvaluation;
341
+ let newChatHistory = instance.chatHistory.slice();
342
+ let newContextWindowChatHistory = !lastEvaluation?.contextWindow ? undefined : instance.chatHistory.slice();
343
+ if (instance.chatHistory[instance.chatHistory.length - 1].type !== 'model' || assistantPrefill) {
344
+ const newModelResponse = assistantPrefill ? [assistantPrefill] : [];
345
+ newChatHistory.push({
346
+ type: 'model',
347
+ response: newModelResponse,
348
+ });
349
+ if (newContextWindowChatHistory) {
350
+ newContextWindowChatHistory.push({
351
+ type: 'model',
352
+ response: newModelResponse,
353
+ });
354
+ }
355
+ }
356
+ const functionsOrGrammar = inputFunctions
357
+ ? {
358
+ functions: inputFunctions,
359
+ documentFunctionParams: config.tools?.includeToolDocumentation ?? true,
360
+ maxParallelFunctionCalls: config.tools?.parallelism ?? 1,
361
+ onFunctionCall: (functionCall) => {
362
+ // log(LogLevels.debug, 'Called function', functionCall)
363
+ },
364
+ }
365
+ : {
366
+ grammar: inputGrammar,
367
+ };
368
+ const initialTokenMeterState = instance.chat.sequence.tokenMeter.getState();
369
+ let completionResult;
370
+ while (true) {
371
+ const { functionCalls, lastEvaluation: currentLastEvaluation, metadata, } = await instance.chat.generateResponse(newChatHistory, {
372
+ signal,
373
+ stopOnAbortSignal: true, // this will make aborted completions resolve (with a partial response)
374
+ maxTokens: request.maxTokens ?? defaults.maxTokens,
375
+ temperature: request.temperature ?? defaults.temperature,
376
+ topP: request.topP ?? defaults.topP,
377
+ topK: request.topK ?? defaults.topK,
378
+ minP: request.minP ?? defaults.minP,
379
+ seed: request.seed ?? config.completionDefaults?.seed ?? getRandomNumber(0, 1000000),
380
+ tokenBias,
381
+ customStopTriggers,
382
+ trimWhitespaceSuffix: false,
383
+ ...functionsOrGrammar,
384
+ repeatPenalty: {
385
+ lastTokens: request.repeatPenaltyNum ?? defaults.repeatPenaltyNum,
386
+ frequencyPenalty: request.frequencyPenalty ?? defaults.frequencyPenalty,
387
+ presencePenalty: request.presencePenalty ?? defaults.presencePenalty,
388
+ },
389
+ contextShift: {
390
+ strategy: config.contextShiftStrategy,
391
+ lastEvaluationMetadata: lastEvaluation?.contextShiftMetadata,
392
+ },
393
+ lastEvaluationContextWindow: {
394
+ history: newContextWindowChatHistory,
395
+ minimumOverlapPercentageToPreventContextShift: 0.5,
396
+ },
397
+ onToken: (tokens) => {
398
+ const text = instance.model.detokenize(tokens);
399
+ if (onChunk) {
400
+ onChunk({
401
+ tokens,
402
+ text,
403
+ });
404
+ }
405
+ },
406
+ });
407
+ lastEvaluation = currentLastEvaluation;
408
+ newChatHistory = lastEvaluation.cleanHistory;
409
+ if (functionCalls) {
410
+ // find leading immediately evokable function calls (=have a handler function)
411
+ const evokableFunctionCalls = [];
412
+ for (const functionCall of functionCalls) {
413
+ const functionDef = functionDefinitions[functionCall.functionName];
414
+ if (functionDef.handler) {
415
+ evokableFunctionCalls.push(functionCall);
416
+ }
417
+ else {
418
+ break;
419
+ }
420
+ }
421
+ // resolve their results.
422
+ const results = await Promise.all(evokableFunctionCalls.map(async (functionCall) => {
423
+ const functionDef = functionDefinitions[functionCall.functionName];
424
+ if (!functionDef) {
425
+ throw new Error(`The model tried to call undefined function "${functionCall.functionName}"`);
426
+ }
427
+ const functionCallResult = await functionDef.handler(functionCall.params);
428
+ log(LogLevels.debug, 'Function handler resolved', {
429
+ function: functionCall.functionName,
430
+ args: functionCall.params,
431
+ result: functionCallResult,
432
+ });
433
+ return {
434
+ functionDef,
435
+ functionCall,
436
+ functionCallResult,
437
+ };
438
+ }));
439
+ newContextWindowChatHistory = lastEvaluation.contextWindow;
440
+ let startNewChunk = true;
441
+ // add results to chat history in the order they were called
442
+ for (const callResult of results) {
443
+ newChatHistory = addFunctionCallToChatHistory({
444
+ chatHistory: newChatHistory,
445
+ functionName: callResult.functionCall.functionName,
446
+ functionDescription: callResult.functionDef.description,
447
+ callParams: callResult.functionCall.params,
448
+ callResult: callResult.functionCallResult,
449
+ rawCall: callResult.functionCall.raw,
450
+ startsNewChunk: startNewChunk,
451
+ });
452
+ newContextWindowChatHistory = addFunctionCallToChatHistory({
453
+ chatHistory: newContextWindowChatHistory,
454
+ functionName: callResult.functionCall.functionName,
455
+ functionDescription: callResult.functionDef.description,
456
+ callParams: callResult.functionCall.params,
457
+ callResult: callResult.functionCallResult,
458
+ rawCall: callResult.functionCall.raw,
459
+ startsNewChunk: startNewChunk,
460
+ });
461
+ startNewChunk = false;
462
+ }
463
+ // check if all function calls were immediately evokable
464
+ const remainingFunctionCalls = functionCalls.slice(evokableFunctionCalls.length);
465
+ if (remainingFunctionCalls.length === 0) {
466
+ // if yes, continue with generation
467
+ lastEvaluation.cleanHistory = newChatHistory;
468
+ lastEvaluation.contextWindow = newContextWindowChatHistory;
469
+ continue;
470
+ }
471
+ else {
472
+ // if no, return the function calls and skip generation
473
+ completionResult = {
474
+ responseText: null,
475
+ stopReason: 'functionCalls',
476
+ functionCalls: remainingFunctionCalls,
477
+ };
478
+ break;
479
+ }
480
+ }
481
+ // no function calls happened, we got a model response.
482
+ instance.lastEvaluation = lastEvaluation;
483
+ instance.chatHistory = newChatHistory;
484
+ const lastMessage = instance.chatHistory[instance.chatHistory.length - 1];
485
+ const responseText = lastMessage.response.filter((item) => typeof item === 'string').join('');
486
+ completionResult = {
487
+ responseText,
488
+ stopReason: metadata.stopReason,
489
+ };
490
+ break;
491
+ }
492
+ const assistantMessage = {
493
+ role: 'assistant',
494
+ content: completionResult.responseText || '',
495
+ };
496
+ if (completionResult.functionCalls) {
497
+ // TODO its possible that there are trailing immediately-evaluatable function calls.
498
+ // function call results need to be added in the order the functions were called, so
499
+ // we need to wait for the pending calls to complete before we can add the trailing calls.
500
+ // as is, these may never resolve
501
+ const pendingFunctionCalls = completionResult.functionCalls.filter((call) => {
502
+ const functionDef = functionDefinitions[call.functionName];
503
+ return !functionDef.handler;
504
+ });
505
+ // TODO write a test that triggers a parallel call to a deferred function and to an IE function
506
+ const trailingFunctionCalls = completionResult.functionCalls.filter((call) => {
507
+ const functionDef = functionDefinitions[call.functionName];
508
+ return functionDef.handler;
509
+ });
510
+ if (trailingFunctionCalls.length) {
511
+ console.debug(trailingFunctionCalls);
512
+ log(LogLevels.warn, 'Trailing function calls not resolved');
513
+ }
514
+ assistantMessage.toolCalls = pendingFunctionCalls.map((call) => {
515
+ const callId = nanoid();
516
+ instance.pendingFunctionCalls[callId] = call;
517
+ log(LogLevels.debug, 'Saving pending tool call', {
518
+ id: callId,
519
+ function: call.functionName,
520
+ args: call.params,
521
+ });
522
+ return {
523
+ id: callId,
524
+ name: call.functionName,
525
+ parameters: call.params,
526
+ };
527
+ });
528
+ }
529
+ const tokenDifference = instance.chat.sequence.tokenMeter.diff(initialTokenMeterState);
530
+ return {
531
+ finishReason: mapFinishReason(completionResult.stopReason),
532
+ message: assistantMessage,
533
+ promptTokens: tokenDifference.usedInputTokens,
534
+ completionTokens: tokenDifference.usedOutputTokens,
535
+ contextTokens: instance.chat.sequence.contextTokens.length,
536
+ };
537
+ }
538
+ export async function processTextCompletionTask({ request, config, resetContext, log, onChunk }, instance, signal) {
539
+ if (!request.prompt) {
540
+ throw new Error('Prompt is required for text completion.');
541
+ }
542
+ let completion;
543
+ let contextSequence;
544
+ if (resetContext && instance.contextSequence) {
545
+ instance.contextSequence.clearHistory();
546
+ }
547
+ if (!instance.completion || instance.completion.disposed) {
548
+ if (instance.contextSequence) {
549
+ contextSequence = instance.contextSequence;
550
+ }
551
+ else if (instance.context.sequencesLeft) {
552
+ contextSequence = instance.context.getSequence();
553
+ }
554
+ else {
555
+ throw new Error('No context sequence available');
556
+ }
557
+ instance.contextSequence = contextSequence;
558
+ completion = new LlamaCompletion({
559
+ contextSequence,
560
+ });
561
+ instance.completion = completion;
562
+ }
563
+ else {
564
+ completion = instance.completion;
565
+ contextSequence = instance.contextSequence;
566
+ }
567
+ if (!contextSequence || contextSequence.disposed) {
568
+ contextSequence = instance.context.getSequence();
569
+ instance.contextSequence = contextSequence;
570
+ completion = new LlamaCompletion({
571
+ contextSequence,
572
+ });
573
+ instance.completion = completion;
574
+ }
575
+ const stopGenerationTriggers = [];
576
+ const stopTrigger = request.stop ?? config.completionDefaults?.stop;
577
+ if (stopTrigger) {
578
+ stopGenerationTriggers.push(...stopTrigger.map((t) => [t]));
579
+ }
580
+ const initialTokenMeterState = contextSequence.tokenMeter.getState();
581
+ const defaults = config.completionDefaults ?? {};
582
+ const result = await completion.generateCompletionWithMeta(request.prompt, {
583
+ maxTokens: request.maxTokens ?? defaults.maxTokens,
584
+ temperature: request.temperature ?? defaults.temperature,
585
+ topP: request.topP ?? defaults.topP,
586
+ topK: request.topK ?? defaults.topK,
587
+ minP: request.minP ?? defaults.minP,
588
+ repeatPenalty: {
589
+ lastTokens: request.repeatPenaltyNum ?? defaults.repeatPenaltyNum,
590
+ frequencyPenalty: request.frequencyPenalty ?? defaults.frequencyPenalty,
591
+ presencePenalty: request.presencePenalty ?? defaults.presencePenalty,
592
+ },
593
+ signal: signal,
594
+ customStopTriggers: stopGenerationTriggers.length ? stopGenerationTriggers : undefined,
595
+ seed: request.seed ?? config.completionDefaults?.seed ?? getRandomNumber(0, 1000000),
596
+ onToken: (tokens) => {
597
+ const text = instance.model.detokenize(tokens);
598
+ if (onChunk) {
599
+ onChunk({
600
+ tokens,
601
+ text,
602
+ });
603
+ }
604
+ },
605
+ });
606
+ const tokenDifference = contextSequence.tokenMeter.diff(initialTokenMeterState);
607
+ return {
608
+ finishReason: mapFinishReason(result.metadata.stopReason),
609
+ text: result.response,
610
+ promptTokens: tokenDifference.usedInputTokens,
611
+ completionTokens: tokenDifference.usedOutputTokens,
612
+ contextTokens: contextSequence.contextTokens.length,
613
+ };
614
+ }
615
+ export async function processEmbeddingTask({ request, config, log }, instance, signal) {
616
+ if (!request.input) {
617
+ throw new Error('Input is required for embedding.');
618
+ }
619
+ const texts = [];
620
+ if (typeof request.input === 'string') {
621
+ texts.push(request.input);
622
+ }
623
+ else if (Array.isArray(request.input)) {
624
+ for (const input of request.input) {
625
+ if (typeof input === 'string') {
626
+ texts.push(input);
627
+ }
628
+ else if (input.type === 'text') {
629
+ texts.push(input.content);
630
+ }
631
+ else if (input.type === 'image') {
632
+ throw new Error('Image inputs not implemented.');
633
+ }
634
+ }
635
+ }
636
+ if (!instance.embeddingContext) {
637
+ instance.embeddingContext = await instance.model.createEmbeddingContext({
638
+ batchSize: config.batchSize,
639
+ createSignal: signal,
640
+ threads: config.device?.cpuThreads,
641
+ contextSize: config.contextSize,
642
+ });
643
+ }
644
+ // @ts-ignore - private property
645
+ const contextSize = instance.embeddingContext._llamaContext.contextSize;
646
+ const embeddings = [];
647
+ let inputTokens = 0;
648
+ for (const text of texts) {
649
+ let tokenizedInput = instance.model.tokenize(text);
650
+ if (tokenizedInput.length > contextSize) {
651
+ log(LogLevels.warn, 'Truncated input that exceeds context size');
652
+ tokenizedInput = tokenizedInput.slice(0, contextSize);
653
+ }
654
+ inputTokens += tokenizedInput.length;
655
+ const embedding = await instance.embeddingContext.getEmbeddingFor(tokenizedInput);
656
+ embeddings.push(new Float32Array(embedding.vector));
657
+ if (signal?.aborted) {
658
+ break;
659
+ }
660
+ }
661
+ return {
662
+ embeddings,
663
+ inputTokens,
664
+ };
665
+ }
666
+ //# sourceMappingURL=engine.js.map