inference-server 1.0.0-beta.19

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (227) hide show
  1. package/README.md +216 -0
  2. package/dist/api/openai/enums.d.ts +4 -0
  3. package/dist/api/openai/enums.js +17 -0
  4. package/dist/api/openai/enums.js.map +1 -0
  5. package/dist/api/openai/handlers/chat.d.ts +3 -0
  6. package/dist/api/openai/handlers/chat.js +358 -0
  7. package/dist/api/openai/handlers/chat.js.map +1 -0
  8. package/dist/api/openai/handlers/completions.d.ts +3 -0
  9. package/dist/api/openai/handlers/completions.js +169 -0
  10. package/dist/api/openai/handlers/completions.js.map +1 -0
  11. package/dist/api/openai/handlers/embeddings.d.ts +3 -0
  12. package/dist/api/openai/handlers/embeddings.js +74 -0
  13. package/dist/api/openai/handlers/embeddings.js.map +1 -0
  14. package/dist/api/openai/handlers/images.d.ts +0 -0
  15. package/dist/api/openai/handlers/images.js +4 -0
  16. package/dist/api/openai/handlers/images.js.map +1 -0
  17. package/dist/api/openai/handlers/models.d.ts +3 -0
  18. package/dist/api/openai/handlers/models.js +23 -0
  19. package/dist/api/openai/handlers/models.js.map +1 -0
  20. package/dist/api/openai/handlers/transcription.d.ts +0 -0
  21. package/dist/api/openai/handlers/transcription.js +4 -0
  22. package/dist/api/openai/handlers/transcription.js.map +1 -0
  23. package/dist/api/openai/index.d.ts +7 -0
  24. package/dist/api/openai/index.js +14 -0
  25. package/dist/api/openai/index.js.map +1 -0
  26. package/dist/api/parseJSONRequestBody.d.ts +2 -0
  27. package/dist/api/parseJSONRequestBody.js +24 -0
  28. package/dist/api/parseJSONRequestBody.js.map +1 -0
  29. package/dist/api/v1/index.d.ts +2 -0
  30. package/dist/api/v1/index.js +29 -0
  31. package/dist/api/v1/index.js.map +1 -0
  32. package/dist/cli.d.ts +1 -0
  33. package/dist/cli.js +10 -0
  34. package/dist/cli.js.map +1 -0
  35. package/dist/engines/gpt4all/engine.d.ts +34 -0
  36. package/dist/engines/gpt4all/engine.js +357 -0
  37. package/dist/engines/gpt4all/engine.js.map +1 -0
  38. package/dist/engines/gpt4all/util.d.ts +3 -0
  39. package/dist/engines/gpt4all/util.js +29 -0
  40. package/dist/engines/gpt4all/util.js.map +1 -0
  41. package/dist/engines/index.d.ts +19 -0
  42. package/dist/engines/index.js +21 -0
  43. package/dist/engines/index.js.map +1 -0
  44. package/dist/engines/node-llama-cpp/engine.d.ts +49 -0
  45. package/dist/engines/node-llama-cpp/engine.js +666 -0
  46. package/dist/engines/node-llama-cpp/engine.js.map +1 -0
  47. package/dist/engines/node-llama-cpp/types.d.ts +13 -0
  48. package/dist/engines/node-llama-cpp/types.js +2 -0
  49. package/dist/engines/node-llama-cpp/types.js.map +1 -0
  50. package/dist/engines/node-llama-cpp/util.d.ts +15 -0
  51. package/dist/engines/node-llama-cpp/util.js +84 -0
  52. package/dist/engines/node-llama-cpp/util.js.map +1 -0
  53. package/dist/engines/node-llama-cpp/validateModelFile.d.ts +8 -0
  54. package/dist/engines/node-llama-cpp/validateModelFile.js +36 -0
  55. package/dist/engines/node-llama-cpp/validateModelFile.js.map +1 -0
  56. package/dist/engines/stable-diffusion-cpp/engine.d.ts +90 -0
  57. package/dist/engines/stable-diffusion-cpp/engine.js +294 -0
  58. package/dist/engines/stable-diffusion-cpp/engine.js.map +1 -0
  59. package/dist/engines/stable-diffusion-cpp/types.d.ts +3 -0
  60. package/dist/engines/stable-diffusion-cpp/types.js +2 -0
  61. package/dist/engines/stable-diffusion-cpp/types.js.map +1 -0
  62. package/dist/engines/stable-diffusion-cpp/util.d.ts +4 -0
  63. package/dist/engines/stable-diffusion-cpp/util.js +55 -0
  64. package/dist/engines/stable-diffusion-cpp/util.js.map +1 -0
  65. package/dist/engines/stable-diffusion-cpp/validateModelFiles.d.ts +19 -0
  66. package/dist/engines/stable-diffusion-cpp/validateModelFiles.js +91 -0
  67. package/dist/engines/stable-diffusion-cpp/validateModelFiles.js.map +1 -0
  68. package/dist/engines/transformers-js/engine.d.ts +37 -0
  69. package/dist/engines/transformers-js/engine.js +538 -0
  70. package/dist/engines/transformers-js/engine.js.map +1 -0
  71. package/dist/engines/transformers-js/types.d.ts +7 -0
  72. package/dist/engines/transformers-js/types.js +2 -0
  73. package/dist/engines/transformers-js/types.js.map +1 -0
  74. package/dist/engines/transformers-js/util.d.ts +7 -0
  75. package/dist/engines/transformers-js/util.js +36 -0
  76. package/dist/engines/transformers-js/util.js.map +1 -0
  77. package/dist/engines/transformers-js/validateModelFiles.d.ts +17 -0
  78. package/dist/engines/transformers-js/validateModelFiles.js +133 -0
  79. package/dist/engines/transformers-js/validateModelFiles.js.map +1 -0
  80. package/dist/experiments/ChatWithVision.d.ts +11 -0
  81. package/dist/experiments/ChatWithVision.js +91 -0
  82. package/dist/experiments/ChatWithVision.js.map +1 -0
  83. package/dist/experiments/StableDiffPromptGenerator.d.ts +0 -0
  84. package/dist/experiments/StableDiffPromptGenerator.js +4 -0
  85. package/dist/experiments/StableDiffPromptGenerator.js.map +1 -0
  86. package/dist/experiments/VoiceFunctionCall.d.ts +18 -0
  87. package/dist/experiments/VoiceFunctionCall.js +51 -0
  88. package/dist/experiments/VoiceFunctionCall.js.map +1 -0
  89. package/dist/http.d.ts +19 -0
  90. package/dist/http.js +54 -0
  91. package/dist/http.js.map +1 -0
  92. package/dist/index.d.ts +7 -0
  93. package/dist/index.js +8 -0
  94. package/dist/index.js.map +1 -0
  95. package/dist/instance.d.ts +88 -0
  96. package/dist/instance.js +594 -0
  97. package/dist/instance.js.map +1 -0
  98. package/dist/lib/acquireFileLock.d.ts +7 -0
  99. package/dist/lib/acquireFileLock.js +38 -0
  100. package/dist/lib/acquireFileLock.js.map +1 -0
  101. package/dist/lib/calculateContextIdentity.d.ts +7 -0
  102. package/dist/lib/calculateContextIdentity.js +39 -0
  103. package/dist/lib/calculateContextIdentity.js.map +1 -0
  104. package/dist/lib/calculateFileChecksum.d.ts +1 -0
  105. package/dist/lib/calculateFileChecksum.js +16 -0
  106. package/dist/lib/calculateFileChecksum.js.map +1 -0
  107. package/dist/lib/copyDirectory.d.ts +6 -0
  108. package/dist/lib/copyDirectory.js +27 -0
  109. package/dist/lib/copyDirectory.js.map +1 -0
  110. package/dist/lib/decodeAudio.d.ts +1 -0
  111. package/dist/lib/decodeAudio.js +26 -0
  112. package/dist/lib/decodeAudio.js.map +1 -0
  113. package/dist/lib/downloadModelFile.d.ts +10 -0
  114. package/dist/lib/downloadModelFile.js +58 -0
  115. package/dist/lib/downloadModelFile.js.map +1 -0
  116. package/dist/lib/flattenMessageTextContent.d.ts +2 -0
  117. package/dist/lib/flattenMessageTextContent.js +11 -0
  118. package/dist/lib/flattenMessageTextContent.js.map +1 -0
  119. package/dist/lib/getCacheDirPath.d.ts +12 -0
  120. package/dist/lib/getCacheDirPath.js +31 -0
  121. package/dist/lib/getCacheDirPath.js.map +1 -0
  122. package/dist/lib/loadImage.d.ts +12 -0
  123. package/dist/lib/loadImage.js +30 -0
  124. package/dist/lib/loadImage.js.map +1 -0
  125. package/dist/lib/logger.d.ts +12 -0
  126. package/dist/lib/logger.js +98 -0
  127. package/dist/lib/logger.js.map +1 -0
  128. package/dist/lib/math.d.ts +7 -0
  129. package/dist/lib/math.js +30 -0
  130. package/dist/lib/math.js.map +1 -0
  131. package/dist/lib/resolveModelFileLocation.d.ts +15 -0
  132. package/dist/lib/resolveModelFileLocation.js +41 -0
  133. package/dist/lib/resolveModelFileLocation.js.map +1 -0
  134. package/dist/lib/util.d.ts +7 -0
  135. package/dist/lib/util.js +61 -0
  136. package/dist/lib/util.js.map +1 -0
  137. package/dist/lib/validateModelFile.d.ts +9 -0
  138. package/dist/lib/validateModelFile.js +62 -0
  139. package/dist/lib/validateModelFile.js.map +1 -0
  140. package/dist/lib/validateModelOptions.d.ts +3 -0
  141. package/dist/lib/validateModelOptions.js +23 -0
  142. package/dist/lib/validateModelOptions.js.map +1 -0
  143. package/dist/pool.d.ts +61 -0
  144. package/dist/pool.js +512 -0
  145. package/dist/pool.js.map +1 -0
  146. package/dist/server.d.ts +59 -0
  147. package/dist/server.js +221 -0
  148. package/dist/server.js.map +1 -0
  149. package/dist/standalone.d.ts +1 -0
  150. package/dist/standalone.js +306 -0
  151. package/dist/standalone.js.map +1 -0
  152. package/dist/store.d.ts +60 -0
  153. package/dist/store.js +203 -0
  154. package/dist/store.js.map +1 -0
  155. package/dist/types/completions.d.ts +57 -0
  156. package/dist/types/completions.js +2 -0
  157. package/dist/types/completions.js.map +1 -0
  158. package/dist/types/index.d.ts +326 -0
  159. package/dist/types/index.js +2 -0
  160. package/dist/types/index.js.map +1 -0
  161. package/docs/engines.md +28 -0
  162. package/docs/gpu.md +72 -0
  163. package/docs/http-api.md +147 -0
  164. package/examples/all-options.js +108 -0
  165. package/examples/chat-cli.js +56 -0
  166. package/examples/chat-server.js +65 -0
  167. package/examples/concurrency.js +70 -0
  168. package/examples/express.js +70 -0
  169. package/examples/pool.js +91 -0
  170. package/package.json +113 -0
  171. package/src/api/openai/enums.ts +20 -0
  172. package/src/api/openai/handlers/chat.ts +408 -0
  173. package/src/api/openai/handlers/completions.ts +196 -0
  174. package/src/api/openai/handlers/embeddings.ts +92 -0
  175. package/src/api/openai/handlers/images.ts +3 -0
  176. package/src/api/openai/handlers/models.ts +33 -0
  177. package/src/api/openai/handlers/transcription.ts +2 -0
  178. package/src/api/openai/index.ts +16 -0
  179. package/src/api/parseJSONRequestBody.ts +26 -0
  180. package/src/api/v1/DRAFT.md +16 -0
  181. package/src/api/v1/index.ts +37 -0
  182. package/src/cli.ts +9 -0
  183. package/src/engines/gpt4all/engine.ts +441 -0
  184. package/src/engines/gpt4all/util.ts +31 -0
  185. package/src/engines/index.ts +28 -0
  186. package/src/engines/node-llama-cpp/engine.ts +811 -0
  187. package/src/engines/node-llama-cpp/types.ts +17 -0
  188. package/src/engines/node-llama-cpp/util.ts +126 -0
  189. package/src/engines/node-llama-cpp/validateModelFile.ts +46 -0
  190. package/src/engines/stable-diffusion-cpp/engine.ts +369 -0
  191. package/src/engines/stable-diffusion-cpp/types.ts +54 -0
  192. package/src/engines/stable-diffusion-cpp/util.ts +58 -0
  193. package/src/engines/stable-diffusion-cpp/validateModelFiles.ts +119 -0
  194. package/src/engines/transformers-js/engine.ts +659 -0
  195. package/src/engines/transformers-js/types.ts +25 -0
  196. package/src/engines/transformers-js/util.ts +40 -0
  197. package/src/engines/transformers-js/validateModelFiles.ts +168 -0
  198. package/src/experiments/ChatWithVision.ts +103 -0
  199. package/src/experiments/StableDiffPromptGenerator.ts +2 -0
  200. package/src/experiments/VoiceFunctionCall.ts +71 -0
  201. package/src/http.ts +72 -0
  202. package/src/index.ts +7 -0
  203. package/src/instance.ts +723 -0
  204. package/src/lib/acquireFileLock.ts +38 -0
  205. package/src/lib/calculateContextIdentity.ts +53 -0
  206. package/src/lib/calculateFileChecksum.ts +18 -0
  207. package/src/lib/copyDirectory.ts +29 -0
  208. package/src/lib/decodeAudio.ts +39 -0
  209. package/src/lib/downloadModelFile.ts +70 -0
  210. package/src/lib/flattenMessageTextContent.ts +19 -0
  211. package/src/lib/getCacheDirPath.ts +34 -0
  212. package/src/lib/loadImage.ts +46 -0
  213. package/src/lib/logger.ts +112 -0
  214. package/src/lib/math.ts +31 -0
  215. package/src/lib/resolveModelFileLocation.ts +49 -0
  216. package/src/lib/util.ts +75 -0
  217. package/src/lib/validateModelFile.ts +71 -0
  218. package/src/lib/validateModelOptions.ts +31 -0
  219. package/src/pool.ts +651 -0
  220. package/src/server.ts +270 -0
  221. package/src/standalone.ts +320 -0
  222. package/src/store.ts +278 -0
  223. package/src/types/completions.ts +86 -0
  224. package/src/types/index.ts +488 -0
  225. package/tsconfig.json +29 -0
  226. package/tsconfig.release.json +11 -0
  227. package/vitest.config.ts +18 -0
