inference-server 1.0.0-beta.19

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (227) hide show
  1. package/README.md +216 -0
  2. package/dist/api/openai/enums.d.ts +4 -0
  3. package/dist/api/openai/enums.js +17 -0
  4. package/dist/api/openai/enums.js.map +1 -0
  5. package/dist/api/openai/handlers/chat.d.ts +3 -0
  6. package/dist/api/openai/handlers/chat.js +358 -0
  7. package/dist/api/openai/handlers/chat.js.map +1 -0
  8. package/dist/api/openai/handlers/completions.d.ts +3 -0
  9. package/dist/api/openai/handlers/completions.js +169 -0
  10. package/dist/api/openai/handlers/completions.js.map +1 -0
  11. package/dist/api/openai/handlers/embeddings.d.ts +3 -0
  12. package/dist/api/openai/handlers/embeddings.js +74 -0
  13. package/dist/api/openai/handlers/embeddings.js.map +1 -0
  14. package/dist/api/openai/handlers/images.d.ts +0 -0
  15. package/dist/api/openai/handlers/images.js +4 -0
  16. package/dist/api/openai/handlers/images.js.map +1 -0
  17. package/dist/api/openai/handlers/models.d.ts +3 -0
  18. package/dist/api/openai/handlers/models.js +23 -0
  19. package/dist/api/openai/handlers/models.js.map +1 -0
  20. package/dist/api/openai/handlers/transcription.d.ts +0 -0
  21. package/dist/api/openai/handlers/transcription.js +4 -0
  22. package/dist/api/openai/handlers/transcription.js.map +1 -0
  23. package/dist/api/openai/index.d.ts +7 -0
  24. package/dist/api/openai/index.js +14 -0
  25. package/dist/api/openai/index.js.map +1 -0
  26. package/dist/api/parseJSONRequestBody.d.ts +2 -0
  27. package/dist/api/parseJSONRequestBody.js +24 -0
  28. package/dist/api/parseJSONRequestBody.js.map +1 -0
  29. package/dist/api/v1/index.d.ts +2 -0
  30. package/dist/api/v1/index.js +29 -0
  31. package/dist/api/v1/index.js.map +1 -0
  32. package/dist/cli.d.ts +1 -0
  33. package/dist/cli.js +10 -0
  34. package/dist/cli.js.map +1 -0
  35. package/dist/engines/gpt4all/engine.d.ts +34 -0
  36. package/dist/engines/gpt4all/engine.js +357 -0
  37. package/dist/engines/gpt4all/engine.js.map +1 -0
  38. package/dist/engines/gpt4all/util.d.ts +3 -0
  39. package/dist/engines/gpt4all/util.js +29 -0
  40. package/dist/engines/gpt4all/util.js.map +1 -0
  41. package/dist/engines/index.d.ts +19 -0
  42. package/dist/engines/index.js +21 -0
  43. package/dist/engines/index.js.map +1 -0
  44. package/dist/engines/node-llama-cpp/engine.d.ts +49 -0
  45. package/dist/engines/node-llama-cpp/engine.js +666 -0
  46. package/dist/engines/node-llama-cpp/engine.js.map +1 -0
  47. package/dist/engines/node-llama-cpp/types.d.ts +13 -0
  48. package/dist/engines/node-llama-cpp/types.js +2 -0
  49. package/dist/engines/node-llama-cpp/types.js.map +1 -0
  50. package/dist/engines/node-llama-cpp/util.d.ts +15 -0
  51. package/dist/engines/node-llama-cpp/util.js +84 -0
  52. package/dist/engines/node-llama-cpp/util.js.map +1 -0
  53. package/dist/engines/node-llama-cpp/validateModelFile.d.ts +8 -0
  54. package/dist/engines/node-llama-cpp/validateModelFile.js +36 -0
  55. package/dist/engines/node-llama-cpp/validateModelFile.js.map +1 -0
  56. package/dist/engines/stable-diffusion-cpp/engine.d.ts +90 -0
  57. package/dist/engines/stable-diffusion-cpp/engine.js +294 -0
  58. package/dist/engines/stable-diffusion-cpp/engine.js.map +1 -0
  59. package/dist/engines/stable-diffusion-cpp/types.d.ts +3 -0
  60. package/dist/engines/stable-diffusion-cpp/types.js +2 -0
  61. package/dist/engines/stable-diffusion-cpp/types.js.map +1 -0
  62. package/dist/engines/stable-diffusion-cpp/util.d.ts +4 -0
  63. package/dist/engines/stable-diffusion-cpp/util.js +55 -0
  64. package/dist/engines/stable-diffusion-cpp/util.js.map +1 -0
  65. package/dist/engines/stable-diffusion-cpp/validateModelFiles.d.ts +19 -0
  66. package/dist/engines/stable-diffusion-cpp/validateModelFiles.js +91 -0
  67. package/dist/engines/stable-diffusion-cpp/validateModelFiles.js.map +1 -0
  68. package/dist/engines/transformers-js/engine.d.ts +37 -0
  69. package/dist/engines/transformers-js/engine.js +538 -0
  70. package/dist/engines/transformers-js/engine.js.map +1 -0
  71. package/dist/engines/transformers-js/types.d.ts +7 -0
  72. package/dist/engines/transformers-js/types.js +2 -0
  73. package/dist/engines/transformers-js/types.js.map +1 -0
  74. package/dist/engines/transformers-js/util.d.ts +7 -0
  75. package/dist/engines/transformers-js/util.js +36 -0
  76. package/dist/engines/transformers-js/util.js.map +1 -0
  77. package/dist/engines/transformers-js/validateModelFiles.d.ts +17 -0
  78. package/dist/engines/transformers-js/validateModelFiles.js +133 -0
  79. package/dist/engines/transformers-js/validateModelFiles.js.map +1 -0
  80. package/dist/experiments/ChatWithVision.d.ts +11 -0
  81. package/dist/experiments/ChatWithVision.js +91 -0
  82. package/dist/experiments/ChatWithVision.js.map +1 -0
  83. package/dist/experiments/StableDiffPromptGenerator.d.ts +0 -0
  84. package/dist/experiments/StableDiffPromptGenerator.js +4 -0
  85. package/dist/experiments/StableDiffPromptGenerator.js.map +1 -0
  86. package/dist/experiments/VoiceFunctionCall.d.ts +18 -0
  87. package/dist/experiments/VoiceFunctionCall.js +51 -0
  88. package/dist/experiments/VoiceFunctionCall.js.map +1 -0
  89. package/dist/http.d.ts +19 -0
  90. package/dist/http.js +54 -0
  91. package/dist/http.js.map +1 -0
  92. package/dist/index.d.ts +7 -0
  93. package/dist/index.js +8 -0
  94. package/dist/index.js.map +1 -0
  95. package/dist/instance.d.ts +88 -0
  96. package/dist/instance.js +594 -0
  97. package/dist/instance.js.map +1 -0
  98. package/dist/lib/acquireFileLock.d.ts +7 -0
  99. package/dist/lib/acquireFileLock.js +38 -0
  100. package/dist/lib/acquireFileLock.js.map +1 -0
  101. package/dist/lib/calculateContextIdentity.d.ts +7 -0
  102. package/dist/lib/calculateContextIdentity.js +39 -0
  103. package/dist/lib/calculateContextIdentity.js.map +1 -0
  104. package/dist/lib/calculateFileChecksum.d.ts +1 -0
  105. package/dist/lib/calculateFileChecksum.js +16 -0
  106. package/dist/lib/calculateFileChecksum.js.map +1 -0
  107. package/dist/lib/copyDirectory.d.ts +6 -0
  108. package/dist/lib/copyDirectory.js +27 -0
  109. package/dist/lib/copyDirectory.js.map +1 -0
  110. package/dist/lib/decodeAudio.d.ts +1 -0
  111. package/dist/lib/decodeAudio.js +26 -0
  112. package/dist/lib/decodeAudio.js.map +1 -0
  113. package/dist/lib/downloadModelFile.d.ts +10 -0
  114. package/dist/lib/downloadModelFile.js +58 -0
  115. package/dist/lib/downloadModelFile.js.map +1 -0
  116. package/dist/lib/flattenMessageTextContent.d.ts +2 -0
  117. package/dist/lib/flattenMessageTextContent.js +11 -0
  118. package/dist/lib/flattenMessageTextContent.js.map +1 -0
  119. package/dist/lib/getCacheDirPath.d.ts +12 -0
  120. package/dist/lib/getCacheDirPath.js +31 -0
  121. package/dist/lib/getCacheDirPath.js.map +1 -0
  122. package/dist/lib/loadImage.d.ts +12 -0
  123. package/dist/lib/loadImage.js +30 -0
  124. package/dist/lib/loadImage.js.map +1 -0
  125. package/dist/lib/logger.d.ts +12 -0
  126. package/dist/lib/logger.js +98 -0
  127. package/dist/lib/logger.js.map +1 -0
  128. package/dist/lib/math.d.ts +7 -0
  129. package/dist/lib/math.js +30 -0
  130. package/dist/lib/math.js.map +1 -0
  131. package/dist/lib/resolveModelFileLocation.d.ts +15 -0
  132. package/dist/lib/resolveModelFileLocation.js +41 -0
  133. package/dist/lib/resolveModelFileLocation.js.map +1 -0
  134. package/dist/lib/util.d.ts +7 -0
  135. package/dist/lib/util.js +61 -0
  136. package/dist/lib/util.js.map +1 -0
  137. package/dist/lib/validateModelFile.d.ts +9 -0
  138. package/dist/lib/validateModelFile.js +62 -0
  139. package/dist/lib/validateModelFile.js.map +1 -0
  140. package/dist/lib/validateModelOptions.d.ts +3 -0
  141. package/dist/lib/validateModelOptions.js +23 -0
  142. package/dist/lib/validateModelOptions.js.map +1 -0
  143. package/dist/pool.d.ts +61 -0
  144. package/dist/pool.js +512 -0
  145. package/dist/pool.js.map +1 -0
  146. package/dist/server.d.ts +59 -0
  147. package/dist/server.js +221 -0
  148. package/dist/server.js.map +1 -0
  149. package/dist/standalone.d.ts +1 -0
  150. package/dist/standalone.js +306 -0
  151. package/dist/standalone.js.map +1 -0
  152. package/dist/store.d.ts +60 -0
  153. package/dist/store.js +203 -0
  154. package/dist/store.js.map +1 -0
  155. package/dist/types/completions.d.ts +57 -0
  156. package/dist/types/completions.js +2 -0
  157. package/dist/types/completions.js.map +1 -0
  158. package/dist/types/index.d.ts +326 -0
  159. package/dist/types/index.js +2 -0
  160. package/dist/types/index.js.map +1 -0
  161. package/docs/engines.md +28 -0
  162. package/docs/gpu.md +72 -0
  163. package/docs/http-api.md +147 -0
  164. package/examples/all-options.js +108 -0
  165. package/examples/chat-cli.js +56 -0
  166. package/examples/chat-server.js +65 -0
  167. package/examples/concurrency.js +70 -0
  168. package/examples/express.js +70 -0
  169. package/examples/pool.js +91 -0
  170. package/package.json +113 -0
  171. package/src/api/openai/enums.ts +20 -0
  172. package/src/api/openai/handlers/chat.ts +408 -0
  173. package/src/api/openai/handlers/completions.ts +196 -0
  174. package/src/api/openai/handlers/embeddings.ts +92 -0
  175. package/src/api/openai/handlers/images.ts +3 -0
  176. package/src/api/openai/handlers/models.ts +33 -0
  177. package/src/api/openai/handlers/transcription.ts +2 -0
  178. package/src/api/openai/index.ts +16 -0
  179. package/src/api/parseJSONRequestBody.ts +26 -0
  180. package/src/api/v1/DRAFT.md +16 -0
  181. package/src/api/v1/index.ts +37 -0
  182. package/src/cli.ts +9 -0
  183. package/src/engines/gpt4all/engine.ts +441 -0
  184. package/src/engines/gpt4all/util.ts +31 -0
  185. package/src/engines/index.ts +28 -0
  186. package/src/engines/node-llama-cpp/engine.ts +811 -0
  187. package/src/engines/node-llama-cpp/types.ts +17 -0
  188. package/src/engines/node-llama-cpp/util.ts +126 -0
  189. package/src/engines/node-llama-cpp/validateModelFile.ts +46 -0
  190. package/src/engines/stable-diffusion-cpp/engine.ts +369 -0
  191. package/src/engines/stable-diffusion-cpp/types.ts +54 -0
  192. package/src/engines/stable-diffusion-cpp/util.ts +58 -0
  193. package/src/engines/stable-diffusion-cpp/validateModelFiles.ts +119 -0
  194. package/src/engines/transformers-js/engine.ts +659 -0
  195. package/src/engines/transformers-js/types.ts +25 -0
  196. package/src/engines/transformers-js/util.ts +40 -0
  197. package/src/engines/transformers-js/validateModelFiles.ts +168 -0
  198. package/src/experiments/ChatWithVision.ts +103 -0
  199. package/src/experiments/StableDiffPromptGenerator.ts +2 -0
  200. package/src/experiments/VoiceFunctionCall.ts +71 -0
  201. package/src/http.ts +72 -0
  202. package/src/index.ts +7 -0
  203. package/src/instance.ts +723 -0
  204. package/src/lib/acquireFileLock.ts +38 -0
  205. package/src/lib/calculateContextIdentity.ts +53 -0
  206. package/src/lib/calculateFileChecksum.ts +18 -0
  207. package/src/lib/copyDirectory.ts +29 -0
  208. package/src/lib/decodeAudio.ts +39 -0
  209. package/src/lib/downloadModelFile.ts +70 -0
  210. package/src/lib/flattenMessageTextContent.ts +19 -0
  211. package/src/lib/getCacheDirPath.ts +34 -0
  212. package/src/lib/loadImage.ts +46 -0
  213. package/src/lib/logger.ts +112 -0
  214. package/src/lib/math.ts +31 -0
  215. package/src/lib/resolveModelFileLocation.ts +49 -0
  216. package/src/lib/util.ts +75 -0
  217. package/src/lib/validateModelFile.ts +71 -0
  218. package/src/lib/validateModelOptions.ts +31 -0
  219. package/src/pool.ts +651 -0
  220. package/src/server.ts +270 -0
  221. package/src/standalone.ts +320 -0
  222. package/src/store.ts +278 -0
  223. package/src/types/completions.ts +86 -0
  224. package/src/types/index.ts +488 -0
  225. package/tsconfig.json +29 -0
  226. package/tsconfig.release.json +11 -0
  227. package/vitest.config.ts +18 -0
