inference-server 1.0.0-beta.19
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +216 -0
- package/dist/api/openai/enums.d.ts +4 -0
- package/dist/api/openai/enums.js +17 -0
- package/dist/api/openai/enums.js.map +1 -0
- package/dist/api/openai/handlers/chat.d.ts +3 -0
- package/dist/api/openai/handlers/chat.js +358 -0
- package/dist/api/openai/handlers/chat.js.map +1 -0
- package/dist/api/openai/handlers/completions.d.ts +3 -0
- package/dist/api/openai/handlers/completions.js +169 -0
- package/dist/api/openai/handlers/completions.js.map +1 -0
- package/dist/api/openai/handlers/embeddings.d.ts +3 -0
- package/dist/api/openai/handlers/embeddings.js +74 -0
- package/dist/api/openai/handlers/embeddings.js.map +1 -0
- package/dist/api/openai/handlers/images.d.ts +0 -0
- package/dist/api/openai/handlers/images.js +4 -0
- package/dist/api/openai/handlers/images.js.map +1 -0
- package/dist/api/openai/handlers/models.d.ts +3 -0
- package/dist/api/openai/handlers/models.js +23 -0
- package/dist/api/openai/handlers/models.js.map +1 -0
- package/dist/api/openai/handlers/transcription.d.ts +0 -0
- package/dist/api/openai/handlers/transcription.js +4 -0
- package/dist/api/openai/handlers/transcription.js.map +1 -0
- package/dist/api/openai/index.d.ts +7 -0
- package/dist/api/openai/index.js +14 -0
- package/dist/api/openai/index.js.map +1 -0
- package/dist/api/parseJSONRequestBody.d.ts +2 -0
- package/dist/api/parseJSONRequestBody.js +24 -0
- package/dist/api/parseJSONRequestBody.js.map +1 -0
- package/dist/api/v1/index.d.ts +2 -0
- package/dist/api/v1/index.js +29 -0
- package/dist/api/v1/index.js.map +1 -0
- package/dist/cli.d.ts +1 -0
- package/dist/cli.js +10 -0
- package/dist/cli.js.map +1 -0
- package/dist/engines/gpt4all/engine.d.ts +34 -0
- package/dist/engines/gpt4all/engine.js +357 -0
- package/dist/engines/gpt4all/engine.js.map +1 -0
- package/dist/engines/gpt4all/util.d.ts +3 -0
- package/dist/engines/gpt4all/util.js +29 -0
- package/dist/engines/gpt4all/util.js.map +1 -0
- package/dist/engines/index.d.ts +19 -0
- package/dist/engines/index.js +21 -0
- package/dist/engines/index.js.map +1 -0
- package/dist/engines/node-llama-cpp/engine.d.ts +49 -0
- package/dist/engines/node-llama-cpp/engine.js +666 -0
- package/dist/engines/node-llama-cpp/engine.js.map +1 -0
- package/dist/engines/node-llama-cpp/types.d.ts +13 -0
- package/dist/engines/node-llama-cpp/types.js +2 -0
- package/dist/engines/node-llama-cpp/types.js.map +1 -0
- package/dist/engines/node-llama-cpp/util.d.ts +15 -0
- package/dist/engines/node-llama-cpp/util.js +84 -0
- package/dist/engines/node-llama-cpp/util.js.map +1 -0
- package/dist/engines/node-llama-cpp/validateModelFile.d.ts +8 -0
- package/dist/engines/node-llama-cpp/validateModelFile.js +36 -0
- package/dist/engines/node-llama-cpp/validateModelFile.js.map +1 -0
- package/dist/engines/stable-diffusion-cpp/engine.d.ts +90 -0
- package/dist/engines/stable-diffusion-cpp/engine.js +294 -0
- package/dist/engines/stable-diffusion-cpp/engine.js.map +1 -0
- package/dist/engines/stable-diffusion-cpp/types.d.ts +3 -0
- package/dist/engines/stable-diffusion-cpp/types.js +2 -0
- package/dist/engines/stable-diffusion-cpp/types.js.map +1 -0
- package/dist/engines/stable-diffusion-cpp/util.d.ts +4 -0
- package/dist/engines/stable-diffusion-cpp/util.js +55 -0
- package/dist/engines/stable-diffusion-cpp/util.js.map +1 -0
- package/dist/engines/stable-diffusion-cpp/validateModelFiles.d.ts +19 -0
- package/dist/engines/stable-diffusion-cpp/validateModelFiles.js +91 -0
- package/dist/engines/stable-diffusion-cpp/validateModelFiles.js.map +1 -0
- package/dist/engines/transformers-js/engine.d.ts +37 -0
- package/dist/engines/transformers-js/engine.js +538 -0
- package/dist/engines/transformers-js/engine.js.map +1 -0
- package/dist/engines/transformers-js/types.d.ts +7 -0
- package/dist/engines/transformers-js/types.js +2 -0
- package/dist/engines/transformers-js/types.js.map +1 -0
- package/dist/engines/transformers-js/util.d.ts +7 -0
- package/dist/engines/transformers-js/util.js +36 -0
- package/dist/engines/transformers-js/util.js.map +1 -0
- package/dist/engines/transformers-js/validateModelFiles.d.ts +17 -0
- package/dist/engines/transformers-js/validateModelFiles.js +133 -0
- package/dist/engines/transformers-js/validateModelFiles.js.map +1 -0
- package/dist/experiments/ChatWithVision.d.ts +11 -0
- package/dist/experiments/ChatWithVision.js +91 -0
- package/dist/experiments/ChatWithVision.js.map +1 -0
- package/dist/experiments/StableDiffPromptGenerator.d.ts +0 -0
- package/dist/experiments/StableDiffPromptGenerator.js +4 -0
- package/dist/experiments/StableDiffPromptGenerator.js.map +1 -0
- package/dist/experiments/VoiceFunctionCall.d.ts +18 -0
- package/dist/experiments/VoiceFunctionCall.js +51 -0
- package/dist/experiments/VoiceFunctionCall.js.map +1 -0
- package/dist/http.d.ts +19 -0
- package/dist/http.js +54 -0
- package/dist/http.js.map +1 -0
- package/dist/index.d.ts +7 -0
- package/dist/index.js +8 -0
- package/dist/index.js.map +1 -0
- package/dist/instance.d.ts +88 -0
- package/dist/instance.js +594 -0
- package/dist/instance.js.map +1 -0
- package/dist/lib/acquireFileLock.d.ts +7 -0
- package/dist/lib/acquireFileLock.js +38 -0
- package/dist/lib/acquireFileLock.js.map +1 -0
- package/dist/lib/calculateContextIdentity.d.ts +7 -0
- package/dist/lib/calculateContextIdentity.js +39 -0
- package/dist/lib/calculateContextIdentity.js.map +1 -0
- package/dist/lib/calculateFileChecksum.d.ts +1 -0
- package/dist/lib/calculateFileChecksum.js +16 -0
- package/dist/lib/calculateFileChecksum.js.map +1 -0
- package/dist/lib/copyDirectory.d.ts +6 -0
- package/dist/lib/copyDirectory.js +27 -0
- package/dist/lib/copyDirectory.js.map +1 -0
- package/dist/lib/decodeAudio.d.ts +1 -0
- package/dist/lib/decodeAudio.js +26 -0
- package/dist/lib/decodeAudio.js.map +1 -0
- package/dist/lib/downloadModelFile.d.ts +10 -0
- package/dist/lib/downloadModelFile.js +58 -0
- package/dist/lib/downloadModelFile.js.map +1 -0
- package/dist/lib/flattenMessageTextContent.d.ts +2 -0
- package/dist/lib/flattenMessageTextContent.js +11 -0
- package/dist/lib/flattenMessageTextContent.js.map +1 -0
- package/dist/lib/getCacheDirPath.d.ts +12 -0
- package/dist/lib/getCacheDirPath.js +31 -0
- package/dist/lib/getCacheDirPath.js.map +1 -0
- package/dist/lib/loadImage.d.ts +12 -0
- package/dist/lib/loadImage.js +30 -0
- package/dist/lib/loadImage.js.map +1 -0
- package/dist/lib/logger.d.ts +12 -0
- package/dist/lib/logger.js +98 -0
- package/dist/lib/logger.js.map +1 -0
- package/dist/lib/math.d.ts +7 -0
- package/dist/lib/math.js +30 -0
- package/dist/lib/math.js.map +1 -0
- package/dist/lib/resolveModelFileLocation.d.ts +15 -0
- package/dist/lib/resolveModelFileLocation.js +41 -0
- package/dist/lib/resolveModelFileLocation.js.map +1 -0
- package/dist/lib/util.d.ts +7 -0
- package/dist/lib/util.js +61 -0
- package/dist/lib/util.js.map +1 -0
- package/dist/lib/validateModelFile.d.ts +9 -0
- package/dist/lib/validateModelFile.js +62 -0
- package/dist/lib/validateModelFile.js.map +1 -0
- package/dist/lib/validateModelOptions.d.ts +3 -0
- package/dist/lib/validateModelOptions.js +23 -0
- package/dist/lib/validateModelOptions.js.map +1 -0
- package/dist/pool.d.ts +61 -0
- package/dist/pool.js +512 -0
- package/dist/pool.js.map +1 -0
- package/dist/server.d.ts +59 -0
- package/dist/server.js +221 -0
- package/dist/server.js.map +1 -0
- package/dist/standalone.d.ts +1 -0
- package/dist/standalone.js +306 -0
- package/dist/standalone.js.map +1 -0
- package/dist/store.d.ts +60 -0
- package/dist/store.js +203 -0
- package/dist/store.js.map +1 -0
- package/dist/types/completions.d.ts +57 -0
- package/dist/types/completions.js +2 -0
- package/dist/types/completions.js.map +1 -0
- package/dist/types/index.d.ts +326 -0
- package/dist/types/index.js +2 -0
- package/dist/types/index.js.map +1 -0
- package/docs/engines.md +28 -0
- package/docs/gpu.md +72 -0
- package/docs/http-api.md +147 -0
- package/examples/all-options.js +108 -0
- package/examples/chat-cli.js +56 -0
- package/examples/chat-server.js +65 -0
- package/examples/concurrency.js +70 -0
- package/examples/express.js +70 -0
- package/examples/pool.js +91 -0
- package/package.json +113 -0
- package/src/api/openai/enums.ts +20 -0
- package/src/api/openai/handlers/chat.ts +408 -0
- package/src/api/openai/handlers/completions.ts +196 -0
- package/src/api/openai/handlers/embeddings.ts +92 -0
- package/src/api/openai/handlers/images.ts +3 -0
- package/src/api/openai/handlers/models.ts +33 -0
- package/src/api/openai/handlers/transcription.ts +2 -0
- package/src/api/openai/index.ts +16 -0
- package/src/api/parseJSONRequestBody.ts +26 -0
- package/src/api/v1/DRAFT.md +16 -0
- package/src/api/v1/index.ts +37 -0
- package/src/cli.ts +9 -0
- package/src/engines/gpt4all/engine.ts +441 -0
- package/src/engines/gpt4all/util.ts +31 -0
- package/src/engines/index.ts +28 -0
- package/src/engines/node-llama-cpp/engine.ts +811 -0
- package/src/engines/node-llama-cpp/types.ts +17 -0
- package/src/engines/node-llama-cpp/util.ts +126 -0
- package/src/engines/node-llama-cpp/validateModelFile.ts +46 -0
- package/src/engines/stable-diffusion-cpp/engine.ts +369 -0
- package/src/engines/stable-diffusion-cpp/types.ts +54 -0
- package/src/engines/stable-diffusion-cpp/util.ts +58 -0
- package/src/engines/stable-diffusion-cpp/validateModelFiles.ts +119 -0
- package/src/engines/transformers-js/engine.ts +659 -0
- package/src/engines/transformers-js/types.ts +25 -0
- package/src/engines/transformers-js/util.ts +40 -0
- package/src/engines/transformers-js/validateModelFiles.ts +168 -0
- package/src/experiments/ChatWithVision.ts +103 -0
- package/src/experiments/StableDiffPromptGenerator.ts +2 -0
- package/src/experiments/VoiceFunctionCall.ts +71 -0
- package/src/http.ts +72 -0
- package/src/index.ts +7 -0
- package/src/instance.ts +723 -0
- package/src/lib/acquireFileLock.ts +38 -0
- package/src/lib/calculateContextIdentity.ts +53 -0
- package/src/lib/calculateFileChecksum.ts +18 -0
- package/src/lib/copyDirectory.ts +29 -0
- package/src/lib/decodeAudio.ts +39 -0
- package/src/lib/downloadModelFile.ts +70 -0
- package/src/lib/flattenMessageTextContent.ts +19 -0
- package/src/lib/getCacheDirPath.ts +34 -0
- package/src/lib/loadImage.ts +46 -0
- package/src/lib/logger.ts +112 -0
- package/src/lib/math.ts +31 -0
- package/src/lib/resolveModelFileLocation.ts +49 -0
- package/src/lib/util.ts +75 -0
- package/src/lib/validateModelFile.ts +71 -0
- package/src/lib/validateModelOptions.ts +31 -0
- package/src/pool.ts +651 -0
- package/src/server.ts +270 -0
- package/src/standalone.ts +320 -0
- package/src/store.ts +278 -0
- package/src/types/completions.ts +86 -0
- package/src/types/index.ts +488 -0
- package/tsconfig.json +29 -0
- package/tsconfig.release.json +11 -0
- package/vitest.config.ts +18 -0
|
@@ -0,0 +1,811 @@
|
|
|
1
|
+
import path from 'node:path'
|
|
2
|
+
import fs from 'node:fs'
|
|
3
|
+
import { nanoid } from 'nanoid'
|
|
4
|
+
import {
|
|
5
|
+
getLlama,
|
|
6
|
+
LlamaOptions,
|
|
7
|
+
LlamaChat,
|
|
8
|
+
LlamaModel,
|
|
9
|
+
LlamaContext,
|
|
10
|
+
LlamaCompletion,
|
|
11
|
+
LlamaLogLevel,
|
|
12
|
+
LlamaChatResponseFunctionCall,
|
|
13
|
+
TokenBias,
|
|
14
|
+
Token,
|
|
15
|
+
LlamaContextSequence,
|
|
16
|
+
LlamaGrammar,
|
|
17
|
+
ChatHistoryItem,
|
|
18
|
+
LlamaChatResponse,
|
|
19
|
+
ChatModelResponse,
|
|
20
|
+
LlamaEmbeddingContext,
|
|
21
|
+
defineChatSessionFunction,
|
|
22
|
+
GbnfJsonSchema,
|
|
23
|
+
ChatSessionModelFunction,
|
|
24
|
+
createModelDownloader,
|
|
25
|
+
readGgufFileInfo,
|
|
26
|
+
GgufFileInfo,
|
|
27
|
+
LlamaJsonSchemaGrammar,
|
|
28
|
+
LLamaChatContextShiftOptions,
|
|
29
|
+
LlamaContextOptions,
|
|
30
|
+
} from 'node-llama-cpp'
|
|
31
|
+
import { StopGenerationTrigger } from 'node-llama-cpp/dist/utils/StopGenerationDetector'
|
|
32
|
+
import {
|
|
33
|
+
EngineChatCompletionResult,
|
|
34
|
+
EngineTextCompletionResult,
|
|
35
|
+
EngineTextCompletionArgs,
|
|
36
|
+
EngineChatCompletionArgs,
|
|
37
|
+
EngineContext,
|
|
38
|
+
ToolDefinition,
|
|
39
|
+
ToolCallResultMessage,
|
|
40
|
+
AssistantMessage,
|
|
41
|
+
EngineEmbeddingArgs,
|
|
42
|
+
EngineEmbeddingResult,
|
|
43
|
+
FileDownloadProgress,
|
|
44
|
+
ModelConfig,
|
|
45
|
+
TextCompletionParams,
|
|
46
|
+
TextCompletionGrammar,
|
|
47
|
+
ChatMessage,
|
|
48
|
+
} from '#package/types/index.js'
|
|
49
|
+
import { LogLevels } from '#package/lib/logger.js'
|
|
50
|
+
import { flattenMessageTextContent } from '#package/lib/flattenMessageTextContent.js'
|
|
51
|
+
import { acquireFileLock } from '#package/lib/acquireFileLock.js'
|
|
52
|
+
import { getRandomNumber } from '#package/lib/util.js'
|
|
53
|
+
import { validateModelFile } from '#package/lib/validateModelFile.js'
|
|
54
|
+
import { createChatMessageArray, addFunctionCallToChatHistory, mapFinishReason } from './util.js'
|
|
55
|
+
import { LlamaChatResult } from './types.js'
|
|
56
|
+
|
|
57
|
+
export interface NodeLlamaCppInstance {
|
|
58
|
+
model: LlamaModel
|
|
59
|
+
context: LlamaContext
|
|
60
|
+
chat?: LlamaChat
|
|
61
|
+
chatHistory: ChatHistoryItem[]
|
|
62
|
+
grammars: Record<string, LlamaGrammar>
|
|
63
|
+
pendingFunctionCalls: Record<string, any>
|
|
64
|
+
lastEvaluation?: LlamaChatResponse['lastEvaluation']
|
|
65
|
+
embeddingContext?: LlamaEmbeddingContext
|
|
66
|
+
completion?: LlamaCompletion
|
|
67
|
+
contextSequence: LlamaContextSequence
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
export interface NodeLlamaCppModelMeta {
|
|
71
|
+
gguf: GgufFileInfo
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
export interface NodeLlamaCppModelConfig extends ModelConfig {
|
|
75
|
+
location: string
|
|
76
|
+
grammars?: Record<string, TextCompletionGrammar>
|
|
77
|
+
sha256?: string
|
|
78
|
+
completionDefaults?: TextCompletionParams
|
|
79
|
+
initialMessages?: ChatMessage[]
|
|
80
|
+
prefix?: string
|
|
81
|
+
tools?: {
|
|
82
|
+
definitions: Record<string, ToolDefinition>
|
|
83
|
+
includeToolDocumentation?: boolean
|
|
84
|
+
parallelism?: number
|
|
85
|
+
}
|
|
86
|
+
contextSize?: number
|
|
87
|
+
batchSize?: number
|
|
88
|
+
lora?: LlamaContextOptions['lora']
|
|
89
|
+
contextShiftStrategy?: LLamaChatContextShiftOptions['strategy']
|
|
90
|
+
device?: {
|
|
91
|
+
gpu?: boolean | 'auto' | (string & {})
|
|
92
|
+
gpuLayers?: number
|
|
93
|
+
cpuThreads?: number
|
|
94
|
+
memLock?: boolean
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
export const autoGpu = true
|
|
99
|
+
|
|
100
|
+
export async function prepareModel(
|
|
101
|
+
{ config, log }: EngineContext<NodeLlamaCppModelConfig>,
|
|
102
|
+
onProgress?: (progress: FileDownloadProgress) => void,
|
|
103
|
+
signal?: AbortSignal,
|
|
104
|
+
) {
|
|
105
|
+
fs.mkdirSync(path.dirname(config.location), { recursive: true })
|
|
106
|
+
const releaseFileLock = await acquireFileLock(config.location, signal)
|
|
107
|
+
|
|
108
|
+
if (signal?.aborted) {
|
|
109
|
+
releaseFileLock()
|
|
110
|
+
return
|
|
111
|
+
}
|
|
112
|
+
log(LogLevels.info, `Preparing node-llama-cpp model at ${config.location}`, {
|
|
113
|
+
model: config.id,
|
|
114
|
+
})
|
|
115
|
+
const downloadModel = async (url: string, validationResult: string) => {
|
|
116
|
+
log(LogLevels.info, `Downloading model files`, {
|
|
117
|
+
model: config.id,
|
|
118
|
+
url: url,
|
|
119
|
+
location: config.location,
|
|
120
|
+
error: validationResult,
|
|
121
|
+
})
|
|
122
|
+
|
|
123
|
+
const downloader = await createModelDownloader({
|
|
124
|
+
modelUrl: url,
|
|
125
|
+
dirPath: path.dirname(config.location),
|
|
126
|
+
fileName: path.basename(config.location),
|
|
127
|
+
deleteTempFileOnCancel: false,
|
|
128
|
+
onProgress: (status) => {
|
|
129
|
+
if (onProgress) {
|
|
130
|
+
onProgress({
|
|
131
|
+
file: config.location,
|
|
132
|
+
loadedBytes: status.downloadedSize,
|
|
133
|
+
totalBytes: status.totalSize,
|
|
134
|
+
})
|
|
135
|
+
}
|
|
136
|
+
},
|
|
137
|
+
})
|
|
138
|
+
await downloader.download()
|
|
139
|
+
}
|
|
140
|
+
try {
|
|
141
|
+
if (signal?.aborted) {
|
|
142
|
+
return
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
const validationError = await validateModelFile(config)
|
|
146
|
+
if (signal?.aborted) {
|
|
147
|
+
return
|
|
148
|
+
}
|
|
149
|
+
if (validationError) {
|
|
150
|
+
if (config.url) {
|
|
151
|
+
await downloadModel(config.url, validationError)
|
|
152
|
+
} else {
|
|
153
|
+
throw new Error(`${validationError} - No URL provided`)
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
const finalValidationError = await validateModelFile(config)
|
|
158
|
+
if (finalValidationError) {
|
|
159
|
+
throw new Error(`Downloaded files are invalid: ${finalValidationError}`)
|
|
160
|
+
}
|
|
161
|
+
const gguf = await readGgufFileInfo(config.location, {
|
|
162
|
+
signal,
|
|
163
|
+
ignoreKeys: [
|
|
164
|
+
'gguf.tokenizer.ggml.merges',
|
|
165
|
+
'gguf.tokenizer.ggml.tokens',
|
|
166
|
+
'gguf.tokenizer.ggml.scores',
|
|
167
|
+
'gguf.tokenizer.ggml.token_type',
|
|
168
|
+
],
|
|
169
|
+
})
|
|
170
|
+
return {
|
|
171
|
+
gguf,
|
|
172
|
+
}
|
|
173
|
+
} catch (err) {
|
|
174
|
+
throw err
|
|
175
|
+
} finally {
|
|
176
|
+
releaseFileLock()
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
export async function createInstance({ config, log }: EngineContext<NodeLlamaCppModelConfig>, signal?: AbortSignal) {
|
|
181
|
+
log(LogLevels.debug, 'Load Llama model', config.device)
|
|
182
|
+
// takes "auto" | "metal" | "cuda" | "vulkan"
|
|
183
|
+
const gpuSetting = (config.device?.gpu ?? 'auto') as LlamaOptions['gpu']
|
|
184
|
+
const llama = await getLlama({
|
|
185
|
+
gpu: gpuSetting,
|
|
186
|
+
// forwarding llama logger
|
|
187
|
+
logLevel: LlamaLogLevel.debug,
|
|
188
|
+
logger: (level, message) => {
|
|
189
|
+
if (level === LlamaLogLevel.warn) {
|
|
190
|
+
log(LogLevels.warn, message)
|
|
191
|
+
} else if (level === LlamaLogLevel.error || level === LlamaLogLevel.fatal) {
|
|
192
|
+
log(LogLevels.error, message)
|
|
193
|
+
} else if (level === LlamaLogLevel.info || level === LlamaLogLevel.debug) {
|
|
194
|
+
log(LogLevels.verbose, message)
|
|
195
|
+
}
|
|
196
|
+
},
|
|
197
|
+
})
|
|
198
|
+
|
|
199
|
+
const llamaGrammars: Record<string, LlamaGrammar> = {
|
|
200
|
+
json: await LlamaGrammar.getFor(llama, 'json'),
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
if (config.grammars) {
|
|
204
|
+
for (const key in config.grammars) {
|
|
205
|
+
const input = config.grammars[key]
|
|
206
|
+
if (typeof input === 'string') {
|
|
207
|
+
llamaGrammars[key] = new LlamaGrammar(llama, {
|
|
208
|
+
grammar: input,
|
|
209
|
+
})
|
|
210
|
+
} else {
|
|
211
|
+
// assume input is a JSON schema object
|
|
212
|
+
llamaGrammars[key] = new LlamaJsonSchemaGrammar(llama, input as GbnfJsonSchema)
|
|
213
|
+
}
|
|
214
|
+
}
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
const llamaModel = await llama.loadModel({
|
|
218
|
+
modelPath: config.location, // full model absolute path
|
|
219
|
+
loadSignal: signal,
|
|
220
|
+
useMlock: config.device?.memLock ?? false,
|
|
221
|
+
gpuLayers: config.device?.gpuLayers,
|
|
222
|
+
// onLoadProgress: (percent) => {}
|
|
223
|
+
})
|
|
224
|
+
|
|
225
|
+
const context = await llamaModel.createContext({
|
|
226
|
+
sequences: 1,
|
|
227
|
+
lora: config.lora,
|
|
228
|
+
threads: config.device?.cpuThreads,
|
|
229
|
+
batchSize: config.batchSize,
|
|
230
|
+
contextSize: config.contextSize,
|
|
231
|
+
flashAttention: true,
|
|
232
|
+
createSignal: signal,
|
|
233
|
+
})
|
|
234
|
+
|
|
235
|
+
const instance: NodeLlamaCppInstance = {
|
|
236
|
+
model: llamaModel,
|
|
237
|
+
context,
|
|
238
|
+
grammars: llamaGrammars,
|
|
239
|
+
chat: undefined,
|
|
240
|
+
chatHistory: [],
|
|
241
|
+
pendingFunctionCalls: {},
|
|
242
|
+
lastEvaluation: undefined,
|
|
243
|
+
completion: undefined,
|
|
244
|
+
contextSequence: context.getSequence(),
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
if (config.initialMessages) {
|
|
248
|
+
const initialChatHistory = createChatMessageArray(config.initialMessages)
|
|
249
|
+
const chat = new LlamaChat({
|
|
250
|
+
contextSequence: instance.contextSequence!,
|
|
251
|
+
// autoDisposeSequence: true,
|
|
252
|
+
})
|
|
253
|
+
|
|
254
|
+
let inputFunctions: Record<string, ChatSessionModelFunction> | undefined
|
|
255
|
+
|
|
256
|
+
if (config.tools?.definitions && Object.keys(config.tools.definitions).length > 0) {
|
|
257
|
+
const functionDefs = config.tools.definitions
|
|
258
|
+
inputFunctions = {}
|
|
259
|
+
for (const functionName in functionDefs) {
|
|
260
|
+
const functionDef = functionDefs[functionName]
|
|
261
|
+
inputFunctions[functionName] = defineChatSessionFunction<any>({
|
|
262
|
+
description: functionDef.description,
|
|
263
|
+
params: functionDef.parameters,
|
|
264
|
+
handler: functionDef.handler || (() => {}),
|
|
265
|
+
}) as ChatSessionModelFunction
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
const loadMessagesRes = await chat.loadChatAndCompleteUserMessage(initialChatHistory, {
|
|
270
|
+
initialUserPrompt: '',
|
|
271
|
+
functions: inputFunctions,
|
|
272
|
+
documentFunctionParams: config.tools?.includeToolDocumentation,
|
|
273
|
+
})
|
|
274
|
+
|
|
275
|
+
instance.chat = chat
|
|
276
|
+
instance.chatHistory = initialChatHistory
|
|
277
|
+
instance.lastEvaluation = {
|
|
278
|
+
cleanHistory: initialChatHistory,
|
|
279
|
+
contextWindow: loadMessagesRes.lastEvaluation.contextWindow,
|
|
280
|
+
contextShiftMetadata: loadMessagesRes.lastEvaluation.contextShiftMetadata,
|
|
281
|
+
}
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
if (config.prefix) {
|
|
285
|
+
const contextSequence = instance.contextSequence!
|
|
286
|
+
const completion = new LlamaCompletion({
|
|
287
|
+
contextSequence: contextSequence,
|
|
288
|
+
})
|
|
289
|
+
await completion.generateCompletion(config.prefix, {
|
|
290
|
+
maxTokens: 0,
|
|
291
|
+
})
|
|
292
|
+
instance.completion = completion
|
|
293
|
+
instance.contextSequence = contextSequence
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
return instance
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
export async function disposeInstance(instance: NodeLlamaCppInstance) {
|
|
300
|
+
await instance.model.dispose()
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
export async function processChatCompletionTask(
|
|
304
|
+
{ request, config, resetContext, log, onChunk }: EngineChatCompletionArgs<NodeLlamaCppModelConfig>,
|
|
305
|
+
instance: NodeLlamaCppInstance,
|
|
306
|
+
signal?: AbortSignal,
|
|
307
|
+
): Promise<EngineChatCompletionResult> {
|
|
308
|
+
if (!instance.chat || resetContext) {
|
|
309
|
+
log(LogLevels.debug, 'Recreating chat context', {
|
|
310
|
+
resetContext,
|
|
311
|
+
willDisposeChat: !!instance.chat,
|
|
312
|
+
})
|
|
313
|
+
// if context reset is requested, dispose the chat instance
|
|
314
|
+
if (instance.chat) {
|
|
315
|
+
await instance.chat.dispose()
|
|
316
|
+
}
|
|
317
|
+
let contextSequence = instance.contextSequence
|
|
318
|
+
if (!contextSequence || contextSequence.disposed) {
|
|
319
|
+
if (instance.context.sequencesLeft) {
|
|
320
|
+
contextSequence = instance.context.getSequence()
|
|
321
|
+
instance.contextSequence = contextSequence
|
|
322
|
+
} else {
|
|
323
|
+
throw new Error('No context sequence available')
|
|
324
|
+
}
|
|
325
|
+
} else {
|
|
326
|
+
contextSequence.clearHistory()
|
|
327
|
+
}
|
|
328
|
+
instance.chat = new LlamaChat({
|
|
329
|
+
contextSequence: contextSequence,
|
|
330
|
+
// autoDisposeSequence: true,
|
|
331
|
+
})
|
|
332
|
+
// reset state and reingest the conversation history
|
|
333
|
+
instance.lastEvaluation = undefined
|
|
334
|
+
instance.pendingFunctionCalls = {}
|
|
335
|
+
instance.chatHistory = createChatMessageArray(request.messages)
|
|
336
|
+
// drop last user message. its gonna be added later, after resolved function calls
|
|
337
|
+
if (instance.chatHistory[instance.chatHistory.length - 1].type === 'user') {
|
|
338
|
+
instance.chatHistory.pop()
|
|
339
|
+
}
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
// set additional stop generation triggers for this completion
|
|
343
|
+
const customStopTriggers: StopGenerationTrigger[] = []
|
|
344
|
+
const stopTrigger = request.stop ?? config.completionDefaults?.stop
|
|
345
|
+
if (stopTrigger) {
|
|
346
|
+
customStopTriggers.push(...stopTrigger.map((t) => [t]))
|
|
347
|
+
}
|
|
348
|
+
// setting up logit/token bias dictionary
|
|
349
|
+
let tokenBias: TokenBias | undefined
|
|
350
|
+
const completionTokenBias = request.tokenBias ?? config.completionDefaults?.tokenBias
|
|
351
|
+
if (completionTokenBias) {
|
|
352
|
+
tokenBias = new TokenBias(instance.model.tokenizer)
|
|
353
|
+
for (const key in completionTokenBias) {
|
|
354
|
+
const bias = completionTokenBias[key] / 10
|
|
355
|
+
const tokenId = parseInt(key) as Token
|
|
356
|
+
if (!isNaN(tokenId)) {
|
|
357
|
+
tokenBias.set(tokenId, bias)
|
|
358
|
+
} else {
|
|
359
|
+
tokenBias.set(key, bias)
|
|
360
|
+
}
|
|
361
|
+
}
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
// setting up available function definitions
|
|
365
|
+
const functionDefinitions: Record<string, ToolDefinition> = {
|
|
366
|
+
...config.tools?.definitions,
|
|
367
|
+
...request.tools,
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
// see if the user submitted any function call results
|
|
371
|
+
const supportsParallelFunctionCalling =
|
|
372
|
+
instance.chat.chatWrapper.settings.functions.parallelism != null && !!config.tools?.parallelism
|
|
373
|
+
const resolvedFunctionCalls = []
|
|
374
|
+
const functionCallResultMessages = request.messages.filter((m) => m.role === 'tool') as ToolCallResultMessage[]
|
|
375
|
+
for (const message of functionCallResultMessages) {
|
|
376
|
+
if (!instance.pendingFunctionCalls[message.callId]) {
|
|
377
|
+
log(LogLevels.warn, `Received function result for non-existing call id "${message.callId}`)
|
|
378
|
+
continue
|
|
379
|
+
}
|
|
380
|
+
log(LogLevels.debug, 'Resolving pending function call', {
|
|
381
|
+
id: message.callId,
|
|
382
|
+
result: message.content,
|
|
383
|
+
})
|
|
384
|
+
const functionCall = instance.pendingFunctionCalls[message.callId]
|
|
385
|
+
const functionDef = functionDefinitions[functionCall.functionName]
|
|
386
|
+
resolvedFunctionCalls.push({
|
|
387
|
+
name: functionCall.functionName,
|
|
388
|
+
description: functionDef?.description,
|
|
389
|
+
params: functionCall.params,
|
|
390
|
+
result: message.content,
|
|
391
|
+
rawCall: functionCall.raw,
|
|
392
|
+
startsNewChunk: supportsParallelFunctionCalling,
|
|
393
|
+
})
|
|
394
|
+
delete instance.pendingFunctionCalls[message.callId]
|
|
395
|
+
}
|
|
396
|
+
// if we resolved any results, add them to history
|
|
397
|
+
if (resolvedFunctionCalls.length) {
|
|
398
|
+
instance.chatHistory.push({
|
|
399
|
+
type: 'model',
|
|
400
|
+
response: resolvedFunctionCalls.map((call) => {
|
|
401
|
+
return {
|
|
402
|
+
type: 'functionCall',
|
|
403
|
+
...call,
|
|
404
|
+
}
|
|
405
|
+
}),
|
|
406
|
+
})
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
// add the new user message to the chat history
|
|
410
|
+
let assistantPrefill: string = ''
|
|
411
|
+
const lastMessage = request.messages[request.messages.length - 1]
|
|
412
|
+
if (lastMessage.role === 'user' && lastMessage.content) {
|
|
413
|
+
const newUserText = flattenMessageTextContent(lastMessage.content)
|
|
414
|
+
if (newUserText) {
|
|
415
|
+
instance.chatHistory.push({
|
|
416
|
+
type: 'user',
|
|
417
|
+
text: newUserText,
|
|
418
|
+
})
|
|
419
|
+
}
|
|
420
|
+
} else if (lastMessage.role === 'assistant') {
|
|
421
|
+
// use last message as prefill for response, if its an assistant message
|
|
422
|
+
assistantPrefill = flattenMessageTextContent(lastMessage.content)
|
|
423
|
+
} else if (!resolvedFunctionCalls.length) {
|
|
424
|
+
log(LogLevels.warn, 'Last message is not valid for chat completion. This is likely a mistake.', lastMessage)
|
|
425
|
+
throw new Error('Invalid last chat message')
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
// only grammar or functions can be used, not both.
|
|
429
|
+
// currently ignoring function definitions if grammar is provided
|
|
430
|
+
|
|
431
|
+
let inputGrammar: LlamaGrammar | undefined
|
|
432
|
+
let inputFunctions: Record<string, ChatSessionModelFunction> | undefined
|
|
433
|
+
|
|
434
|
+
if (request.grammar) {
|
|
435
|
+
if (!instance.grammars[request.grammar]) {
|
|
436
|
+
throw new Error(`Grammar "${request.grammar}" not found.`)
|
|
437
|
+
}
|
|
438
|
+
inputGrammar = instance.grammars[request.grammar]
|
|
439
|
+
} else if (Object.keys(functionDefinitions).length > 0) {
|
|
440
|
+
inputFunctions = {}
|
|
441
|
+
for (const functionName in functionDefinitions) {
|
|
442
|
+
const functionDef = functionDefinitions[functionName]
|
|
443
|
+
inputFunctions[functionName] = defineChatSessionFunction<any>({
|
|
444
|
+
description: functionDef.description,
|
|
445
|
+
params: functionDef.parameters,
|
|
446
|
+
handler: functionDef.handler || (() => {}),
|
|
447
|
+
})
|
|
448
|
+
}
|
|
449
|
+
}
|
|
450
|
+
const defaults = config.completionDefaults ?? {}
|
|
451
|
+
let lastEvaluation: LlamaChatResponse['lastEvaluation'] | undefined = instance.lastEvaluation
|
|
452
|
+
let newChatHistory = instance.chatHistory.slice()
|
|
453
|
+
let newContextWindowChatHistory = !lastEvaluation?.contextWindow ? undefined : instance.chatHistory.slice()
|
|
454
|
+
|
|
455
|
+
if (instance.chatHistory[instance.chatHistory.length - 1].type !== 'model' || assistantPrefill) {
|
|
456
|
+
const newModelResponse = assistantPrefill ? [assistantPrefill] : []
|
|
457
|
+
newChatHistory.push({
|
|
458
|
+
type: 'model',
|
|
459
|
+
response: newModelResponse,
|
|
460
|
+
})
|
|
461
|
+
if (newContextWindowChatHistory) {
|
|
462
|
+
newContextWindowChatHistory.push({
|
|
463
|
+
type: 'model',
|
|
464
|
+
response: newModelResponse,
|
|
465
|
+
})
|
|
466
|
+
}
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
const functionsOrGrammar = inputFunctions
|
|
470
|
+
? {
|
|
471
|
+
functions: inputFunctions,
|
|
472
|
+
documentFunctionParams: config.tools?.includeToolDocumentation ?? true,
|
|
473
|
+
maxParallelFunctionCalls: config.tools?.parallelism ?? 1,
|
|
474
|
+
onFunctionCall: (functionCall: LlamaChatResponseFunctionCall<any>) => {
|
|
475
|
+
// log(LogLevels.debug, 'Called function', functionCall)
|
|
476
|
+
},
|
|
477
|
+
}
|
|
478
|
+
: {
|
|
479
|
+
grammar: inputGrammar,
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
const initialTokenMeterState = instance.chat.sequence.tokenMeter.getState()
|
|
483
|
+
let completionResult: LlamaChatResult
|
|
484
|
+
while (true) {
|
|
485
|
+
const {
|
|
486
|
+
functionCalls,
|
|
487
|
+
lastEvaluation: currentLastEvaluation,
|
|
488
|
+
metadata,
|
|
489
|
+
} = await instance.chat.generateResponse(newChatHistory, {
|
|
490
|
+
signal,
|
|
491
|
+
stopOnAbortSignal: true, // this will make aborted completions resolve (with a partial response)
|
|
492
|
+
maxTokens: request.maxTokens ?? defaults.maxTokens,
|
|
493
|
+
temperature: request.temperature ?? defaults.temperature,
|
|
494
|
+
topP: request.topP ?? defaults.topP,
|
|
495
|
+
topK: request.topK ?? defaults.topK,
|
|
496
|
+
minP: request.minP ?? defaults.minP,
|
|
497
|
+
seed: request.seed ?? config.completionDefaults?.seed ?? getRandomNumber(0, 1000000),
|
|
498
|
+
tokenBias,
|
|
499
|
+
customStopTriggers,
|
|
500
|
+
trimWhitespaceSuffix: false,
|
|
501
|
+
...functionsOrGrammar,
|
|
502
|
+
repeatPenalty: {
|
|
503
|
+
lastTokens: request.repeatPenaltyNum ?? defaults.repeatPenaltyNum,
|
|
504
|
+
frequencyPenalty: request.frequencyPenalty ?? defaults.frequencyPenalty,
|
|
505
|
+
presencePenalty: request.presencePenalty ?? defaults.presencePenalty,
|
|
506
|
+
},
|
|
507
|
+
contextShift: {
|
|
508
|
+
strategy: config.contextShiftStrategy,
|
|
509
|
+
lastEvaluationMetadata: lastEvaluation?.contextShiftMetadata,
|
|
510
|
+
},
|
|
511
|
+
lastEvaluationContextWindow: {
|
|
512
|
+
history: newContextWindowChatHistory,
|
|
513
|
+
minimumOverlapPercentageToPreventContextShift: 0.5,
|
|
514
|
+
},
|
|
515
|
+
onToken: (tokens) => {
|
|
516
|
+
const text = instance.model.detokenize(tokens)
|
|
517
|
+
if (onChunk) {
|
|
518
|
+
onChunk({
|
|
519
|
+
tokens,
|
|
520
|
+
text,
|
|
521
|
+
})
|
|
522
|
+
}
|
|
523
|
+
},
|
|
524
|
+
})
|
|
525
|
+
|
|
526
|
+
lastEvaluation = currentLastEvaluation
|
|
527
|
+
newChatHistory = lastEvaluation.cleanHistory
|
|
528
|
+
|
|
529
|
+
if (functionCalls) {
|
|
530
|
+
// find leading immediately evokable function calls (=have a handler function)
|
|
531
|
+
const evokableFunctionCalls = []
|
|
532
|
+
for (const functionCall of functionCalls) {
|
|
533
|
+
const functionDef = functionDefinitions[functionCall.functionName]
|
|
534
|
+
if (functionDef.handler) {
|
|
535
|
+
evokableFunctionCalls.push(functionCall)
|
|
536
|
+
} else {
|
|
537
|
+
break
|
|
538
|
+
}
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
// resolve their results.
|
|
542
|
+
const results = await Promise.all(
|
|
543
|
+
evokableFunctionCalls.map(async (functionCall) => {
|
|
544
|
+
const functionDef = functionDefinitions[functionCall.functionName]
|
|
545
|
+
if (!functionDef) {
|
|
546
|
+
throw new Error(`The model tried to call undefined function "${functionCall.functionName}"`)
|
|
547
|
+
}
|
|
548
|
+
const functionCallResult = await functionDef.handler!(functionCall.params)
|
|
549
|
+
log(LogLevels.debug, 'Function handler resolved', {
|
|
550
|
+
function: functionCall.functionName,
|
|
551
|
+
args: functionCall.params,
|
|
552
|
+
result: functionCallResult,
|
|
553
|
+
})
|
|
554
|
+
return {
|
|
555
|
+
functionDef,
|
|
556
|
+
functionCall,
|
|
557
|
+
functionCallResult,
|
|
558
|
+
}
|
|
559
|
+
}),
|
|
560
|
+
)
|
|
561
|
+
newContextWindowChatHistory = lastEvaluation.contextWindow
|
|
562
|
+
let startNewChunk = true
|
|
563
|
+
// add results to chat history in the order they were called
|
|
564
|
+
for (const callResult of results) {
|
|
565
|
+
newChatHistory = addFunctionCallToChatHistory({
|
|
566
|
+
chatHistory: newChatHistory,
|
|
567
|
+
functionName: callResult.functionCall.functionName,
|
|
568
|
+
functionDescription: callResult.functionDef.description,
|
|
569
|
+
callParams: callResult.functionCall.params,
|
|
570
|
+
callResult: callResult.functionCallResult,
|
|
571
|
+
rawCall: callResult.functionCall.raw,
|
|
572
|
+
startsNewChunk: startNewChunk,
|
|
573
|
+
})
|
|
574
|
+
newContextWindowChatHistory = addFunctionCallToChatHistory({
|
|
575
|
+
chatHistory: newContextWindowChatHistory,
|
|
576
|
+
functionName: callResult.functionCall.functionName,
|
|
577
|
+
functionDescription: callResult.functionDef.description,
|
|
578
|
+
callParams: callResult.functionCall.params,
|
|
579
|
+
callResult: callResult.functionCallResult,
|
|
580
|
+
rawCall: callResult.functionCall.raw,
|
|
581
|
+
startsNewChunk: startNewChunk,
|
|
582
|
+
})
|
|
583
|
+
startNewChunk = false
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
// check if all function calls were immediately evokable
|
|
587
|
+
const remainingFunctionCalls = functionCalls.slice(evokableFunctionCalls.length)
|
|
588
|
+
|
|
589
|
+
if (remainingFunctionCalls.length === 0) {
|
|
590
|
+
// if yes, continue with generation
|
|
591
|
+
lastEvaluation.cleanHistory = newChatHistory
|
|
592
|
+
lastEvaluation.contextWindow = newContextWindowChatHistory!
|
|
593
|
+
continue
|
|
594
|
+
} else {
|
|
595
|
+
// if no, return the function calls and skip generation
|
|
596
|
+
completionResult = {
|
|
597
|
+
responseText: null,
|
|
598
|
+
stopReason: 'functionCalls',
|
|
599
|
+
functionCalls: remainingFunctionCalls,
|
|
600
|
+
}
|
|
601
|
+
break
|
|
602
|
+
}
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
// no function calls happened, we got a model response.
|
|
606
|
+
instance.lastEvaluation = lastEvaluation
|
|
607
|
+
instance.chatHistory = newChatHistory
|
|
608
|
+
const lastMessage = instance.chatHistory[instance.chatHistory.length - 1] as ChatModelResponse
|
|
609
|
+
const responseText = lastMessage.response.filter((item: any) => typeof item === 'string').join('')
|
|
610
|
+
completionResult = {
|
|
611
|
+
responseText,
|
|
612
|
+
stopReason: metadata.stopReason,
|
|
613
|
+
}
|
|
614
|
+
break
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
const assistantMessage: AssistantMessage = {
|
|
618
|
+
role: 'assistant',
|
|
619
|
+
content: completionResult.responseText || '',
|
|
620
|
+
}
|
|
621
|
+
|
|
622
|
+
if (completionResult.functionCalls) {
|
|
623
|
+
// TODO its possible that there are trailing immediately-evaluatable function calls.
|
|
624
|
+
// function call results need to be added in the order the functions were called, so
|
|
625
|
+
// we need to wait for the pending calls to complete before we can add the trailing calls.
|
|
626
|
+
// as is, these may never resolve
|
|
627
|
+
const pendingFunctionCalls = completionResult.functionCalls.filter((call) => {
|
|
628
|
+
const functionDef = functionDefinitions[call.functionName]
|
|
629
|
+
return !functionDef.handler
|
|
630
|
+
})
|
|
631
|
+
|
|
632
|
+
// TODO write a test that triggers a parallel call to a deferred function and to an IE function
|
|
633
|
+
const trailingFunctionCalls = completionResult.functionCalls.filter((call) => {
|
|
634
|
+
const functionDef = functionDefinitions[call.functionName]
|
|
635
|
+
return functionDef.handler
|
|
636
|
+
})
|
|
637
|
+
if (trailingFunctionCalls.length) {
|
|
638
|
+
console.debug(trailingFunctionCalls)
|
|
639
|
+
log(LogLevels.warn, 'Trailing function calls not resolved')
|
|
640
|
+
}
|
|
641
|
+
|
|
642
|
+
assistantMessage.toolCalls = pendingFunctionCalls.map((call) => {
|
|
643
|
+
const callId = nanoid()
|
|
644
|
+
instance.pendingFunctionCalls[callId] = call
|
|
645
|
+
log(LogLevels.debug, 'Saving pending tool call', {
|
|
646
|
+
id: callId,
|
|
647
|
+
function: call.functionName,
|
|
648
|
+
args: call.params,
|
|
649
|
+
})
|
|
650
|
+
return {
|
|
651
|
+
id: callId,
|
|
652
|
+
name: call.functionName,
|
|
653
|
+
parameters: call.params,
|
|
654
|
+
}
|
|
655
|
+
})
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
const tokenDifference = instance.chat.sequence.tokenMeter.diff(initialTokenMeterState)
|
|
659
|
+
return {
|
|
660
|
+
finishReason: mapFinishReason(completionResult.stopReason),
|
|
661
|
+
message: assistantMessage,
|
|
662
|
+
promptTokens: tokenDifference.usedInputTokens,
|
|
663
|
+
completionTokens: tokenDifference.usedOutputTokens,
|
|
664
|
+
contextTokens: instance.chat.sequence.contextTokens.length,
|
|
665
|
+
}
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
export async function processTextCompletionTask(
|
|
669
|
+
{ request, config, resetContext, log, onChunk }: EngineTextCompletionArgs<NodeLlamaCppModelConfig>,
|
|
670
|
+
instance: NodeLlamaCppInstance,
|
|
671
|
+
signal?: AbortSignal,
|
|
672
|
+
): Promise<EngineTextCompletionResult> {
|
|
673
|
+
if (!request.prompt) {
|
|
674
|
+
throw new Error('Prompt is required for text completion.')
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
let completion: LlamaCompletion
|
|
678
|
+
let contextSequence: LlamaContextSequence
|
|
679
|
+
|
|
680
|
+
if (resetContext && instance.contextSequence) {
|
|
681
|
+
instance.contextSequence.clearHistory()
|
|
682
|
+
}
|
|
683
|
+
|
|
684
|
+
if (!instance.completion || instance.completion.disposed) {
|
|
685
|
+
if (instance.contextSequence) {
|
|
686
|
+
contextSequence = instance.contextSequence
|
|
687
|
+
} else if (instance.context.sequencesLeft) {
|
|
688
|
+
contextSequence = instance.context.getSequence()
|
|
689
|
+
} else {
|
|
690
|
+
throw new Error('No context sequence available')
|
|
691
|
+
}
|
|
692
|
+
instance.contextSequence = contextSequence
|
|
693
|
+
completion = new LlamaCompletion({
|
|
694
|
+
contextSequence,
|
|
695
|
+
})
|
|
696
|
+
instance.completion = completion
|
|
697
|
+
} else {
|
|
698
|
+
completion = instance.completion
|
|
699
|
+
contextSequence = instance.contextSequence!
|
|
700
|
+
}
|
|
701
|
+
|
|
702
|
+
if (!contextSequence || contextSequence.disposed) {
|
|
703
|
+
contextSequence = instance.context.getSequence()
|
|
704
|
+
instance.contextSequence = contextSequence
|
|
705
|
+
completion = new LlamaCompletion({
|
|
706
|
+
contextSequence,
|
|
707
|
+
})
|
|
708
|
+
instance.completion = completion
|
|
709
|
+
}
|
|
710
|
+
|
|
711
|
+
const stopGenerationTriggers: StopGenerationTrigger[] = []
|
|
712
|
+
const stopTrigger = request.stop ?? config.completionDefaults?.stop
|
|
713
|
+
if (stopTrigger) {
|
|
714
|
+
stopGenerationTriggers.push(...stopTrigger.map((t) => [t]))
|
|
715
|
+
}
|
|
716
|
+
|
|
717
|
+
const initialTokenMeterState = contextSequence.tokenMeter.getState()
|
|
718
|
+
const defaults = config.completionDefaults ?? {}
|
|
719
|
+
const result = await completion.generateCompletionWithMeta(request.prompt, {
|
|
720
|
+
maxTokens: request.maxTokens ?? defaults.maxTokens,
|
|
721
|
+
temperature: request.temperature ?? defaults.temperature,
|
|
722
|
+
topP: request.topP ?? defaults.topP,
|
|
723
|
+
topK: request.topK ?? defaults.topK,
|
|
724
|
+
minP: request.minP ?? defaults.minP,
|
|
725
|
+
repeatPenalty: {
|
|
726
|
+
lastTokens: request.repeatPenaltyNum ?? defaults.repeatPenaltyNum,
|
|
727
|
+
frequencyPenalty: request.frequencyPenalty ?? defaults.frequencyPenalty,
|
|
728
|
+
presencePenalty: request.presencePenalty ?? defaults.presencePenalty,
|
|
729
|
+
},
|
|
730
|
+
signal: signal,
|
|
731
|
+
customStopTriggers: stopGenerationTriggers.length ? stopGenerationTriggers : undefined,
|
|
732
|
+
seed: request.seed ?? config.completionDefaults?.seed ?? getRandomNumber(0, 1000000),
|
|
733
|
+
onToken: (tokens) => {
|
|
734
|
+
const text = instance.model.detokenize(tokens)
|
|
735
|
+
if (onChunk) {
|
|
736
|
+
onChunk({
|
|
737
|
+
tokens,
|
|
738
|
+
text,
|
|
739
|
+
})
|
|
740
|
+
}
|
|
741
|
+
},
|
|
742
|
+
})
|
|
743
|
+
|
|
744
|
+
const tokenDifference = contextSequence.tokenMeter.diff(initialTokenMeterState)
|
|
745
|
+
|
|
746
|
+
return {
|
|
747
|
+
finishReason: mapFinishReason(result.metadata.stopReason),
|
|
748
|
+
text: result.response,
|
|
749
|
+
promptTokens: tokenDifference.usedInputTokens,
|
|
750
|
+
completionTokens: tokenDifference.usedOutputTokens,
|
|
751
|
+
contextTokens: contextSequence.contextTokens.length,
|
|
752
|
+
}
|
|
753
|
+
}
|
|
754
|
+
|
|
755
|
+
export async function processEmbeddingTask(
|
|
756
|
+
{ request, config, log }: EngineEmbeddingArgs<NodeLlamaCppModelConfig>,
|
|
757
|
+
instance: NodeLlamaCppInstance,
|
|
758
|
+
signal?: AbortSignal,
|
|
759
|
+
): Promise<EngineEmbeddingResult> {
|
|
760
|
+
if (!request.input) {
|
|
761
|
+
throw new Error('Input is required for embedding.')
|
|
762
|
+
}
|
|
763
|
+
const texts: string[] = []
|
|
764
|
+
if (typeof request.input === 'string') {
|
|
765
|
+
texts.push(request.input)
|
|
766
|
+
} else if (Array.isArray(request.input)) {
|
|
767
|
+
for (const input of request.input) {
|
|
768
|
+
if (typeof input === 'string') {
|
|
769
|
+
texts.push(input)
|
|
770
|
+
} else if (input.type === 'text') {
|
|
771
|
+
texts.push(input.content)
|
|
772
|
+
} else if (input.type === 'image') {
|
|
773
|
+
throw new Error('Image inputs not implemented.')
|
|
774
|
+
}
|
|
775
|
+
}
|
|
776
|
+
}
|
|
777
|
+
|
|
778
|
+
if (!instance.embeddingContext) {
|
|
779
|
+
instance.embeddingContext = await instance.model.createEmbeddingContext({
|
|
780
|
+
batchSize: config.batchSize,
|
|
781
|
+
createSignal: signal,
|
|
782
|
+
threads: config.device?.cpuThreads,
|
|
783
|
+
contextSize: config.contextSize,
|
|
784
|
+
})
|
|
785
|
+
}
|
|
786
|
+
|
|
787
|
+
// @ts-ignore - private property
|
|
788
|
+
const contextSize = instance.embeddingContext._llamaContext.contextSize
|
|
789
|
+
|
|
790
|
+
const embeddings: Float32Array[] = []
|
|
791
|
+
let inputTokens = 0
|
|
792
|
+
|
|
793
|
+
for (const text of texts) {
|
|
794
|
+
let tokenizedInput = instance.model.tokenize(text)
|
|
795
|
+
if (tokenizedInput.length > contextSize) {
|
|
796
|
+
log(LogLevels.warn, 'Truncated input that exceeds context size')
|
|
797
|
+
tokenizedInput = tokenizedInput.slice(0, contextSize)
|
|
798
|
+
}
|
|
799
|
+
inputTokens += tokenizedInput.length
|
|
800
|
+
const embedding = await instance.embeddingContext.getEmbeddingFor(tokenizedInput)
|
|
801
|
+
embeddings.push(new Float32Array(embedding.vector))
|
|
802
|
+
if (signal?.aborted) {
|
|
803
|
+
break
|
|
804
|
+
}
|
|
805
|
+
}
|
|
806
|
+
|
|
807
|
+
return {
|
|
808
|
+
embeddings,
|
|
809
|
+
inputTokens,
|
|
810
|
+
}
|
|
811
|
+
}
|