@@ -0,0 +1,811 @@
1
+ import path from 'node:path'
2
+ import fs from 'node:fs'
3
+ import { nanoid } from 'nanoid'
4
+ import {
5
+ getLlama,
6
+ LlamaOptions,
7
+ LlamaChat,
8
+ LlamaModel,
9
+ LlamaContext,
10
+ LlamaCompletion,
11
+ LlamaLogLevel,
12
+ LlamaChatResponseFunctionCall,
13
+ TokenBias,
14
+ Token,
15
+ LlamaContextSequence,
16
+ LlamaGrammar,
17
+ ChatHistoryItem,
18
+ LlamaChatResponse,
19
+ ChatModelResponse,
20
+ LlamaEmbeddingContext,
21
+ defineChatSessionFunction,
22
+ GbnfJsonSchema,
23
+ ChatSessionModelFunction,
24
+ createModelDownloader,
25
+ readGgufFileInfo,
26
+ GgufFileInfo,
27
+ LlamaJsonSchemaGrammar,
28
+ LLamaChatContextShiftOptions,
29
+ LlamaContextOptions,
30
+ } from 'node-llama-cpp'
31
+ import { StopGenerationTrigger } from 'node-llama-cpp/dist/utils/StopGenerationDetector'
32
+ import {
33
+ EngineChatCompletionResult,
34
+ EngineTextCompletionResult,
35
+ EngineTextCompletionArgs,
36
+ EngineChatCompletionArgs,
37
+ EngineContext,
38
+ ToolDefinition,
39
+ ToolCallResultMessage,
40
+ AssistantMessage,
41
+ EngineEmbeddingArgs,
42
+ EngineEmbeddingResult,
43
+ FileDownloadProgress,
44
+ ModelConfig,
45
+ TextCompletionParams,
46
+ TextCompletionGrammar,
47
+ ChatMessage,
48
+ } from '#package/types/index.js'
49
+ import { LogLevels } from '#package/lib/logger.js'
50
+ import { flattenMessageTextContent } from '#package/lib/flattenMessageTextContent.js'
51
+ import { acquireFileLock } from '#package/lib/acquireFileLock.js'
52
+ import { getRandomNumber } from '#package/lib/util.js'
53
+ import { validateModelFile } from '#package/lib/validateModelFile.js'
54
+ import { createChatMessageArray, addFunctionCallToChatHistory, mapFinishReason } from './util.js'
55
+ import { LlamaChatResult } from './types.js'
56
+
57
+ export interface NodeLlamaCppInstance {
58
+ model: LlamaModel
59
+ context: LlamaContext
60
+ chat?: LlamaChat
61
+ chatHistory: ChatHistoryItem[]
62
+ grammars: Record<string, LlamaGrammar>
63
+ pendingFunctionCalls: Record<string, any>
64
+ lastEvaluation?: LlamaChatResponse['lastEvaluation']
65
+ embeddingContext?: LlamaEmbeddingContext
66
+ completion?: LlamaCompletion
67
+ contextSequence: LlamaContextSequence
68
+ }
69
+
70
+ export interface NodeLlamaCppModelMeta {
71
+ gguf: GgufFileInfo
72
+ }
73
+
74
+ export interface NodeLlamaCppModelConfig extends ModelConfig {
75
+ location: string
76
+ grammars?: Record<string, TextCompletionGrammar>
77
+ sha256?: string
78
+ completionDefaults?: TextCompletionParams
79
+ initialMessages?: ChatMessage[]
80
+ prefix?: string
81
+ tools?: {
82
+ definitions: Record<string, ToolDefinition>
83
+ includeToolDocumentation?: boolean
84
+ parallelism?: number
85
+ }
86
+ contextSize?: number
87
+ batchSize?: number
88
+ lora?: LlamaContextOptions['lora']
89
+ contextShiftStrategy?: LLamaChatContextShiftOptions['strategy']
90
+ device?: {
91
+ gpu?: boolean | 'auto' | (string & {})
92
+ gpuLayers?: number
93
+ cpuThreads?: number
94
+ memLock?: boolean
95
+ }
96
+ }
97
+
98
+ export const autoGpu = true
99
+
100
+ export async function prepareModel(
101
+ { config, log }: EngineContext<NodeLlamaCppModelConfig>,
102
+ onProgress?: (progress: FileDownloadProgress) => void,
103
+ signal?: AbortSignal,
104
+ ) {
105
+ fs.mkdirSync(path.dirname(config.location), { recursive: true })
106
+ const releaseFileLock = await acquireFileLock(config.location, signal)
107
+
108
+ if (signal?.aborted) {
109
+ releaseFileLock()
110
+ return
111
+ }
112
+ log(LogLevels.info, `Preparing node-llama-cpp model at ${config.location}`, {
113
+ model: config.id,
114
+ })
115
+ const downloadModel = async (url: string, validationResult: string) => {
116
+ log(LogLevels.info, `Downloading model files`, {
117
+ model: config.id,
118
+ url: url,
119
+ location: config.location,
120
+ error: validationResult,
121
+ })
122
+
123
+ const downloader = await createModelDownloader({
124
+ modelUrl: url,
125
+ dirPath: path.dirname(config.location),
126
+ fileName: path.basename(config.location),
127
+ deleteTempFileOnCancel: false,
128
+ onProgress: (status) => {
129
+ if (onProgress) {
130
+ onProgress({
131
+ file: config.location,
132
+ loadedBytes: status.downloadedSize,
133
+ totalBytes: status.totalSize,
134
+ })
135
+ }
136
+ },
137
+ })
138
+ await downloader.download()
139
+ }
140
+ try {
141
+ if (signal?.aborted) {
142
+ return
143
+ }
144
+
145
+ const validationError = await validateModelFile(config)
146
+ if (signal?.aborted) {
147
+ return
148
+ }
149
+ if (validationError) {
150
+ if (config.url) {
151
+ await downloadModel(config.url, validationError)
152
+ } else {
153
+ throw new Error(`${validationError} - No URL provided`)
154
+ }
155
+ }
156
+
157
+ const finalValidationError = await validateModelFile(config)
158
+ if (finalValidationError) {
159
+ throw new Error(`Downloaded files are invalid: ${finalValidationError}`)
160
+ }
161
+ const gguf = await readGgufFileInfo(config.location, {
162
+ signal,
163
+ ignoreKeys: [
164
+ 'gguf.tokenizer.ggml.merges',
165
+ 'gguf.tokenizer.ggml.tokens',
166
+ 'gguf.tokenizer.ggml.scores',
167
+ 'gguf.tokenizer.ggml.token_type',
168
+ ],
169
+ })
170
+ return {
171
+ gguf,
172
+ }
173
+ } catch (err) {
174
+ throw err
175
+ } finally {
176
+ releaseFileLock()
177
+ }
178
+ }
179
+
180
+ export async function createInstance({ config, log }: EngineContext<NodeLlamaCppModelConfig>, signal?: AbortSignal) {
181
+ log(LogLevels.debug, 'Load Llama model', config.device)
182
+ // takes "auto" | "metal" | "cuda" | "vulkan"
183
+ const gpuSetting = (config.device?.gpu ?? 'auto') as LlamaOptions['gpu']
184
+ const llama = await getLlama({
185
+ gpu: gpuSetting,
186
+ // forwarding llama logger
187
+ logLevel: LlamaLogLevel.debug,
188
+ logger: (level, message) => {
189
+ if (level === LlamaLogLevel.warn) {
190
+ log(LogLevels.warn, message)
191
+ } else if (level === LlamaLogLevel.error || level === LlamaLogLevel.fatal) {
192
+ log(LogLevels.error, message)
193
+ } else if (level === LlamaLogLevel.info || level === LlamaLogLevel.debug) {
194
+ log(LogLevels.verbose, message)
195
+ }
196
+ },
197
+ })
198
+
199
+ const llamaGrammars: Record<string, LlamaGrammar> = {
200
+ json: await LlamaGrammar.getFor(llama, 'json'),
201
+ }
202
+
203
+ if (config.grammars) {
204
+ for (const key in config.grammars) {
205
+ const input = config.grammars[key]
206
+ if (typeof input === 'string') {
207
+ llamaGrammars[key] = new LlamaGrammar(llama, {
208
+ grammar: input,
209
+ })
210
+ } else {
211
+ // assume input is a JSON schema object
212
+ llamaGrammars[key] = new LlamaJsonSchemaGrammar(llama, input as GbnfJsonSchema)
213
+ }
214
+ }
215
+ }
216
+
217
+ const llamaModel = await llama.loadModel({
218
+ modelPath: config.location, // full model absolute path
219
+ loadSignal: signal,
220
+ useMlock: config.device?.memLock ?? false,
221
+ gpuLayers: config.device?.gpuLayers,
222
+ // onLoadProgress: (percent) => {}
223
+ })
224
+
225
+ const context = await llamaModel.createContext({
226
+ sequences: 1,
227
+ lora: config.lora,
228
+ threads: config.device?.cpuThreads,
229
+ batchSize: config.batchSize,
230
+ contextSize: config.contextSize,
231
+ flashAttention: true,
232
+ createSignal: signal,
233
+ })
234
+
235
+ const instance: NodeLlamaCppInstance = {
236
+ model: llamaModel,
237
+ context,
238
+ grammars: llamaGrammars,
239
+ chat: undefined,
240
+ chatHistory: [],
241
+ pendingFunctionCalls: {},
242
+ lastEvaluation: undefined,
243
+ completion: undefined,
244
+ contextSequence: context.getSequence(),
245
+ }
246
+
247
+ if (config.initialMessages) {
248
+ const initialChatHistory = createChatMessageArray(config.initialMessages)
249
+ const chat = new LlamaChat({
250
+ contextSequence: instance.contextSequence!,
251
+ // autoDisposeSequence: true,
252
+ })
253
+
254
+ let inputFunctions: Record<string, ChatSessionModelFunction> | undefined
255
+
256
+ if (config.tools?.definitions && Object.keys(config.tools.definitions).length > 0) {
257
+ const functionDefs = config.tools.definitions
258
+ inputFunctions = {}
259
+ for (const functionName in functionDefs) {
260
+ const functionDef = functionDefs[functionName]
261
+ inputFunctions[functionName] = defineChatSessionFunction<any>({
262
+ description: functionDef.description,
263
+ params: functionDef.parameters,
264
+ handler: functionDef.handler || (() => {}),
265
+ }) as ChatSessionModelFunction
266
+ }
267
+ }
268
+
269
+ const loadMessagesRes = await chat.loadChatAndCompleteUserMessage(initialChatHistory, {
270
+ initialUserPrompt: '',
271
+ functions: inputFunctions,
272
+ documentFunctionParams: config.tools?.includeToolDocumentation,
273
+ })
274
+
275
+ instance.chat = chat
276
+ instance.chatHistory = initialChatHistory
277
+ instance.lastEvaluation = {
278
+ cleanHistory: initialChatHistory,
279
+ contextWindow: loadMessagesRes.lastEvaluation.contextWindow,
280
+ contextShiftMetadata: loadMessagesRes.lastEvaluation.contextShiftMetadata,
281
+ }
282
+ }
283
+
284
+ if (config.prefix) {
285
+ const contextSequence = instance.contextSequence!
286
+ const completion = new LlamaCompletion({
287
+ contextSequence: contextSequence,
288
+ })
289
+ await completion.generateCompletion(config.prefix, {
290
+ maxTokens: 0,
291
+ })
292
+ instance.completion = completion
293
+ instance.contextSequence = contextSequence
294
+ }
295
+
296
+ return instance
297
+ }
298
+
299
+ export async function disposeInstance(instance: NodeLlamaCppInstance) {
300
+ await instance.model.dispose()
301
+ }
302
+
303
+ export async function processChatCompletionTask(
304
+ { request, config, resetContext, log, onChunk }: EngineChatCompletionArgs<NodeLlamaCppModelConfig>,
305
+ instance: NodeLlamaCppInstance,
306
+ signal?: AbortSignal,
307
+ ): Promise<EngineChatCompletionResult> {
308
+ if (!instance.chat || resetContext) {
309
+ log(LogLevels.debug, 'Recreating chat context', {
310
+ resetContext,
311
+ willDisposeChat: !!instance.chat,
312
+ })
313
+ // if context reset is requested, dispose the chat instance
314
+ if (instance.chat) {
315
+ await instance.chat.dispose()
316
+ }
317
+ let contextSequence = instance.contextSequence
318
+ if (!contextSequence || contextSequence.disposed) {
319
+ if (instance.context.sequencesLeft) {
320
+ contextSequence = instance.context.getSequence()
321
+ instance.contextSequence = contextSequence
322
+ } else {
323
+ throw new Error('No context sequence available')
324
+ }
325
+ } else {
326
+ contextSequence.clearHistory()
327
+ }
328
+ instance.chat = new LlamaChat({
329
+ contextSequence: contextSequence,
330
+ // autoDisposeSequence: true,
331
+ })
332
+ // reset state and reingest the conversation history
333
+ instance.lastEvaluation = undefined
334
+ instance.pendingFunctionCalls = {}
335
+ instance.chatHistory = createChatMessageArray(request.messages)
336
+ // drop last user message. its gonna be added later, after resolved function calls
337
+ if (instance.chatHistory[instance.chatHistory.length - 1].type === 'user') {
338
+ instance.chatHistory.pop()
339
+ }
340
+ }
341
+
342
+ // set additional stop generation triggers for this completion
343
+ const customStopTriggers: StopGenerationTrigger[] = []
344
+ const stopTrigger = request.stop ?? config.completionDefaults?.stop
345
+ if (stopTrigger) {
346
+ customStopTriggers.push(...stopTrigger.map((t) => [t]))
347
+ }
348
+ // setting up logit/token bias dictionary
349
+ let tokenBias: TokenBias | undefined
350
+ const completionTokenBias = request.tokenBias ?? config.completionDefaults?.tokenBias
351
+ if (completionTokenBias) {
352
+ tokenBias = new TokenBias(instance.model.tokenizer)
353
+ for (const key in completionTokenBias) {
354
+ const bias = completionTokenBias[key] / 10
355
+ const tokenId = parseInt(key) as Token
356
+ if (!isNaN(tokenId)) {
357
+ tokenBias.set(tokenId, bias)
358
+ } else {
359
+ tokenBias.set(key, bias)
360
+ }
361
+ }
362
+ }
363
+
364
+ // setting up available function definitions
365
+ const functionDefinitions: Record<string, ToolDefinition> = {
366
+ ...config.tools?.definitions,
367
+ ...request.tools,
368
+ }
369
+
370
+ // see if the user submitted any function call results
371
+ const supportsParallelFunctionCalling =
372
+ instance.chat.chatWrapper.settings.functions.parallelism != null && !!config.tools?.parallelism
373
+ const resolvedFunctionCalls = []
374
+ const functionCallResultMessages = request.messages.filter((m) => m.role === 'tool') as ToolCallResultMessage[]
375
+ for (const message of functionCallResultMessages) {
376
+ if (!instance.pendingFunctionCalls[message.callId]) {
377
+ log(LogLevels.warn, `Received function result for non-existing call id "${message.callId}`)
378
+ continue
379
+ }
380
+ log(LogLevels.debug, 'Resolving pending function call', {
381
+ id: message.callId,
382
+ result: message.content,
383
+ })
384
+ const functionCall = instance.pendingFunctionCalls[message.callId]
385
+ const functionDef = functionDefinitions[functionCall.functionName]
386
+ resolvedFunctionCalls.push({
387
+ name: functionCall.functionName,
388
+ description: functionDef?.description,
389
+ params: functionCall.params,
390
+ result: message.content,
391
+ rawCall: functionCall.raw,
392
+ startsNewChunk: supportsParallelFunctionCalling,
393
+ })
394
+ delete instance.pendingFunctionCalls[message.callId]
395
+ }
396
+ // if we resolved any results, add them to history
397
+ if (resolvedFunctionCalls.length) {
398
+ instance.chatHistory.push({
399
+ type: 'model',
400
+ response: resolvedFunctionCalls.map((call) => {
401
+ return {
402
+ type: 'functionCall',
403
+ ...call,
404
+ }
405
+ }),
406
+ })
407
+ }
408
+
409
+ // add the new user message to the chat history
410
+ let assistantPrefill: string = ''
411
+ const lastMessage = request.messages[request.messages.length - 1]
412
+ if (lastMessage.role === 'user' && lastMessage.content) {
413
+ const newUserText = flattenMessageTextContent(lastMessage.content)
414
+ if (newUserText) {
415
+ instance.chatHistory.push({
416
+ type: 'user',
417
+ text: newUserText,
418
+ })
419
+ }
420
+ } else if (lastMessage.role === 'assistant') {
421
+ // use last message as prefill for response, if its an assistant message
422
+ assistantPrefill = flattenMessageTextContent(lastMessage.content)
423
+ } else if (!resolvedFunctionCalls.length) {
424
+ log(LogLevels.warn, 'Last message is not valid for chat completion. This is likely a mistake.', lastMessage)
425
+ throw new Error('Invalid last chat message')
426
+ }
427
+
428
+ // only grammar or functions can be used, not both.
429
+ // currently ignoring function definitions if grammar is provided
430
+
431
+ let inputGrammar: LlamaGrammar | undefined
432
+ let inputFunctions: Record<string, ChatSessionModelFunction> | undefined
433
+
434
+ if (request.grammar) {
435
+ if (!instance.grammars[request.grammar]) {
436
+ throw new Error(`Grammar "${request.grammar}" not found.`)
437
+ }
438
+ inputGrammar = instance.grammars[request.grammar]
439
+ } else if (Object.keys(functionDefinitions).length > 0) {
440
+ inputFunctions = {}
441
+ for (const functionName in functionDefinitions) {
442
+ const functionDef = functionDefinitions[functionName]
443
+ inputFunctions[functionName] = defineChatSessionFunction<any>({
444
+ description: functionDef.description,
445
+ params: functionDef.parameters,
446
+ handler: functionDef.handler || (() => {}),
447
+ })
448
+ }
449
+ }
450
+ const defaults = config.completionDefaults ?? {}
451
+ let lastEvaluation: LlamaChatResponse['lastEvaluation'] | undefined = instance.lastEvaluation
452
+ let newChatHistory = instance.chatHistory.slice()
453
+ let newContextWindowChatHistory = !lastEvaluation?.contextWindow ? undefined : instance.chatHistory.slice()
454
+
455
+ if (instance.chatHistory[instance.chatHistory.length - 1].type !== 'model' || assistantPrefill) {
456
+ const newModelResponse = assistantPrefill ? [assistantPrefill] : []
457
+ newChatHistory.push({
458
+ type: 'model',
459
+ response: newModelResponse,
460
+ })
461
+ if (newContextWindowChatHistory) {
462
+ newContextWindowChatHistory.push({
463
+ type: 'model',
464
+ response: newModelResponse,
465
+ })
466
+ }
467
+ }
468
+
469
+ const functionsOrGrammar = inputFunctions
470
+ ? {
471
+ functions: inputFunctions,
472
+ documentFunctionParams: config.tools?.includeToolDocumentation ?? true,
473
+ maxParallelFunctionCalls: config.tools?.parallelism ?? 1,
474
+ onFunctionCall: (functionCall: LlamaChatResponseFunctionCall<any>) => {
475
+ // log(LogLevels.debug, 'Called function', functionCall)
476
+ },
477
+ }
478
+ : {
479
+ grammar: inputGrammar,
480
+ }
481
+
482
+ const initialTokenMeterState = instance.chat.sequence.tokenMeter.getState()
483
+ let completionResult: LlamaChatResult
484
+ while (true) {
485
+ const {
486
+ functionCalls,
487
+ lastEvaluation: currentLastEvaluation,
488
+ metadata,
489
+ } = await instance.chat.generateResponse(newChatHistory, {
490
+ signal,
491
+ stopOnAbortSignal: true, // this will make aborted completions resolve (with a partial response)
492
+ maxTokens: request.maxTokens ?? defaults.maxTokens,
493
+ temperature: request.temperature ?? defaults.temperature,
494
+ topP: request.topP ?? defaults.topP,
495
+ topK: request.topK ?? defaults.topK,
496
+ minP: request.minP ?? defaults.minP,
497
+ seed: request.seed ?? config.completionDefaults?.seed ?? getRandomNumber(0, 1000000),
498
+ tokenBias,
499
+ customStopTriggers,
500
+ trimWhitespaceSuffix: false,
501
+ ...functionsOrGrammar,
502
+ repeatPenalty: {
503
+ lastTokens: request.repeatPenaltyNum ?? defaults.repeatPenaltyNum,
504
+ frequencyPenalty: request.frequencyPenalty ?? defaults.frequencyPenalty,
505
+ presencePenalty: request.presencePenalty ?? defaults.presencePenalty,
506
+ },
507
+ contextShift: {
508
+ strategy: config.contextShiftStrategy,
509
+ lastEvaluationMetadata: lastEvaluation?.contextShiftMetadata,
510
+ },
511
+ lastEvaluationContextWindow: {
512
+ history: newContextWindowChatHistory,
513
+ minimumOverlapPercentageToPreventContextShift: 0.5,
514
+ },
515
+ onToken: (tokens) => {
516
+ const text = instance.model.detokenize(tokens)
517
+ if (onChunk) {
518
+ onChunk({
519
+ tokens,
520
+ text,
521
+ })
522
+ }
523
+ },
524
+ })
525
+
526
+ lastEvaluation = currentLastEvaluation
527
+ newChatHistory = lastEvaluation.cleanHistory
528
+
529
+ if (functionCalls) {
530
+ // find leading immediately evokable function calls (=have a handler function)
531
+ const evokableFunctionCalls = []
532
+ for (const functionCall of functionCalls) {
533
+ const functionDef = functionDefinitions[functionCall.functionName]
534
+ if (functionDef.handler) {
535
+ evokableFunctionCalls.push(functionCall)
536
+ } else {
537
+ break
538
+ }
539
+ }
540
+
541
+ // resolve their results.
542
+ const results = await Promise.all(
543
+ evokableFunctionCalls.map(async (functionCall) => {
544
+ const functionDef = functionDefinitions[functionCall.functionName]
545
+ if (!functionDef) {
546
+ throw new Error(`The model tried to call undefined function "${functionCall.functionName}"`)
547
+ }
548
+ const functionCallResult = await functionDef.handler!(functionCall.params)
549
+ log(LogLevels.debug, 'Function handler resolved', {
550
+ function: functionCall.functionName,
551
+ args: functionCall.params,
552
+ result: functionCallResult,
553
+ })
554
+ return {
555
+ functionDef,
556
+ functionCall,
557
+ functionCallResult,
558
+ }
559
+ }),
560
+ )
561
+ newContextWindowChatHistory = lastEvaluation.contextWindow
562
+ let startNewChunk = true
563
+ // add results to chat history in the order they were called
564
+ for (const callResult of results) {
565
+ newChatHistory = addFunctionCallToChatHistory({
566
+ chatHistory: newChatHistory,
567
+ functionName: callResult.functionCall.functionName,
568
+ functionDescription: callResult.functionDef.description,
569
+ callParams: callResult.functionCall.params,
570
+ callResult: callResult.functionCallResult,
571
+ rawCall: callResult.functionCall.raw,
572
+ startsNewChunk: startNewChunk,
573
+ })
574
+ newContextWindowChatHistory = addFunctionCallToChatHistory({
575
+ chatHistory: newContextWindowChatHistory,
576
+ functionName: callResult.functionCall.functionName,
577
+ functionDescription: callResult.functionDef.description,
578
+ callParams: callResult.functionCall.params,
579
+ callResult: callResult.functionCallResult,
580
+ rawCall: callResult.functionCall.raw,
581
+ startsNewChunk: startNewChunk,
582
+ })
583
+ startNewChunk = false
584
+ }
585
+
586
+ // check if all function calls were immediately evokable
587
+ const remainingFunctionCalls = functionCalls.slice(evokableFunctionCalls.length)
588
+
589
+ if (remainingFunctionCalls.length === 0) {
590
+ // if yes, continue with generation
591
+ lastEvaluation.cleanHistory = newChatHistory
592
+ lastEvaluation.contextWindow = newContextWindowChatHistory!
593
+ continue
594
+ } else {
595
+ // if no, return the function calls and skip generation
596
+ completionResult = {
597
+ responseText: null,
598
+ stopReason: 'functionCalls',
599
+ functionCalls: remainingFunctionCalls,
600
+ }
601
+ break
602
+ }
603
+ }
604
+
605
+ // no function calls happened, we got a model response.
606
+ instance.lastEvaluation = lastEvaluation
607
+ instance.chatHistory = newChatHistory
608
+ const lastMessage = instance.chatHistory[instance.chatHistory.length - 1] as ChatModelResponse
609
+ const responseText = lastMessage.response.filter((item: any) => typeof item === 'string').join('')
610
+ completionResult = {
611
+ responseText,
612
+ stopReason: metadata.stopReason,
613
+ }
614
+ break
615
+ }
616
+
617
+ const assistantMessage: AssistantMessage = {
618
+ role: 'assistant',
619
+ content: completionResult.responseText || '',
620
+ }
621
+
622
+ if (completionResult.functionCalls) {
623
+ // TODO its possible that there are trailing immediately-evaluatable function calls.
624
+ // function call results need to be added in the order the functions were called, so
625
+ // we need to wait for the pending calls to complete before we can add the trailing calls.
626
+ // as is, these may never resolve
627
+ const pendingFunctionCalls = completionResult.functionCalls.filter((call) => {
628
+ const functionDef = functionDefinitions[call.functionName]
629
+ return !functionDef.handler
630
+ })
631
+
632
+ // TODO write a test that triggers a parallel call to a deferred function and to an IE function
633
+ const trailingFunctionCalls = completionResult.functionCalls.filter((call) => {
634
+ const functionDef = functionDefinitions[call.functionName]
635
+ return functionDef.handler
636
+ })
637
+ if (trailingFunctionCalls.length) {
638
+ console.debug(trailingFunctionCalls)
639
+ log(LogLevels.warn, 'Trailing function calls not resolved')
640
+ }
641
+
642
+ assistantMessage.toolCalls = pendingFunctionCalls.map((call) => {
643
+ const callId = nanoid()
644
+ instance.pendingFunctionCalls[callId] = call
645
+ log(LogLevels.debug, 'Saving pending tool call', {
646
+ id: callId,
647
+ function: call.functionName,
648
+ args: call.params,
649
+ })
650
+ return {
651
+ id: callId,
652
+ name: call.functionName,
653
+ parameters: call.params,
654
+ }
655
+ })
656
+ }
657
+
658
+ const tokenDifference = instance.chat.sequence.tokenMeter.diff(initialTokenMeterState)
659
+ return {
660
+ finishReason: mapFinishReason(completionResult.stopReason),
661
+ message: assistantMessage,
662
+ promptTokens: tokenDifference.usedInputTokens,
663
+ completionTokens: tokenDifference.usedOutputTokens,
664
+ contextTokens: instance.chat.sequence.contextTokens.length,
665
+ }
666
+ }
667
+
668
+ export async function processTextCompletionTask(
669
+ { request, config, resetContext, log, onChunk }: EngineTextCompletionArgs<NodeLlamaCppModelConfig>,
670
+ instance: NodeLlamaCppInstance,
671
+ signal?: AbortSignal,
672
+ ): Promise<EngineTextCompletionResult> {
673
+ if (!request.prompt) {
674
+ throw new Error('Prompt is required for text completion.')
675
+ }
676
+
677
+ let completion: LlamaCompletion
678
+ let contextSequence: LlamaContextSequence
679
+
680
+ if (resetContext && instance.contextSequence) {
681
+ instance.contextSequence.clearHistory()
682
+ }
683
+
684
+ if (!instance.completion || instance.completion.disposed) {
685
+ if (instance.contextSequence) {
686
+ contextSequence = instance.contextSequence
687
+ } else if (instance.context.sequencesLeft) {
688
+ contextSequence = instance.context.getSequence()
689
+ } else {
690
+ throw new Error('No context sequence available')
691
+ }
692
+ instance.contextSequence = contextSequence
693
+ completion = new LlamaCompletion({
694
+ contextSequence,
695
+ })
696
+ instance.completion = completion
697
+ } else {
698
+ completion = instance.completion
699
+ contextSequence = instance.contextSequence!
700
+ }
701
+
702
+ if (!contextSequence || contextSequence.disposed) {
703
+ contextSequence = instance.context.getSequence()
704
+ instance.contextSequence = contextSequence
705
+ completion = new LlamaCompletion({
706
+ contextSequence,
707
+ })
708
+ instance.completion = completion
709
+ }
710
+
711
+ const stopGenerationTriggers: StopGenerationTrigger[] = []
712
+ const stopTrigger = request.stop ?? config.completionDefaults?.stop
713
+ if (stopTrigger) {
714
+ stopGenerationTriggers.push(...stopTrigger.map((t) => [t]))
715
+ }
716
+
717
+ const initialTokenMeterState = contextSequence.tokenMeter.getState()
718
+ const defaults = config.completionDefaults ?? {}
719
+ const result = await completion.generateCompletionWithMeta(request.prompt, {
720
+ maxTokens: request.maxTokens ?? defaults.maxTokens,
721
+ temperature: request.temperature ?? defaults.temperature,
722
+ topP: request.topP ?? defaults.topP,
723
+ topK: request.topK ?? defaults.topK,
724
+ minP: request.minP ?? defaults.minP,
725
+ repeatPenalty: {
726
+ lastTokens: request.repeatPenaltyNum ?? defaults.repeatPenaltyNum,
727
+ frequencyPenalty: request.frequencyPenalty ?? defaults.frequencyPenalty,
728
+ presencePenalty: request.presencePenalty ?? defaults.presencePenalty,
729
+ },
730
+ signal: signal,
731
+ customStopTriggers: stopGenerationTriggers.length ? stopGenerationTriggers : undefined,
732
+ seed: request.seed ?? config.completionDefaults?.seed ?? getRandomNumber(0, 1000000),
733
+ onToken: (tokens) => {
734
+ const text = instance.model.detokenize(tokens)
735
+ if (onChunk) {
736
+ onChunk({
737
+ tokens,
738
+ text,
739
+ })
740
+ }
741
+ },
742
+ })
743
+
744
+ const tokenDifference = contextSequence.tokenMeter.diff(initialTokenMeterState)
745
+
746
+ return {
747
+ finishReason: mapFinishReason(result.metadata.stopReason),
748
+ text: result.response,
749
+ promptTokens: tokenDifference.usedInputTokens,
750
+ completionTokens: tokenDifference.usedOutputTokens,
751
+ contextTokens: contextSequence.contextTokens.length,
752
+ }
753
+ }
754
+
755
+ export async function processEmbeddingTask(
756
+ { request, config, log }: EngineEmbeddingArgs<NodeLlamaCppModelConfig>,
757
+ instance: NodeLlamaCppInstance,
758
+ signal?: AbortSignal,
759
+ ): Promise<EngineEmbeddingResult> {
760
+ if (!request.input) {
761
+ throw new Error('Input is required for embedding.')
762
+ }
763
+ const texts: string[] = []
764
+ if (typeof request.input === 'string') {
765
+ texts.push(request.input)
766
+ } else if (Array.isArray(request.input)) {
767
+ for (const input of request.input) {
768
+ if (typeof input === 'string') {
769
+ texts.push(input)
770
+ } else if (input.type === 'text') {
771
+ texts.push(input.content)
772
+ } else if (input.type === 'image') {
773
+ throw new Error('Image inputs not implemented.')
774
+ }
775
+ }
776
+ }
777
+
778
+ if (!instance.embeddingContext) {
779
+ instance.embeddingContext = await instance.model.createEmbeddingContext({
780
+ batchSize: config.batchSize,
781
+ createSignal: signal,
782
+ threads: config.device?.cpuThreads,
783
+ contextSize: config.contextSize,
784
+ })
785
+ }
786
+
787
+ // @ts-ignore - private property
788
+ const contextSize = instance.embeddingContext._llamaContext.contextSize
789
+
790
+ const embeddings: Float32Array[] = []
791
+ let inputTokens = 0
792
+
793
+ for (const text of texts) {
794
+ let tokenizedInput = instance.model.tokenize(text)
795
+ if (tokenizedInput.length > contextSize) {
796
+ log(LogLevels.warn, 'Truncated input that exceeds context size')
797
+ tokenizedInput = tokenizedInput.slice(0, contextSize)
798
+ }
799
+ inputTokens += tokenizedInput.length
800
+ const embedding = await instance.embeddingContext.getEmbeddingFor(tokenizedInput)
801
+ embeddings.push(new Float32Array(embedding.vector))
802
+ if (signal?.aborted) {
803
+ break
804
+ }
805
+ }
806
+
807
+ return {
808
+ embeddings,
809
+ inputTokens,
810
+ }
811
+ }