inference-server 1.0.0-beta.19

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (227) hide show
  1. package/README.md +216 -0
  2. package/dist/api/openai/enums.d.ts +4 -0
  3. package/dist/api/openai/enums.js +17 -0
  4. package/dist/api/openai/enums.js.map +1 -0
  5. package/dist/api/openai/handlers/chat.d.ts +3 -0
  6. package/dist/api/openai/handlers/chat.js +358 -0
  7. package/dist/api/openai/handlers/chat.js.map +1 -0
  8. package/dist/api/openai/handlers/completions.d.ts +3 -0
  9. package/dist/api/openai/handlers/completions.js +169 -0
  10. package/dist/api/openai/handlers/completions.js.map +1 -0
  11. package/dist/api/openai/handlers/embeddings.d.ts +3 -0
  12. package/dist/api/openai/handlers/embeddings.js +74 -0
  13. package/dist/api/openai/handlers/embeddings.js.map +1 -0
  14. package/dist/api/openai/handlers/images.d.ts +0 -0
  15. package/dist/api/openai/handlers/images.js +4 -0
  16. package/dist/api/openai/handlers/images.js.map +1 -0
  17. package/dist/api/openai/handlers/models.d.ts +3 -0
  18. package/dist/api/openai/handlers/models.js +23 -0
  19. package/dist/api/openai/handlers/models.js.map +1 -0
  20. package/dist/api/openai/handlers/transcription.d.ts +0 -0
  21. package/dist/api/openai/handlers/transcription.js +4 -0
  22. package/dist/api/openai/handlers/transcription.js.map +1 -0
  23. package/dist/api/openai/index.d.ts +7 -0
  24. package/dist/api/openai/index.js +14 -0
  25. package/dist/api/openai/index.js.map +1 -0
  26. package/dist/api/parseJSONRequestBody.d.ts +2 -0
  27. package/dist/api/parseJSONRequestBody.js +24 -0
  28. package/dist/api/parseJSONRequestBody.js.map +1 -0
  29. package/dist/api/v1/index.d.ts +2 -0
  30. package/dist/api/v1/index.js +29 -0
  31. package/dist/api/v1/index.js.map +1 -0
  32. package/dist/cli.d.ts +1 -0
  33. package/dist/cli.js +10 -0
  34. package/dist/cli.js.map +1 -0
  35. package/dist/engines/gpt4all/engine.d.ts +34 -0
  36. package/dist/engines/gpt4all/engine.js +357 -0
  37. package/dist/engines/gpt4all/engine.js.map +1 -0
  38. package/dist/engines/gpt4all/util.d.ts +3 -0
  39. package/dist/engines/gpt4all/util.js +29 -0
  40. package/dist/engines/gpt4all/util.js.map +1 -0
  41. package/dist/engines/index.d.ts +19 -0
  42. package/dist/engines/index.js +21 -0
  43. package/dist/engines/index.js.map +1 -0
  44. package/dist/engines/node-llama-cpp/engine.d.ts +49 -0
  45. package/dist/engines/node-llama-cpp/engine.js +666 -0
  46. package/dist/engines/node-llama-cpp/engine.js.map +1 -0
  47. package/dist/engines/node-llama-cpp/types.d.ts +13 -0
  48. package/dist/engines/node-llama-cpp/types.js +2 -0
  49. package/dist/engines/node-llama-cpp/types.js.map +1 -0
  50. package/dist/engines/node-llama-cpp/util.d.ts +15 -0
  51. package/dist/engines/node-llama-cpp/util.js +84 -0
  52. package/dist/engines/node-llama-cpp/util.js.map +1 -0
  53. package/dist/engines/node-llama-cpp/validateModelFile.d.ts +8 -0
  54. package/dist/engines/node-llama-cpp/validateModelFile.js +36 -0
  55. package/dist/engines/node-llama-cpp/validateModelFile.js.map +1 -0
  56. package/dist/engines/stable-diffusion-cpp/engine.d.ts +90 -0
  57. package/dist/engines/stable-diffusion-cpp/engine.js +294 -0
  58. package/dist/engines/stable-diffusion-cpp/engine.js.map +1 -0
  59. package/dist/engines/stable-diffusion-cpp/types.d.ts +3 -0
  60. package/dist/engines/stable-diffusion-cpp/types.js +2 -0
  61. package/dist/engines/stable-diffusion-cpp/types.js.map +1 -0
  62. package/dist/engines/stable-diffusion-cpp/util.d.ts +4 -0
  63. package/dist/engines/stable-diffusion-cpp/util.js +55 -0
  64. package/dist/engines/stable-diffusion-cpp/util.js.map +1 -0
  65. package/dist/engines/stable-diffusion-cpp/validateModelFiles.d.ts +19 -0
  66. package/dist/engines/stable-diffusion-cpp/validateModelFiles.js +91 -0
  67. package/dist/engines/stable-diffusion-cpp/validateModelFiles.js.map +1 -0
  68. package/dist/engines/transformers-js/engine.d.ts +37 -0
  69. package/dist/engines/transformers-js/engine.js +538 -0
  70. package/dist/engines/transformers-js/engine.js.map +1 -0
  71. package/dist/engines/transformers-js/types.d.ts +7 -0
  72. package/dist/engines/transformers-js/types.js +2 -0
  73. package/dist/engines/transformers-js/types.js.map +1 -0
  74. package/dist/engines/transformers-js/util.d.ts +7 -0
  75. package/dist/engines/transformers-js/util.js +36 -0
  76. package/dist/engines/transformers-js/util.js.map +1 -0
  77. package/dist/engines/transformers-js/validateModelFiles.d.ts +17 -0
  78. package/dist/engines/transformers-js/validateModelFiles.js +133 -0
  79. package/dist/engines/transformers-js/validateModelFiles.js.map +1 -0
  80. package/dist/experiments/ChatWithVision.d.ts +11 -0
  81. package/dist/experiments/ChatWithVision.js +91 -0
  82. package/dist/experiments/ChatWithVision.js.map +1 -0
  83. package/dist/experiments/StableDiffPromptGenerator.d.ts +0 -0
  84. package/dist/experiments/StableDiffPromptGenerator.js +4 -0
  85. package/dist/experiments/StableDiffPromptGenerator.js.map +1 -0
  86. package/dist/experiments/VoiceFunctionCall.d.ts +18 -0
  87. package/dist/experiments/VoiceFunctionCall.js +51 -0
  88. package/dist/experiments/VoiceFunctionCall.js.map +1 -0
  89. package/dist/http.d.ts +19 -0
  90. package/dist/http.js +54 -0
  91. package/dist/http.js.map +1 -0
  92. package/dist/index.d.ts +7 -0
  93. package/dist/index.js +8 -0
  94. package/dist/index.js.map +1 -0
  95. package/dist/instance.d.ts +88 -0
  96. package/dist/instance.js +594 -0
  97. package/dist/instance.js.map +1 -0
  98. package/dist/lib/acquireFileLock.d.ts +7 -0
  99. package/dist/lib/acquireFileLock.js +38 -0
  100. package/dist/lib/acquireFileLock.js.map +1 -0
  101. package/dist/lib/calculateContextIdentity.d.ts +7 -0
  102. package/dist/lib/calculateContextIdentity.js +39 -0
  103. package/dist/lib/calculateContextIdentity.js.map +1 -0
  104. package/dist/lib/calculateFileChecksum.d.ts +1 -0
  105. package/dist/lib/calculateFileChecksum.js +16 -0
  106. package/dist/lib/calculateFileChecksum.js.map +1 -0
  107. package/dist/lib/copyDirectory.d.ts +6 -0
  108. package/dist/lib/copyDirectory.js +27 -0
  109. package/dist/lib/copyDirectory.js.map +1 -0
  110. package/dist/lib/decodeAudio.d.ts +1 -0
  111. package/dist/lib/decodeAudio.js +26 -0
  112. package/dist/lib/decodeAudio.js.map +1 -0
  113. package/dist/lib/downloadModelFile.d.ts +10 -0
  114. package/dist/lib/downloadModelFile.js +58 -0
  115. package/dist/lib/downloadModelFile.js.map +1 -0
  116. package/dist/lib/flattenMessageTextContent.d.ts +2 -0
  117. package/dist/lib/flattenMessageTextContent.js +11 -0
  118. package/dist/lib/flattenMessageTextContent.js.map +1 -0
  119. package/dist/lib/getCacheDirPath.d.ts +12 -0
  120. package/dist/lib/getCacheDirPath.js +31 -0
  121. package/dist/lib/getCacheDirPath.js.map +1 -0
  122. package/dist/lib/loadImage.d.ts +12 -0
  123. package/dist/lib/loadImage.js +30 -0
  124. package/dist/lib/loadImage.js.map +1 -0
  125. package/dist/lib/logger.d.ts +12 -0
  126. package/dist/lib/logger.js +98 -0
  127. package/dist/lib/logger.js.map +1 -0
  128. package/dist/lib/math.d.ts +7 -0
  129. package/dist/lib/math.js +30 -0
  130. package/dist/lib/math.js.map +1 -0
  131. package/dist/lib/resolveModelFileLocation.d.ts +15 -0
  132. package/dist/lib/resolveModelFileLocation.js +41 -0
  133. package/dist/lib/resolveModelFileLocation.js.map +1 -0
  134. package/dist/lib/util.d.ts +7 -0
  135. package/dist/lib/util.js +61 -0
  136. package/dist/lib/util.js.map +1 -0
  137. package/dist/lib/validateModelFile.d.ts +9 -0
  138. package/dist/lib/validateModelFile.js +62 -0
  139. package/dist/lib/validateModelFile.js.map +1 -0
  140. package/dist/lib/validateModelOptions.d.ts +3 -0
  141. package/dist/lib/validateModelOptions.js +23 -0
  142. package/dist/lib/validateModelOptions.js.map +1 -0
  143. package/dist/pool.d.ts +61 -0
  144. package/dist/pool.js +512 -0
  145. package/dist/pool.js.map +1 -0
  146. package/dist/server.d.ts +59 -0
  147. package/dist/server.js +221 -0
  148. package/dist/server.js.map +1 -0
  149. package/dist/standalone.d.ts +1 -0
  150. package/dist/standalone.js +306 -0
  151. package/dist/standalone.js.map +1 -0
  152. package/dist/store.d.ts +60 -0
  153. package/dist/store.js +203 -0
  154. package/dist/store.js.map +1 -0
  155. package/dist/types/completions.d.ts +57 -0
  156. package/dist/types/completions.js +2 -0
  157. package/dist/types/completions.js.map +1 -0
  158. package/dist/types/index.d.ts +326 -0
  159. package/dist/types/index.js +2 -0
  160. package/dist/types/index.js.map +1 -0
  161. package/docs/engines.md +28 -0
  162. package/docs/gpu.md +72 -0
  163. package/docs/http-api.md +147 -0
  164. package/examples/all-options.js +108 -0
  165. package/examples/chat-cli.js +56 -0
  166. package/examples/chat-server.js +65 -0
  167. package/examples/concurrency.js +70 -0
  168. package/examples/express.js +70 -0
  169. package/examples/pool.js +91 -0
  170. package/package.json +113 -0
  171. package/src/api/openai/enums.ts +20 -0
  172. package/src/api/openai/handlers/chat.ts +408 -0
  173. package/src/api/openai/handlers/completions.ts +196 -0
  174. package/src/api/openai/handlers/embeddings.ts +92 -0
  175. package/src/api/openai/handlers/images.ts +3 -0
  176. package/src/api/openai/handlers/models.ts +33 -0
  177. package/src/api/openai/handlers/transcription.ts +2 -0
  178. package/src/api/openai/index.ts +16 -0
  179. package/src/api/parseJSONRequestBody.ts +26 -0
  180. package/src/api/v1/DRAFT.md +16 -0
  181. package/src/api/v1/index.ts +37 -0
  182. package/src/cli.ts +9 -0
  183. package/src/engines/gpt4all/engine.ts +441 -0
  184. package/src/engines/gpt4all/util.ts +31 -0
  185. package/src/engines/index.ts +28 -0
  186. package/src/engines/node-llama-cpp/engine.ts +811 -0
  187. package/src/engines/node-llama-cpp/types.ts +17 -0
  188. package/src/engines/node-llama-cpp/util.ts +126 -0
  189. package/src/engines/node-llama-cpp/validateModelFile.ts +46 -0
  190. package/src/engines/stable-diffusion-cpp/engine.ts +369 -0
  191. package/src/engines/stable-diffusion-cpp/types.ts +54 -0
  192. package/src/engines/stable-diffusion-cpp/util.ts +58 -0
  193. package/src/engines/stable-diffusion-cpp/validateModelFiles.ts +119 -0
  194. package/src/engines/transformers-js/engine.ts +659 -0
  195. package/src/engines/transformers-js/types.ts +25 -0
  196. package/src/engines/transformers-js/util.ts +40 -0
  197. package/src/engines/transformers-js/validateModelFiles.ts +168 -0
  198. package/src/experiments/ChatWithVision.ts +103 -0
  199. package/src/experiments/StableDiffPromptGenerator.ts +2 -0
  200. package/src/experiments/VoiceFunctionCall.ts +71 -0
  201. package/src/http.ts +72 -0
  202. package/src/index.ts +7 -0
  203. package/src/instance.ts +723 -0
  204. package/src/lib/acquireFileLock.ts +38 -0
  205. package/src/lib/calculateContextIdentity.ts +53 -0
  206. package/src/lib/calculateFileChecksum.ts +18 -0
  207. package/src/lib/copyDirectory.ts +29 -0
  208. package/src/lib/decodeAudio.ts +39 -0
  209. package/src/lib/downloadModelFile.ts +70 -0
  210. package/src/lib/flattenMessageTextContent.ts +19 -0
  211. package/src/lib/getCacheDirPath.ts +34 -0
  212. package/src/lib/loadImage.ts +46 -0
  213. package/src/lib/logger.ts +112 -0
  214. package/src/lib/math.ts +31 -0
  215. package/src/lib/resolveModelFileLocation.ts +49 -0
  216. package/src/lib/util.ts +75 -0
  217. package/src/lib/validateModelFile.ts +71 -0
  218. package/src/lib/validateModelOptions.ts +31 -0
  219. package/src/pool.ts +651 -0
  220. package/src/server.ts +270 -0
  221. package/src/standalone.ts +320 -0
  222. package/src/store.ts +278 -0
  223. package/src/types/completions.ts +86 -0
  224. package/src/types/index.ts +488 -0
  225. package/tsconfig.json +29 -0
  226. package/tsconfig.release.json +11 -0
  227. package/vitest.config.ts +18 -0