@@ -0,0 +1,441 @@
1
+ import path from 'node:path'
2
+ import fs from 'node:fs'
3
+ import {
4
+ loadModel,
5
+ createCompletion,
6
+ createEmbedding,
7
+ InferenceModel,
8
+ LoadModelOptions,
9
+ CompletionInput,
10
+ EmbeddingModel,
11
+ DEFAULT_MODEL_LIST_URL,
12
+ } from 'gpt4all'
13
+ import {
14
+ EngineTextCompletionArgs,
15
+ EngineChatCompletionArgs,
16
+ EngineChatCompletionResult,
17
+ EngineTextCompletionResult,
18
+ CompletionFinishReason,
19
+ EngineContext,
20
+ EngineEmbeddingArgs,
21
+ EngineEmbeddingResult,
22
+ FileDownloadProgress,
23
+ ModelConfig,
24
+ TextCompletionParams,
25
+ ChatMessage,
26
+ } from '#package/types/index.js'
27
+ import { LogLevels } from '#package/lib/logger.js'
28
+ import { downloadModelFile } from '#package/lib/downloadModelFile.js'
29
+ import { acquireFileLock } from '#package/lib/acquireFileLock.js'
30
+ import { validateModelFile } from '#package/lib/validateModelFile.js'
31
+ import { createChatMessageArray } from './util.js'
32
+
33
+ export type GPT4AllInstance = InferenceModel | EmbeddingModel
34
+
35
+ export interface GPT4AllModelMeta {
36
+ url: string
37
+ md5sum: string
38
+ filename: string
39
+ promptTemplate: string
40
+ systemPrompt: string
41
+ filesize: number
42
+ ramrequired: number
43
+ }
44
+
45
+ export interface GPT4AllModelConfig extends ModelConfig {
46
+ location: string
47
+ md5?: string
48
+ url?: string
49
+ contextSize?: number
50
+ batchSize?: number
51
+ task: 'text-completion' | 'embedding'
52
+ initialMessages?: ChatMessage[]
53
+ completionDefaults?: TextCompletionParams
54
+ device?: {
55
+ gpu?: boolean | 'auto' | (string & {})
56
+ gpuLayers?: number
57
+ cpuThreads?: number
58
+ }
59
+ }
60
+
61
+ export const autoGpu = true
62
+
63
+ export async function prepareModel(
64
+ { config, log }: EngineContext<GPT4AllModelConfig>,
65
+ onProgress?: (progress: FileDownloadProgress) => void,
66
+ signal?: AbortSignal,
67
+ ) {
68
+ fs.mkdirSync(path.dirname(config.location), { recursive: true })
69
+ const releaseFileLock = await acquireFileLock(config.location)
70
+ if (signal?.aborted) {
71
+ releaseFileLock()
72
+ return
73
+ }
74
+ log(LogLevels.info, `Preparing gpt4all model at ${config.location}`, {
75
+ model: config.id,
76
+ })
77
+ let modelMeta: GPT4AllModelMeta | undefined
78
+ let modelList: GPT4AllModelMeta[]
79
+ const modelMetaPath = path.join(path.dirname(config.location), 'models.json')
80
+ try {
81
+ if (!fs.existsSync(modelMetaPath)) {
82
+ const res = await fetch(DEFAULT_MODEL_LIST_URL)
83
+ modelList = (await res.json()) as GPT4AllModelMeta[]
84
+ fs.writeFileSync(modelMetaPath, JSON.stringify(modelList, null, 2))
85
+ } else {
86
+ modelList = JSON.parse(fs.readFileSync(modelMetaPath, 'utf-8'))
87
+ }
88
+ const foundModelMeta = modelList.find((item) => {
89
+ if (config.md5 && item.md5sum) {
90
+ return item.md5sum === config.md5
91
+ }
92
+ if (config.url && item.url) {
93
+ return item.url === config.url
94
+ }
95
+ return item.filename === path.basename(config.location)
96
+ })
97
+ if (foundModelMeta) {
98
+ modelMeta = foundModelMeta
99
+ }
100
+
101
+ const validationError = await validateModelFile({
102
+ ...config,
103
+ md5: config.md5 || modelMeta?.md5sum,
104
+ })
105
+ if (signal?.aborted) {
106
+ return
107
+ }
108
+ if (validationError) {
109
+ if (config.url) {
110
+ log(LogLevels.info, 'Downloading', {
111
+ model: config.id,
112
+ url: config.url,
113
+ location: config.location,
114
+ error: validationError,
115
+ })
116
+ await downloadModelFile({
117
+ url: config.url,
118
+ filePath: config.location,
119
+ modelsCachePath: config.modelsCachePath,
120
+ onProgress,
121
+ signal,
122
+ })
123
+ } else {
124
+ throw new Error(`${validationError} - No URL provided`)
125
+ }
126
+ }
127
+
128
+ const finalValidationError = await validateModelFile({
129
+ ...config,
130
+ md5: config.md5 || modelMeta?.md5sum,
131
+ })
132
+ if (finalValidationError) {
133
+ throw new Error(`Downloaded files are invalid: ${finalValidationError}`)
134
+ }
135
+ if (signal?.aborted) {
136
+ return
137
+ }
138
+
139
+ return modelMeta
140
+ } catch (error) {
141
+ throw error
142
+ } finally {
143
+ releaseFileLock()
144
+ }
145
+ }
146
+
147
+ export async function createInstance({ config, log }: EngineContext<GPT4AllModelConfig>, signal?: AbortSignal) {
148
+ log(LogLevels.info, `Load GPT4All model ${config.location}`)
149
+ let device = config.device?.gpu ?? 'cpu'
150
+ if (typeof device === 'boolean') {
151
+ device = device ? 'gpu' : 'cpu'
152
+ } else if (device === 'auto') {
153
+ device = 'cpu'
154
+ }
155
+ const loadOpts: LoadModelOptions = {
156
+ modelPath: path.dirname(config.location),
157
+ // file: config.file,
158
+ modelConfigFile: path.dirname(config.location) + '/models.json',
159
+ allowDownload: false,
160
+ device: device,
161
+ ngl: config.device?.gpuLayers ?? 100,
162
+ nCtx: config.contextSize ?? 2048,
163
+ // verbose: true,
164
+ // signal?: // TODO no way to cancel load
165
+ }
166
+
167
+ let modelType: 'inference' | 'embedding'
168
+ if (config.task === 'text-completion') {
169
+ modelType = 'inference'
170
+ } else if (config.task === 'embedding') {
171
+ modelType = 'embedding'
172
+ } else {
173
+ throw new Error(`Unsupported task type: ${config.task}`)
174
+ }
175
+
176
+ const instance = await loadModel(path.basename(config.location), {
177
+ ...loadOpts,
178
+ type: modelType,
179
+ })
180
+ if (config.device?.cpuThreads) {
181
+ instance.llm.setThreadCount(config.device.cpuThreads)
182
+ }
183
+
184
+ if ('generate' in instance) {
185
+ if (config.initialMessages?.length) {
186
+ let messages = createChatMessageArray(config.initialMessages)
187
+ let systemPrompt
188
+ if (messages[0].role === 'system') {
189
+ systemPrompt = messages[0].content
190
+ messages = messages.slice(1)
191
+ }
192
+ await instance.createChatSession({
193
+ systemPrompt,
194
+ messages,
195
+ })
196
+ } else if (config.prefix) {
197
+ await instance.generate(config.prefix, {
198
+ nPredict: 0,
199
+ })
200
+ } else {
201
+ await instance.generate('', {
202
+ nPredict: 0,
203
+ })
204
+ }
205
+ }
206
+
207
+ return instance
208
+ }
209
+
210
+ export async function disposeInstance(instance: GPT4AllInstance) {
211
+ instance.dispose()
212
+ }
213
+
214
+ export async function processTextCompletionTask(
215
+ { request, config, onChunk }: EngineTextCompletionArgs<GPT4AllModelConfig>,
216
+ instance: GPT4AllInstance,
217
+ signal?: AbortSignal,
218
+ ): Promise<EngineTextCompletionResult> {
219
+ if (!('generate' in instance)) {
220
+ throw new Error('Instance does not support text completion.')
221
+ }
222
+ if (!request.prompt) {
223
+ throw new Error('Prompt is required for text completion.')
224
+ }
225
+
226
+ let finishReason: CompletionFinishReason = 'eogToken'
227
+ let suffixToRemove: string | undefined
228
+
229
+ const defaults = config.completionDefaults ?? {}
230
+ const stopTriggers = request.stop ?? defaults.stop ?? []
231
+ const includesStopTriggers = (text: string) => stopTriggers.find((t) => text.includes(t))
232
+ const result = await instance.generate(request.prompt, {
233
+ // @ts-ignore
234
+ special: true, // allows passing in raw prompt (including <|start|> etc.)
235
+ promptTemplate: '%1',
236
+ temperature: request.temperature ?? defaults.temperature,
237
+ nPredict: request.maxTokens ?? defaults.maxTokens,
238
+ topP: request.topP ?? defaults.topP,
239
+ topK: request.topK ?? defaults.topK,
240
+ minP: request.minP ?? defaults.minP,
241
+ nBatch: config?.batchSize,
242
+ repeatLastN: request.repeatPenaltyNum ?? defaults.repeatPenaltyNum,
243
+ // repeat penalty is doing something different than both frequency and presence penalty
244
+ // so not falling back to them here.
245
+ repeatPenalty: request.repeatPenalty ?? defaults.repeatPenalty,
246
+ // seed: args.seed, // https://github.com/nomic-ai/gpt4all/issues/1952
247
+ // @ts-ignore
248
+ onResponseToken: (tokenId, text) => {
249
+ const matchingTrigger = includesStopTriggers(text)
250
+ if (matchingTrigger) {
251
+ finishReason = 'stopTrigger'
252
+ suffixToRemove = text
253
+ return false
254
+ }
255
+ if (onChunk) {
256
+ onChunk({
257
+ text,
258
+ tokens: [tokenId],
259
+ })
260
+ }
261
+ return !signal?.aborted
262
+ },
263
+ // @ts-ignore
264
+ onResponseTokens: ({ tokenIds, text }) => {
265
+ const matchingTrigger = includesStopTriggers(text)
266
+ if (matchingTrigger) {
267
+ finishReason = 'stopTrigger'
268
+ suffixToRemove = text
269
+ return false
270
+ }
271
+ if (onChunk) {
272
+ onChunk({
273
+ text,
274
+ tokens: tokenIds,
275
+ })
276
+ }
277
+ return !signal?.aborted
278
+ },
279
+ })
280
+
281
+ if (result.tokensGenerated === request.maxTokens) {
282
+ finishReason = 'maxTokens'
283
+ }
284
+
285
+ let responseText = result.text
286
+ if (suffixToRemove) {
287
+ responseText = responseText.slice(0, -suffixToRemove.length)
288
+ }
289
+
290
+ return {
291
+ finishReason,
292
+ text: responseText,
293
+ promptTokens: result.tokensIngested,
294
+ completionTokens: result.tokensGenerated,
295
+ contextTokens: instance.activeChatSession?.promptContext.nPast ?? 0,
296
+ }
297
+ }
298
+
299
+ export async function processChatCompletionTask(
300
+ { request, config, resetContext, log, onChunk }: EngineChatCompletionArgs<GPT4AllModelConfig>,
301
+ instance: GPT4AllInstance,
302
+ signal?: AbortSignal,
303
+ ): Promise<EngineChatCompletionResult> {
304
+ if (!('createChatSession' in instance)) {
305
+ throw new Error('Instance does not support chat completion.')
306
+ }
307
+ let session = instance.activeChatSession
308
+ if (!session || resetContext) {
309
+ log(LogLevels.debug, 'Resetting chat context')
310
+ let messages = createChatMessageArray(request.messages)
311
+ let systemPrompt
312
+ if (messages[0].role === 'system') {
313
+ systemPrompt = messages[0].content
314
+ messages = messages.slice(1)
315
+ }
316
+ // drop last user message
317
+ if (messages[messages.length - 1].role === 'user') {
318
+ messages = messages.slice(0, -1)
319
+ }
320
+
321
+ session = await instance.createChatSession({
322
+ systemPrompt,
323
+ messages,
324
+ })
325
+ }
326
+
327
+ const conversationMessages = createChatMessageArray(request.messages).filter((m) => m.role !== 'system')
328
+
329
+ const lastMessage = conversationMessages[conversationMessages.length - 1]
330
+ if (!(lastMessage.role === 'user' && lastMessage.content)) {
331
+ throw new Error('Chat completions require a final user message.')
332
+ }
333
+ const input: CompletionInput = lastMessage.content
334
+
335
+ let finishReason: CompletionFinishReason = 'eogToken'
336
+ let suffixToRemove: string | undefined
337
+
338
+ const defaults = config.completionDefaults ?? {}
339
+ const stopTriggers = request.stop ?? defaults.stop ?? []
340
+ const includesStopTriggers = (text: string) => stopTriggers.find((t) => text.includes(t))
341
+ const result = await createCompletion(session, input, {
342
+ temperature: request.temperature ?? defaults.temperature,
343
+ nPredict: request.maxTokens ?? defaults.maxTokens,
344
+ topP: request.topP ?? defaults.topP,
345
+ topK: request.topK ?? defaults.topK,
346
+ minP: request.minP ?? defaults.minP,
347
+ nBatch: config.batchSize,
348
+ repeatLastN: request.repeatPenaltyNum ?? defaults.repeatPenaltyNum,
349
+ repeatPenalty: request.repeatPenalty ?? defaults.repeatPenalty,
350
+ // seed: args.seed, // see https://github.com/nomic-ai/gpt4all/issues/1952
351
+ // @ts-ignore
352
+ onResponseToken: (tokenId, text) => {
353
+ const matchingTrigger = includesStopTriggers(text)
354
+ if (matchingTrigger) {
355
+ finishReason = 'stopTrigger'
356
+ suffixToRemove = text
357
+ return false
358
+ }
359
+ if (onChunk) {
360
+ onChunk({
361
+ text,
362
+ tokens: [tokenId],
363
+ })
364
+ }
365
+ return !signal?.aborted
366
+ },
367
+ // @ts-ignore
368
+ onResponseTokens: ({ tokenIds, text }) => {
369
+ const matchingTrigger = includesStopTriggers(text)
370
+ if (matchingTrigger) {
371
+ finishReason = 'stopTrigger'
372
+ suffixToRemove = text
373
+ return false
374
+ }
375
+ if (onChunk) {
376
+ onChunk({
377
+ tokens: tokenIds,
378
+ text,
379
+ })
380
+ }
381
+
382
+ return !signal?.aborted
383
+ },
384
+ })
385
+
386
+ if (result.usage.completion_tokens === request.maxTokens) {
387
+ finishReason = 'maxTokens'
388
+ }
389
+
390
+ let response = result.choices[0].message.content
391
+ if (suffixToRemove) {
392
+ response = response.slice(0, -suffixToRemove.length)
393
+ }
394
+
395
+ return {
396
+ finishReason,
397
+ message: {
398
+ role: 'assistant',
399
+ content: response,
400
+ },
401
+ promptTokens: result.usage.prompt_tokens,
402
+ completionTokens: result.usage.completion_tokens,
403
+ contextTokens: session.promptContext.nPast,
404
+ }
405
+ }
406
+
407
+ export async function processEmbeddingTask(
408
+ { request, config }: EngineEmbeddingArgs,
409
+ instance: GPT4AllInstance,
410
+ signal?: AbortSignal,
411
+ ): Promise<EngineEmbeddingResult> {
412
+ if (!('embed' in instance)) {
413
+ throw new Error('Instance does not support embedding.')
414
+ }
415
+ if (!request.input) {
416
+ throw new Error('Input is required for embedding.')
417
+ }
418
+ const texts: string[] = []
419
+ if (typeof request.input === 'string') {
420
+ texts.push(request.input)
421
+ } else if (Array.isArray(request.input)) {
422
+ for (const input of request.input) {
423
+ if (typeof input === 'string') {
424
+ texts.push(input)
425
+ } else if (input.type === 'text') {
426
+ texts.push(input.content)
427
+ } else if (input.type === 'image') {
428
+ throw new Error('Image inputs not implemented.')
429
+ }
430
+ }
431
+ }
432
+
433
+ const res = await createEmbedding(instance, texts, {
434
+ dimensionality: request.dimensions,
435
+ })
436
+
437
+ return {
438
+ embeddings: res.embeddings,
439
+ inputTokens: res.n_prompt_tokens,
440
+ }
441
+ }
@@ -0,0 +1,31 @@
1
+ import { ChatMessage as GPT4AllChatMessage } from 'gpt4all'
2
+ import { ChatMessage } from '#package/types/index.js'
3
+ import { flattenMessageTextContent } from '#package/lib/flattenMessageTextContent.js'
4
+
5
+ export function createChatMessageArray(
6
+ messages: ChatMessage[],
7
+ ): GPT4AllChatMessage[] {
8
+ const chatMessages: GPT4AllChatMessage[] = []
9
+ let systemPrompt: string | undefined
10
+ for (const message of messages) {
11
+ if (message.role === 'user' || message.role === 'assistant') {
12
+ chatMessages.push({
13
+ role: message.role,
14
+ content: flattenMessageTextContent(message.content),
15
+ })
16
+ } else if (message.role === 'system') {
17
+ if (systemPrompt) {
18
+ systemPrompt += '\n\n' + message.content
19
+ } else {
20
+ systemPrompt = flattenMessageTextContent(message.content)
21
+ }
22
+ }
23
+ }
24
+ if (systemPrompt) {
25
+ chatMessages.unshift({
26
+ role: 'system',
27
+ content: systemPrompt,
28
+ })
29
+ }
30
+ return chatMessages
31
+ }
@@ -0,0 +1,28 @@
1
+ import type { ModelPool } from '#package/pool.js'
2
+ import type { ModelStore } from '#package/store.js'
3
+ import { ModelEngine, EngineStartContext, ModelOptions, BuiltInModelOptions } from '#package/types/index.js'
4
+
5
+ export const BuiltInEngines = {
6
+ gpt4all: 'gpt4all',
7
+ nodeLlamaCpp: 'node-llama-cpp',
8
+ transformersJs: 'transformers-js',
9
+ stableDiffusionCpp: 'stable-diffusion-cpp',
10
+ } as const
11
+
12
+ export type BuiltInEngineName = typeof BuiltInEngines[keyof typeof BuiltInEngines];
13
+
14
+ export const builtInEngineNames: string[] = [
15
+ ...Object.values(BuiltInEngines),
16
+ ] as const
17
+
18
+ export class CustomEngine implements ModelEngine {
19
+ pool!: ModelPool
20
+ store!: ModelStore
21
+ async start({ pool, store }: EngineStartContext) {
22
+ this.pool = pool
23
+ this.store = store
24
+ }
25
+ async prepareModel() {}
26
+ async createInstance() {}
27
+ async disposeInstance() {}
28
+ }