@@ -0,0 +1,17 @@
1
+ import {
2
+ ChatHistoryItem,
3
+ ChatModelFunctions,
4
+ LlamaChatResponse,
5
+ LlamaChatResponseFunctionCall,
6
+ } from 'node-llama-cpp'
7
+
8
+ export interface LlamaChatResult<T extends ChatModelFunctions = any> {
9
+ responseText: string | null
10
+ functionCalls?: LlamaChatResponseFunctionCall<T>[]
11
+ stopReason: LlamaChatResponse['metadata']['stopReason']
12
+ }
13
+
14
+ export type ContextShiftStrategy = ((options: {
15
+ chatHistory: ChatHistoryItem[]
16
+ metadata: any
17
+ }) => { chatHistory: ChatHistoryItem[]; metadata: any }) | null
@@ -0,0 +1,126 @@
1
+ import fs from 'node:fs'
2
+ import path from 'node:path'
3
+ import {
4
+ ChatHistoryItem,
5
+ ChatModelResponse,
6
+ LlamaTextJSON,
7
+ ChatModelFunctionCall,
8
+ } from 'node-llama-cpp'
9
+ import { CompletionFinishReason, ChatMessage } from '#package/types/index.js'
10
+ import { flattenMessageTextContent } from '#package/lib/flattenMessageTextContent.js'
11
+ import { LlamaChatResult } from './types.js'
12
+
13
+ export function mapFinishReason(
14
+ nodeLlamaCppFinishReason: LlamaChatResult['stopReason'],
15
+ ): CompletionFinishReason {
16
+ switch (nodeLlamaCppFinishReason) {
17
+ case 'functionCalls':
18
+ return 'toolCalls'
19
+ case 'stopGenerationTrigger':
20
+ return 'stopTrigger'
21
+ case 'customStopTrigger':
22
+ return 'stopTrigger'
23
+ default:
24
+ return nodeLlamaCppFinishReason
25
+ }
26
+ }
27
+
28
+ export function addFunctionCallToChatHistory({
29
+ chatHistory,
30
+ functionName,
31
+ functionDescription,
32
+ callParams,
33
+ callResult,
34
+ rawCall,
35
+ startsNewChunk
36
+ }: {
37
+ chatHistory: ChatHistoryItem[],
38
+ functionName: string,
39
+ functionDescription?: string,
40
+ callParams: any,
41
+ callResult: any,
42
+ rawCall?: LlamaTextJSON,
43
+ startsNewChunk?: boolean
44
+ }) {
45
+ const newChatHistory = chatHistory.slice();
46
+ if (newChatHistory.length === 0 || newChatHistory[newChatHistory.length - 1]!.type !== "model")
47
+ newChatHistory.push({
48
+ type: "model",
49
+ response: []
50
+ });
51
+
52
+ const lastModelResponseItem = newChatHistory[newChatHistory.length - 1] as ChatModelResponse;
53
+ const newLastModelResponseItem = {...lastModelResponseItem};
54
+ newChatHistory[newChatHistory.length - 1] = newLastModelResponseItem;
55
+
56
+ const modelResponse = newLastModelResponseItem.response.slice();
57
+ newLastModelResponseItem.response = modelResponse;
58
+
59
+ const functionCall: ChatModelFunctionCall = {
60
+ type: "functionCall",
61
+ name: functionName,
62
+ description: functionDescription,
63
+ params: callParams,
64
+ result: callResult,
65
+ rawCall
66
+ };
67
+
68
+ if (startsNewChunk)
69
+ functionCall.startsNewChunk = true;
70
+
71
+ modelResponse.push(functionCall);
72
+
73
+ return newChatHistory;
74
+ }
75
+
76
+ export function createChatMessageArray(
77
+ messages: ChatMessage[],
78
+ ): ChatHistoryItem[] {
79
+ const items: ChatHistoryItem[] = []
80
+ let systemPrompt: string | undefined
81
+ for (const message of messages) {
82
+ if (message.role === 'user') {
83
+ items.push({
84
+ type: 'user',
85
+ text: flattenMessageTextContent(message.content),
86
+ })
87
+ } else if (message.role === 'assistant') {
88
+ items.push({
89
+ type: 'model',
90
+ response: [message.content],
91
+ })
92
+ } else if (message.role === 'system') {
93
+ if (systemPrompt) {
94
+ systemPrompt += '\n\n' + flattenMessageTextContent(message.content)
95
+ } else {
96
+ systemPrompt = flattenMessageTextContent(message.content)
97
+ }
98
+ }
99
+ }
100
+
101
+ if (systemPrompt) {
102
+ items.unshift({
103
+ type: 'system',
104
+ text: systemPrompt,
105
+ })
106
+ }
107
+
108
+ return items
109
+ }
110
+
111
+ export async function readGBNFFiles(directoryPath: string) {
112
+ const gbnfFiles = fs
113
+ .readdirSync(directoryPath)
114
+ .filter((f) => f.endsWith('.gbnf'))
115
+ const fileContents = await Promise.all(
116
+ gbnfFiles.map((file) =>
117
+ fs.promises.readFile(path.join(directoryPath, file), 'utf-8'),
118
+ ),
119
+ )
120
+ return gbnfFiles.reduce((acc, file, i) => {
121
+ acc[file.replace('.gbnf', '')] = fileContents[i]
122
+ return acc
123
+ }, {} as Record<string, string>)
124
+ }
125
+
126
+
@@ -0,0 +1,46 @@
1
+
2
+
3
+ import fs from 'node:fs'
4
+ import { calculateFileChecksum } from '#package/lib/calculateFileChecksum.js'
5
+ import { resolveModelFileLocation } from '#package/lib/resolveModelFileLocation.js'
6
+
7
+ interface ValidatableModelConfig {
8
+ url?: string
9
+ location?: string
10
+ modelsCachePath: string
11
+ sha256?: string
12
+ }
13
+
14
+ export async function validateModelFile(config: ValidatableModelConfig): Promise<string | undefined> {
15
+ const fileLocation = resolveModelFileLocation({
16
+ url: config.url,
17
+ filePath: config.location,
18
+ modelsCachePath: config.modelsCachePath,
19
+ })
20
+ if (!fs.existsSync(fileLocation)) {
21
+ return `Model file missing at ${fileLocation}`
22
+ }
23
+ const ipullFile = fileLocation + '.ipull'
24
+ let validatedChecksum = false
25
+ if (fs.existsSync(ipullFile)) {
26
+ // if we have a valid file at the download destination, we can remove the ipull file
27
+ if (config.sha256) {
28
+ const fileHash = await calculateFileChecksum(fileLocation, 'sha256')
29
+ if (fileHash === config.sha256) {
30
+ fs.unlinkSync(ipullFile)
31
+ validatedChecksum = true
32
+ }
33
+ }
34
+ if (!validatedChecksum) {
35
+ return `Model file with incomplete download`
36
+ }
37
+ }
38
+
39
+ if (!validatedChecksum && config.sha256) {
40
+ const fileHash = await calculateFileChecksum(fileLocation, 'sha256')
41
+ if (fileHash !== config.sha256) {
42
+ return `File sha256 checksum mismatch: expected ${config.sha256} got ${fileHash} at ${fileLocation}`
43
+ }
44
+ }
45
+ return undefined
46
+ }
@@ -0,0 +1,369 @@
1
+ import StableDiffusion from '@lmagder/node-stable-diffusion-cpp'
2
+ import { gguf } from '@huggingface/gguf'
3
+ import sharp from 'sharp'
4
+ import fs from 'node:fs'
5
+ import path from 'node:path'
6
+ import {
7
+ EngineContext,
8
+ FileDownloadProgress,
9
+ ModelConfig,
10
+ EngineTextToImageResult,
11
+ ModelFileSource,
12
+ EngineTextToImageArgs,
13
+ Image,
14
+ EngineImageToImageArgs,
15
+ } from '#package/types/index.js'
16
+ import { LogLevel, LogLevels } from '#package/lib/logger.js'
17
+ import { downloadModelFile } from '#package/lib/downloadModelFile.js'
18
+ import { resolveModelFileLocation } from '#package/lib/resolveModelFileLocation.js'
19
+ import { acquireFileLock } from '#package/lib/acquireFileLock.js'
20
+ import { getRandomNumber } from '#package/lib/util.js'
21
+ import { StableDiffusionSamplingMethod, StableDiffusionSchedule, StableDiffusionWeightType } from './types.js'
22
+ import { validateModelFiles, ModelValidationResult } from './validateModelFiles.js'
23
+ import { parseQuantization, getWeightType, getSamplingMethod } from './util.js'
24
+
25
+ export interface StableDiffusionInstance {
26
+ context: StableDiffusion.Context
27
+ }
28
+
29
+ export interface StableDiffusionModelConfig extends ModelConfig {
30
+ location: string
31
+ sha256?: string
32
+ clipL?: ModelFileSource
33
+ clipG?: ModelFileSource
34
+ vae?: ModelFileSource
35
+ t5xxl?: ModelFileSource
36
+ controlNet?: ModelFileSource
37
+ taesd?: ModelFileSource
38
+ diffusionModel?: boolean
39
+ model?: ModelFileSource
40
+ loras?: ModelFileSource[]
41
+ samplingMethod?: StableDiffusionSamplingMethod
42
+ weightType?: StableDiffusionWeightType
43
+ schedule?: StableDiffusionSchedule
44
+ device?: {
45
+ gpu?: boolean | 'auto' | (string & {})
46
+ cpuThreads?: number
47
+ }
48
+ }
49
+
50
+ interface StableDiffusionModelMeta {
51
+ gguf: any
52
+ }
53
+
54
+ export const autoGpu = true
55
+
56
+ export async function prepareModel(
57
+ { config, log }: EngineContext<StableDiffusionModelConfig, StableDiffusionModelMeta>,
58
+ onProgress?: (progress: FileDownloadProgress) => void,
59
+ signal?: AbortSignal,
60
+ ) {
61
+ fs.mkdirSync(path.dirname(config.location), { recursive: true })
62
+ const releaseFileLock = await acquireFileLock(config.location)
63
+ if (signal?.aborted) {
64
+ releaseFileLock()
65
+ return
66
+ }
67
+ log(LogLevels.info, `Preparing stable-diffusion model at ${config.location}`, {
68
+ model: config.id,
69
+ })
70
+
71
+ const downloadModel = (url: string, validationResult: ModelValidationResult) => {
72
+ log(LogLevels.info, `${validationResult.message} - Downloading model files`, {
73
+ model: config.id,
74
+ url: config.url,
75
+ location: config.location,
76
+ errors: validationResult.errors,
77
+ })
78
+ const downloadPromises = []
79
+ if (validationResult.errors.model && config.location) {
80
+ downloadPromises.push(
81
+ downloadModelFile({
82
+ url: url,
83
+ filePath: config.location,
84
+ modelsCachePath: config.modelsCachePath,
85
+ onProgress,
86
+ signal,
87
+ }),
88
+ )
89
+ }
90
+ const pushDownload = (src: ModelFileSource) => {
91
+ if (!src.url) {
92
+ return
93
+ }
94
+ downloadPromises.push(
95
+ downloadModelFile({
96
+ url: src.url,
97
+ filePath: src.file,
98
+ modelsCachePath: config.modelsCachePath,
99
+ onProgress,
100
+ signal,
101
+ }),
102
+ )
103
+ }
104
+ if (validationResult.errors.clipG && config.clipG) {
105
+ pushDownload(config.clipG)
106
+ }
107
+ if (validationResult.errors.clipL && config.clipL) {
108
+ pushDownload(config.clipL)
109
+ }
110
+ if (validationResult.errors.vae && config.vae) {
111
+ pushDownload(config.vae)
112
+ }
113
+ if (validationResult.errors.t5xxl && config.t5xxl) {
114
+ pushDownload(config.t5xxl)
115
+ }
116
+ if (validationResult.errors.controlNet && config.controlNet) {
117
+ pushDownload(config.controlNet)
118
+ }
119
+ if (validationResult.errors.taesd && config.taesd) {
120
+ pushDownload(config.taesd)
121
+ }
122
+ if (config.loras) {
123
+ for (const lora of config.loras) {
124
+ if (!lora.url) {
125
+ continue
126
+ }
127
+ pushDownload(lora)
128
+ }
129
+ }
130
+ return Promise.all(downloadPromises)
131
+ }
132
+ try {
133
+ if (signal?.aborted) {
134
+ return
135
+ }
136
+
137
+ const validationResults = await validateModelFiles(config)
138
+ if (signal?.aborted) {
139
+ return
140
+ }
141
+ if (validationResults) {
142
+ if (config.url) {
143
+ await downloadModel(config.url, validationResults)
144
+ } else {
145
+ throw new Error(`${validationResults.message} - No URL provided`)
146
+ }
147
+ }
148
+
149
+ const finalValidationError = await validateModelFiles(config)
150
+ if (finalValidationError) {
151
+ throw new Error(`Downloaded files are invalid: ${finalValidationError}`)
152
+ }
153
+
154
+ const result: any = {}
155
+ if (config.location.endsWith('.gguf')) {
156
+ const { metadata, tensorInfos } = await gguf(config.location, {
157
+ allowLocalFile: true,
158
+ })
159
+ result.gguf = metadata
160
+ }
161
+ return result
162
+ } catch (error) {
163
+ throw error
164
+ } finally {
165
+ releaseFileLock()
166
+ }
167
+ }
168
+
169
+ export async function createInstance({ config, log }: EngineContext<StableDiffusionModelConfig>, signal?: AbortSignal) {
170
+ log(LogLevels.debug, 'Load Stable Diffusion model', config)
171
+ const handleLog = (level: string, message: string) => {
172
+ log(level as LogLevel, message)
173
+ }
174
+ const handleProgress = (step: number, steps: number, time: number) => {
175
+ log(LogLevels.debug, `Progress: ${step}/${steps} (${time}ms)`)
176
+ }
177
+
178
+ const resolveComponentLocation = (src?: ModelFileSource) => {
179
+ if (src) {
180
+ return resolveModelFileLocation({
181
+ url: src.url,
182
+ filePath: src.file,
183
+ modelsCachePath: config.modelsCachePath,
184
+ })
185
+ }
186
+ return undefined
187
+ }
188
+
189
+ const vaeFilePath = resolveComponentLocation(config.vae)
190
+ const clipLFilePath = resolveComponentLocation(config.clipL)
191
+ const clipGFilePath = resolveComponentLocation(config.clipG)
192
+ const t5xxlFilePath = resolveComponentLocation(config.t5xxl)
193
+ const controlNetFilePath = resolveComponentLocation(config.controlNet)
194
+ const taesdFilePath = resolveComponentLocation(config.taesd)
195
+
196
+ let weightType = config.weightType ? getWeightType(config.weightType) : undefined
197
+ if (typeof weightType === 'undefined') {
198
+ const quantization = parseQuantization(config.location)
199
+ if (quantization) {
200
+ weightType = getWeightType(quantization)
201
+ }
202
+ }
203
+
204
+ if (typeof weightType === 'undefined') {
205
+ log(LogLevels.warn, 'Failed to parse model weight type (quantization) from file name, falling back to f32', {
206
+ file: config.location,
207
+ })
208
+ }
209
+
210
+ const loraDir = path.join(path.dirname(config.location), 'loras')
211
+ const contextParams = {
212
+ model: !config.diffusionModel ? config.location : undefined,
213
+ diffusionModel: config.diffusionModel ? config.location : undefined,
214
+ numThreads: config.device?.cpuThreads,
215
+ vae: vaeFilePath,
216
+ clipL: clipLFilePath,
217
+ clipG: clipGFilePath,
218
+ t5xxl: t5xxlFilePath,
219
+ controlNet: controlNetFilePath,
220
+ taesd: taesdFilePath,
221
+ weightType: weightType,
222
+ loraDir: loraDir,
223
+ // TODO how to expose?
224
+ // keepClipOnCpu: true,
225
+ // keepControlNetOnCpu: true,
226
+ // keepVaeOnCpu: true,
227
+ }
228
+ log(LogLevels.debug, 'Creating context with', contextParams)
229
+ const context = await StableDiffusion.createContext(
230
+ // @ts-ignore
231
+ contextParams,
232
+ handleLog,
233
+ handleProgress,
234
+ )
235
+
236
+ return {
237
+ context,
238
+ }
239
+ }
240
+
241
+ export async function processTextToImageTask(
242
+ { request, config, log }: EngineTextToImageArgs<StableDiffusionModelConfig>,
243
+ instance: StableDiffusionInstance,
244
+ signal?: AbortSignal,
245
+ ): Promise<EngineTextToImageResult> {
246
+ const seed = request.seed ?? getRandomNumber(0, 1000000)
247
+ const results = await instance.context.txt2img({
248
+ prompt: request.prompt,
249
+ negativePrompt: request.negativePrompt,
250
+ width: request.width || 512,
251
+ height: request.height || 512,
252
+ batchCount: request.batchCount,
253
+ sampleMethod: getSamplingMethod(request.samplingMethod || config.samplingMethod),
254
+ sampleSteps: request.sampleSteps,
255
+ cfgScale: request.cfgScale,
256
+ // @ts-ignore
257
+ guidance: request.guidance,
258
+ styleRatio: request.styleRatio,
259
+ controlStrength: request.controlStrength,
260
+ normalizeInput: false,
261
+ seed,
262
+ })
263
+
264
+ const images: Image[] = []
265
+ for (const [idx, img] of results.entries()) {
266
+ images.push({
267
+ handle: sharp(img.data, {
268
+ raw: {
269
+ width: img.width,
270
+ height: img.height,
271
+ channels: img.channel,
272
+ },
273
+ }),
274
+ width: img.width,
275
+ height: img.height,
276
+ channels: img.channel,
277
+ })
278
+ }
279
+ if (!images.length) {
280
+ throw new Error('No images generated')
281
+ }
282
+ return {
283
+ images: images,
284
+ seed,
285
+ }
286
+ }
287
+
288
+ export async function processImageToImageTask(
289
+ { request, config, log }: EngineImageToImageArgs<StableDiffusionModelConfig>,
290
+ instance: StableDiffusionInstance,
291
+ signal?: AbortSignal,
292
+ ): Promise<EngineTextToImageResult> {
293
+ const seed = request.seed ?? getRandomNumber(0, 1000000)
294
+ console.debug('processImageToImageTask', {
295
+ width: request.image.width,
296
+ height: request.image.height,
297
+ channel: request.image.channels as 3 | 4,
298
+ })
299
+ const initImage = {
300
+ data: await request.image.handle.raw().toBuffer(),
301
+ width: request.image.width,
302
+ height: request.image.height,
303
+ channel: request.image.channels as 3 | 4,
304
+ }
305
+ const results = await instance.context.img2img({
306
+ initImage,
307
+ prompt: request.prompt,
308
+ width: request.width || 512,
309
+ height: request.height || 512,
310
+ batchCount: request.batchCount,
311
+ sampleMethod: getSamplingMethod(request.samplingMethod || config.samplingMethod),
312
+ cfgScale: request.cfgScale,
313
+ sampleSteps: request.sampleSteps,
314
+ // @ts-ignore
315
+ guidance: request.guidance,
316
+ strength: request.strength,
317
+ styleRatio: request.styleRatio,
318
+ controlStrength: request.controlStrength,
319
+ seed,
320
+ })
321
+
322
+ const images: Image[] = []
323
+ // to sharp
324
+ // const imagePromises = results.map(async (img, idx) => {
325
+ // return await sharp(img.data, {
326
+ // raw: {
327
+ // width: img.width,
328
+ // height: img.height,
329
+ // channels: img.channel,
330
+ // },
331
+ // })
332
+ // })
333
+
334
+ for (const [idx, img] of results.entries()) {
335
+ console.debug('img', {
336
+ id: idx,
337
+ width: img.width,
338
+ height: img.height,
339
+ channels: img.channel,
340
+ })
341
+
342
+ images.push({
343
+ handle: sharp(img.data, {
344
+ raw: {
345
+ width: img.width,
346
+ height: img.height,
347
+ channels: img.channel,
348
+ },
349
+ }),
350
+ width: img.width,
351
+ height: img.height,
352
+ channels: img.channel,
353
+ })
354
+
355
+ // images.push({
356
+ // data: img.data,
357
+ // width: img.width,
358
+ // height: img.height,
359
+ // channels: img.channel,
360
+ // })
361
+ }
362
+ if (!images.length) {
363
+ throw new Error('No images generated')
364
+ }
365
+ return {
366
+ images: images,
367
+ seed,
368
+ }
369
+ }
@@ -0,0 +1,54 @@
1
+ export type StableDiffusionWeightType =
2
+ | 'f32'
3
+ | 'f16'
4
+ | 'q4_0'
5
+ | 'q4_1'
6
+ | 'q5_0'
7
+ | 'q5_1'
8
+ | 'q8_0'
9
+ | 'q2_k'
10
+ | 'q3_k'
11
+ | 'q4_k'
12
+ | 'q5_k'
13
+ | 'q6_k'
14
+ | 'q8_k'
15
+ | 'iq2_xxs'
16
+ | 'iq2_xs'
17
+ | 'iq3_xxs'
18
+ | 'iq1_s'
19
+ | 'iq4_nl'
20
+ | 'iq3_s'
21
+ | 'iq2_s'
22
+ | 'iq4_xs'
23
+ | 'i8'
24
+ | 'i16'
25
+ | 'i32'
26
+ | 'i64'
27
+ | 'f64'
28
+ | 'iq1_m'
29
+ | 'bf16'
30
+ | 'q4_0_4_4'
31
+ | 'q4_0_4_8'
32
+ | 'q4_0_8_8'
33
+ | (string & {})
34
+
35
+ export type StableDiffusionSchedule =
36
+ | 'discrete'
37
+ | 'karras'
38
+ | 'exponential'
39
+ | 'ays'
40
+ | 'gits'
41
+ | (string & {})
42
+
43
+ export type StableDiffusionSamplingMethod =
44
+ | 'euler'
45
+ | 'euler_a'
46
+ | 'lcm'
47
+ | 'heun'
48
+ | 'dpm2'
49
+ | 'dpm++2s_a'
50
+ | 'dpm++2m'
51
+ | 'dpm++2mv2'
52
+ | 'ipndm'
53
+ | 'ipndm_v'
54
+ | (string & {})
@@ -0,0 +1,58 @@
1
+ import StableDiffusion from '@lmagder/node-stable-diffusion-cpp'
2
+
3
+ export function parseQuantization(filename: string): string | null {
4
+ // Regular expressions to match different quantization patterns
5
+ const regexPatterns = [
6
+ /q(\d+)_(\d+)/i, // q4_0
7
+ /[-_\.](f16|f32|int8|int4)/i, // f16, f32, int8, int4
8
+ /[-_\.](fp16|fp32)/i, // fp16, fp32
9
+ ]
10
+
11
+ for (const regex of regexPatterns) {
12
+ const match = filename.match(regex)
13
+ if (match) {
14
+ // If there's a match, return the full matched quantization string
15
+ // Remove leading dash if present, convert to uppercase
16
+ return match[0].replace(/^[-_]/, '').replace(/fp/i, 'f').toLowerCase()
17
+ }
18
+ }
19
+ return null
20
+ }
21
+
22
+ export function getWeightType(key: string): number | undefined {
23
+ const weightKey = key.toUpperCase() as keyof typeof StableDiffusion.Type
24
+ if (weightKey in StableDiffusion.Type) {
25
+ return StableDiffusion.Type[weightKey]
26
+ }
27
+ console.warn('Unknown weight type', weightKey)
28
+ return undefined
29
+ }
30
+
31
+ export function getSamplingMethod(method?: string): StableDiffusion.SampleMethod | undefined {
32
+ switch (method) {
33
+ case 'euler':
34
+ return StableDiffusion.SampleMethod.Euler
35
+ case 'euler_a':
36
+ return StableDiffusion.SampleMethod.EulerA
37
+ case 'lcm':
38
+ return StableDiffusion.SampleMethod.LCM
39
+ case 'heun':
40
+ return StableDiffusion.SampleMethod.Heun
41
+ case 'dpm2':
42
+ return StableDiffusion.SampleMethod.DPM2
43
+ case 'dpm++2s_a':
44
+ return StableDiffusion.SampleMethod.DPMPP2SA
45
+ case 'dpm++2m':
46
+ return StableDiffusion.SampleMethod.DPMPP2M
47
+ case 'dpm++2mv2':
48
+ return StableDiffusion.SampleMethod.DPMPP2Mv2
49
+ case 'ipndm':
50
+ // @ts-ignore
51
+ return StableDiffusion.SampleMethod.IPNDM
52
+ case 'ipndm_v':
53
+ // @ts-ignore
54
+ return StableDiffusion.SampleMethod.IPNDMV
55
+ }
56
+ console.warn('Unknown sampling method', method)
57
+ return undefined
58
+ }