inference-server 1.0.0-beta.19
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +216 -0
- package/dist/api/openai/enums.d.ts +4 -0
- package/dist/api/openai/enums.js +17 -0
- package/dist/api/openai/enums.js.map +1 -0
- package/dist/api/openai/handlers/chat.d.ts +3 -0
- package/dist/api/openai/handlers/chat.js +358 -0
- package/dist/api/openai/handlers/chat.js.map +1 -0
- package/dist/api/openai/handlers/completions.d.ts +3 -0
- package/dist/api/openai/handlers/completions.js +169 -0
- package/dist/api/openai/handlers/completions.js.map +1 -0
- package/dist/api/openai/handlers/embeddings.d.ts +3 -0
- package/dist/api/openai/handlers/embeddings.js +74 -0
- package/dist/api/openai/handlers/embeddings.js.map +1 -0
- package/dist/api/openai/handlers/images.d.ts +0 -0
- package/dist/api/openai/handlers/images.js +4 -0
- package/dist/api/openai/handlers/images.js.map +1 -0
- package/dist/api/openai/handlers/models.d.ts +3 -0
- package/dist/api/openai/handlers/models.js +23 -0
- package/dist/api/openai/handlers/models.js.map +1 -0
- package/dist/api/openai/handlers/transcription.d.ts +0 -0
- package/dist/api/openai/handlers/transcription.js +4 -0
- package/dist/api/openai/handlers/transcription.js.map +1 -0
- package/dist/api/openai/index.d.ts +7 -0
- package/dist/api/openai/index.js +14 -0
- package/dist/api/openai/index.js.map +1 -0
- package/dist/api/parseJSONRequestBody.d.ts +2 -0
- package/dist/api/parseJSONRequestBody.js +24 -0
- package/dist/api/parseJSONRequestBody.js.map +1 -0
- package/dist/api/v1/index.d.ts +2 -0
- package/dist/api/v1/index.js +29 -0
- package/dist/api/v1/index.js.map +1 -0
- package/dist/cli.d.ts +1 -0
- package/dist/cli.js +10 -0
- package/dist/cli.js.map +1 -0
- package/dist/engines/gpt4all/engine.d.ts +34 -0
- package/dist/engines/gpt4all/engine.js +357 -0
- package/dist/engines/gpt4all/engine.js.map +1 -0
- package/dist/engines/gpt4all/util.d.ts +3 -0
- package/dist/engines/gpt4all/util.js +29 -0
- package/dist/engines/gpt4all/util.js.map +1 -0
- package/dist/engines/index.d.ts +19 -0
- package/dist/engines/index.js +21 -0
- package/dist/engines/index.js.map +1 -0
- package/dist/engines/node-llama-cpp/engine.d.ts +49 -0
- package/dist/engines/node-llama-cpp/engine.js +666 -0
- package/dist/engines/node-llama-cpp/engine.js.map +1 -0
- package/dist/engines/node-llama-cpp/types.d.ts +13 -0
- package/dist/engines/node-llama-cpp/types.js +2 -0
- package/dist/engines/node-llama-cpp/types.js.map +1 -0
- package/dist/engines/node-llama-cpp/util.d.ts +15 -0
- package/dist/engines/node-llama-cpp/util.js +84 -0
- package/dist/engines/node-llama-cpp/util.js.map +1 -0
- package/dist/engines/node-llama-cpp/validateModelFile.d.ts +8 -0
- package/dist/engines/node-llama-cpp/validateModelFile.js +36 -0
- package/dist/engines/node-llama-cpp/validateModelFile.js.map +1 -0
- package/dist/engines/stable-diffusion-cpp/engine.d.ts +90 -0
- package/dist/engines/stable-diffusion-cpp/engine.js +294 -0
- package/dist/engines/stable-diffusion-cpp/engine.js.map +1 -0
- package/dist/engines/stable-diffusion-cpp/types.d.ts +3 -0
- package/dist/engines/stable-diffusion-cpp/types.js +2 -0
- package/dist/engines/stable-diffusion-cpp/types.js.map +1 -0
- package/dist/engines/stable-diffusion-cpp/util.d.ts +4 -0
- package/dist/engines/stable-diffusion-cpp/util.js +55 -0
- package/dist/engines/stable-diffusion-cpp/util.js.map +1 -0
- package/dist/engines/stable-diffusion-cpp/validateModelFiles.d.ts +19 -0
- package/dist/engines/stable-diffusion-cpp/validateModelFiles.js +91 -0
- package/dist/engines/stable-diffusion-cpp/validateModelFiles.js.map +1 -0
- package/dist/engines/transformers-js/engine.d.ts +37 -0
- package/dist/engines/transformers-js/engine.js +538 -0
- package/dist/engines/transformers-js/engine.js.map +1 -0
- package/dist/engines/transformers-js/types.d.ts +7 -0
- package/dist/engines/transformers-js/types.js +2 -0
- package/dist/engines/transformers-js/types.js.map +1 -0
- package/dist/engines/transformers-js/util.d.ts +7 -0
- package/dist/engines/transformers-js/util.js +36 -0
- package/dist/engines/transformers-js/util.js.map +1 -0
- package/dist/engines/transformers-js/validateModelFiles.d.ts +17 -0
- package/dist/engines/transformers-js/validateModelFiles.js +133 -0
- package/dist/engines/transformers-js/validateModelFiles.js.map +1 -0
- package/dist/experiments/ChatWithVision.d.ts +11 -0
- package/dist/experiments/ChatWithVision.js +91 -0
- package/dist/experiments/ChatWithVision.js.map +1 -0
- package/dist/experiments/StableDiffPromptGenerator.d.ts +0 -0
- package/dist/experiments/StableDiffPromptGenerator.js +4 -0
- package/dist/experiments/StableDiffPromptGenerator.js.map +1 -0
- package/dist/experiments/VoiceFunctionCall.d.ts +18 -0
- package/dist/experiments/VoiceFunctionCall.js +51 -0
- package/dist/experiments/VoiceFunctionCall.js.map +1 -0
- package/dist/http.d.ts +19 -0
- package/dist/http.js +54 -0
- package/dist/http.js.map +1 -0
- package/dist/index.d.ts +7 -0
- package/dist/index.js +8 -0
- package/dist/index.js.map +1 -0
- package/dist/instance.d.ts +88 -0
- package/dist/instance.js +594 -0
- package/dist/instance.js.map +1 -0
- package/dist/lib/acquireFileLock.d.ts +7 -0
- package/dist/lib/acquireFileLock.js +38 -0
- package/dist/lib/acquireFileLock.js.map +1 -0
- package/dist/lib/calculateContextIdentity.d.ts +7 -0
- package/dist/lib/calculateContextIdentity.js +39 -0
- package/dist/lib/calculateContextIdentity.js.map +1 -0
- package/dist/lib/calculateFileChecksum.d.ts +1 -0
- package/dist/lib/calculateFileChecksum.js +16 -0
- package/dist/lib/calculateFileChecksum.js.map +1 -0
- package/dist/lib/copyDirectory.d.ts +6 -0
- package/dist/lib/copyDirectory.js +27 -0
- package/dist/lib/copyDirectory.js.map +1 -0
- package/dist/lib/decodeAudio.d.ts +1 -0
- package/dist/lib/decodeAudio.js +26 -0
- package/dist/lib/decodeAudio.js.map +1 -0
- package/dist/lib/downloadModelFile.d.ts +10 -0
- package/dist/lib/downloadModelFile.js +58 -0
- package/dist/lib/downloadModelFile.js.map +1 -0
- package/dist/lib/flattenMessageTextContent.d.ts +2 -0
- package/dist/lib/flattenMessageTextContent.js +11 -0
- package/dist/lib/flattenMessageTextContent.js.map +1 -0
- package/dist/lib/getCacheDirPath.d.ts +12 -0
- package/dist/lib/getCacheDirPath.js +31 -0
- package/dist/lib/getCacheDirPath.js.map +1 -0
- package/dist/lib/loadImage.d.ts +12 -0
- package/dist/lib/loadImage.js +30 -0
- package/dist/lib/loadImage.js.map +1 -0
- package/dist/lib/logger.d.ts +12 -0
- package/dist/lib/logger.js +98 -0
- package/dist/lib/logger.js.map +1 -0
- package/dist/lib/math.d.ts +7 -0
- package/dist/lib/math.js +30 -0
- package/dist/lib/math.js.map +1 -0
- package/dist/lib/resolveModelFileLocation.d.ts +15 -0
- package/dist/lib/resolveModelFileLocation.js +41 -0
- package/dist/lib/resolveModelFileLocation.js.map +1 -0
- package/dist/lib/util.d.ts +7 -0
- package/dist/lib/util.js +61 -0
- package/dist/lib/util.js.map +1 -0
- package/dist/lib/validateModelFile.d.ts +9 -0
- package/dist/lib/validateModelFile.js +62 -0
- package/dist/lib/validateModelFile.js.map +1 -0
- package/dist/lib/validateModelOptions.d.ts +3 -0
- package/dist/lib/validateModelOptions.js +23 -0
- package/dist/lib/validateModelOptions.js.map +1 -0
- package/dist/pool.d.ts +61 -0
- package/dist/pool.js +512 -0
- package/dist/pool.js.map +1 -0
- package/dist/server.d.ts +59 -0
- package/dist/server.js +221 -0
- package/dist/server.js.map +1 -0
- package/dist/standalone.d.ts +1 -0
- package/dist/standalone.js +306 -0
- package/dist/standalone.js.map +1 -0
- package/dist/store.d.ts +60 -0
- package/dist/store.js +203 -0
- package/dist/store.js.map +1 -0
- package/dist/types/completions.d.ts +57 -0
- package/dist/types/completions.js +2 -0
- package/dist/types/completions.js.map +1 -0
- package/dist/types/index.d.ts +326 -0
- package/dist/types/index.js +2 -0
- package/dist/types/index.js.map +1 -0
- package/docs/engines.md +28 -0
- package/docs/gpu.md +72 -0
- package/docs/http-api.md +147 -0
- package/examples/all-options.js +108 -0
- package/examples/chat-cli.js +56 -0
- package/examples/chat-server.js +65 -0
- package/examples/concurrency.js +70 -0
- package/examples/express.js +70 -0
- package/examples/pool.js +91 -0
- package/package.json +113 -0
- package/src/api/openai/enums.ts +20 -0
- package/src/api/openai/handlers/chat.ts +408 -0
- package/src/api/openai/handlers/completions.ts +196 -0
- package/src/api/openai/handlers/embeddings.ts +92 -0
- package/src/api/openai/handlers/images.ts +3 -0
- package/src/api/openai/handlers/models.ts +33 -0
- package/src/api/openai/handlers/transcription.ts +2 -0
- package/src/api/openai/index.ts +16 -0
- package/src/api/parseJSONRequestBody.ts +26 -0
- package/src/api/v1/DRAFT.md +16 -0
- package/src/api/v1/index.ts +37 -0
- package/src/cli.ts +9 -0
- package/src/engines/gpt4all/engine.ts +441 -0
- package/src/engines/gpt4all/util.ts +31 -0
- package/src/engines/index.ts +28 -0
- package/src/engines/node-llama-cpp/engine.ts +811 -0
- package/src/engines/node-llama-cpp/types.ts +17 -0
- package/src/engines/node-llama-cpp/util.ts +126 -0
- package/src/engines/node-llama-cpp/validateModelFile.ts +46 -0
- package/src/engines/stable-diffusion-cpp/engine.ts +369 -0
- package/src/engines/stable-diffusion-cpp/types.ts +54 -0
- package/src/engines/stable-diffusion-cpp/util.ts +58 -0
- package/src/engines/stable-diffusion-cpp/validateModelFiles.ts +119 -0
- package/src/engines/transformers-js/engine.ts +659 -0
- package/src/engines/transformers-js/types.ts +25 -0
- package/src/engines/transformers-js/util.ts +40 -0
- package/src/engines/transformers-js/validateModelFiles.ts +168 -0
- package/src/experiments/ChatWithVision.ts +103 -0
- package/src/experiments/StableDiffPromptGenerator.ts +2 -0
- package/src/experiments/VoiceFunctionCall.ts +71 -0
- package/src/http.ts +72 -0
- package/src/index.ts +7 -0
- package/src/instance.ts +723 -0
- package/src/lib/acquireFileLock.ts +38 -0
- package/src/lib/calculateContextIdentity.ts +53 -0
- package/src/lib/calculateFileChecksum.ts +18 -0
- package/src/lib/copyDirectory.ts +29 -0
- package/src/lib/decodeAudio.ts +39 -0
- package/src/lib/downloadModelFile.ts +70 -0
- package/src/lib/flattenMessageTextContent.ts +19 -0
- package/src/lib/getCacheDirPath.ts +34 -0
- package/src/lib/loadImage.ts +46 -0
- package/src/lib/logger.ts +112 -0
- package/src/lib/math.ts +31 -0
- package/src/lib/resolveModelFileLocation.ts +49 -0
- package/src/lib/util.ts +75 -0
- package/src/lib/validateModelFile.ts +71 -0
- package/src/lib/validateModelOptions.ts +31 -0
- package/src/pool.ts +651 -0
- package/src/server.ts +270 -0
- package/src/standalone.ts +320 -0
- package/src/store.ts +278 -0
- package/src/types/completions.ts +86 -0
- package/src/types/index.ts +488 -0
- package/tsconfig.json +29 -0
- package/tsconfig.release.json +11 -0
- package/vitest.config.ts +18 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import {
|
|
2
|
+
ChatHistoryItem,
|
|
3
|
+
ChatModelFunctions,
|
|
4
|
+
LlamaChatResponse,
|
|
5
|
+
LlamaChatResponseFunctionCall,
|
|
6
|
+
} from 'node-llama-cpp'
|
|
7
|
+
|
|
8
|
+
export interface LlamaChatResult<T extends ChatModelFunctions = any> {
|
|
9
|
+
responseText: string | null
|
|
10
|
+
functionCalls?: LlamaChatResponseFunctionCall<T>[]
|
|
11
|
+
stopReason: LlamaChatResponse['metadata']['stopReason']
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
export type ContextShiftStrategy = ((options: {
|
|
15
|
+
chatHistory: ChatHistoryItem[]
|
|
16
|
+
metadata: any
|
|
17
|
+
}) => { chatHistory: ChatHistoryItem[]; metadata: any }) | null
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import fs from 'node:fs'
|
|
2
|
+
import path from 'node:path'
|
|
3
|
+
import {
|
|
4
|
+
ChatHistoryItem,
|
|
5
|
+
ChatModelResponse,
|
|
6
|
+
LlamaTextJSON,
|
|
7
|
+
ChatModelFunctionCall,
|
|
8
|
+
} from 'node-llama-cpp'
|
|
9
|
+
import { CompletionFinishReason, ChatMessage } from '#package/types/index.js'
|
|
10
|
+
import { flattenMessageTextContent } from '#package/lib/flattenMessageTextContent.js'
|
|
11
|
+
import { LlamaChatResult } from './types.js'
|
|
12
|
+
|
|
13
|
+
export function mapFinishReason(
|
|
14
|
+
nodeLlamaCppFinishReason: LlamaChatResult['stopReason'],
|
|
15
|
+
): CompletionFinishReason {
|
|
16
|
+
switch (nodeLlamaCppFinishReason) {
|
|
17
|
+
case 'functionCalls':
|
|
18
|
+
return 'toolCalls'
|
|
19
|
+
case 'stopGenerationTrigger':
|
|
20
|
+
return 'stopTrigger'
|
|
21
|
+
case 'customStopTrigger':
|
|
22
|
+
return 'stopTrigger'
|
|
23
|
+
default:
|
|
24
|
+
return nodeLlamaCppFinishReason
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
export function addFunctionCallToChatHistory({
|
|
29
|
+
chatHistory,
|
|
30
|
+
functionName,
|
|
31
|
+
functionDescription,
|
|
32
|
+
callParams,
|
|
33
|
+
callResult,
|
|
34
|
+
rawCall,
|
|
35
|
+
startsNewChunk
|
|
36
|
+
}: {
|
|
37
|
+
chatHistory: ChatHistoryItem[],
|
|
38
|
+
functionName: string,
|
|
39
|
+
functionDescription?: string,
|
|
40
|
+
callParams: any,
|
|
41
|
+
callResult: any,
|
|
42
|
+
rawCall?: LlamaTextJSON,
|
|
43
|
+
startsNewChunk?: boolean
|
|
44
|
+
}) {
|
|
45
|
+
const newChatHistory = chatHistory.slice();
|
|
46
|
+
if (newChatHistory.length === 0 || newChatHistory[newChatHistory.length - 1]!.type !== "model")
|
|
47
|
+
newChatHistory.push({
|
|
48
|
+
type: "model",
|
|
49
|
+
response: []
|
|
50
|
+
});
|
|
51
|
+
|
|
52
|
+
const lastModelResponseItem = newChatHistory[newChatHistory.length - 1] as ChatModelResponse;
|
|
53
|
+
const newLastModelResponseItem = {...lastModelResponseItem};
|
|
54
|
+
newChatHistory[newChatHistory.length - 1] = newLastModelResponseItem;
|
|
55
|
+
|
|
56
|
+
const modelResponse = newLastModelResponseItem.response.slice();
|
|
57
|
+
newLastModelResponseItem.response = modelResponse;
|
|
58
|
+
|
|
59
|
+
const functionCall: ChatModelFunctionCall = {
|
|
60
|
+
type: "functionCall",
|
|
61
|
+
name: functionName,
|
|
62
|
+
description: functionDescription,
|
|
63
|
+
params: callParams,
|
|
64
|
+
result: callResult,
|
|
65
|
+
rawCall
|
|
66
|
+
};
|
|
67
|
+
|
|
68
|
+
if (startsNewChunk)
|
|
69
|
+
functionCall.startsNewChunk = true;
|
|
70
|
+
|
|
71
|
+
modelResponse.push(functionCall);
|
|
72
|
+
|
|
73
|
+
return newChatHistory;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
export function createChatMessageArray(
|
|
77
|
+
messages: ChatMessage[],
|
|
78
|
+
): ChatHistoryItem[] {
|
|
79
|
+
const items: ChatHistoryItem[] = []
|
|
80
|
+
let systemPrompt: string | undefined
|
|
81
|
+
for (const message of messages) {
|
|
82
|
+
if (message.role === 'user') {
|
|
83
|
+
items.push({
|
|
84
|
+
type: 'user',
|
|
85
|
+
text: flattenMessageTextContent(message.content),
|
|
86
|
+
})
|
|
87
|
+
} else if (message.role === 'assistant') {
|
|
88
|
+
items.push({
|
|
89
|
+
type: 'model',
|
|
90
|
+
response: [message.content],
|
|
91
|
+
})
|
|
92
|
+
} else if (message.role === 'system') {
|
|
93
|
+
if (systemPrompt) {
|
|
94
|
+
systemPrompt += '\n\n' + flattenMessageTextContent(message.content)
|
|
95
|
+
} else {
|
|
96
|
+
systemPrompt = flattenMessageTextContent(message.content)
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
if (systemPrompt) {
|
|
102
|
+
items.unshift({
|
|
103
|
+
type: 'system',
|
|
104
|
+
text: systemPrompt,
|
|
105
|
+
})
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
return items
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
export async function readGBNFFiles(directoryPath: string) {
|
|
112
|
+
const gbnfFiles = fs
|
|
113
|
+
.readdirSync(directoryPath)
|
|
114
|
+
.filter((f) => f.endsWith('.gbnf'))
|
|
115
|
+
const fileContents = await Promise.all(
|
|
116
|
+
gbnfFiles.map((file) =>
|
|
117
|
+
fs.promises.readFile(path.join(directoryPath, file), 'utf-8'),
|
|
118
|
+
),
|
|
119
|
+
)
|
|
120
|
+
return gbnfFiles.reduce((acc, file, i) => {
|
|
121
|
+
acc[file.replace('.gbnf', '')] = fileContents[i]
|
|
122
|
+
return acc
|
|
123
|
+
}, {} as Record<string, string>)
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
import fs from 'node:fs'
|
|
4
|
+
import { calculateFileChecksum } from '#package/lib/calculateFileChecksum.js'
|
|
5
|
+
import { resolveModelFileLocation } from '#package/lib/resolveModelFileLocation.js'
|
|
6
|
+
|
|
7
|
+
interface ValidatableModelConfig {
|
|
8
|
+
url?: string
|
|
9
|
+
location?: string
|
|
10
|
+
modelsCachePath: string
|
|
11
|
+
sha256?: string
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
export async function validateModelFile(config: ValidatableModelConfig): Promise<string | undefined> {
|
|
15
|
+
const fileLocation = resolveModelFileLocation({
|
|
16
|
+
url: config.url,
|
|
17
|
+
filePath: config.location,
|
|
18
|
+
modelsCachePath: config.modelsCachePath,
|
|
19
|
+
})
|
|
20
|
+
if (!fs.existsSync(fileLocation)) {
|
|
21
|
+
return `Model file missing at ${fileLocation}`
|
|
22
|
+
}
|
|
23
|
+
const ipullFile = fileLocation + '.ipull'
|
|
24
|
+
let validatedChecksum = false
|
|
25
|
+
if (fs.existsSync(ipullFile)) {
|
|
26
|
+
// if we have a valid file at the download destination, we can remove the ipull file
|
|
27
|
+
if (config.sha256) {
|
|
28
|
+
const fileHash = await calculateFileChecksum(fileLocation, 'sha256')
|
|
29
|
+
if (fileHash === config.sha256) {
|
|
30
|
+
fs.unlinkSync(ipullFile)
|
|
31
|
+
validatedChecksum = true
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
if (!validatedChecksum) {
|
|
35
|
+
return `Model file with incomplete download`
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
if (!validatedChecksum && config.sha256) {
|
|
40
|
+
const fileHash = await calculateFileChecksum(fileLocation, 'sha256')
|
|
41
|
+
if (fileHash !== config.sha256) {
|
|
42
|
+
return `File sha256 checksum mismatch: expected ${config.sha256} got ${fileHash} at ${fileLocation}`
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
return undefined
|
|
46
|
+
}
|
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
import StableDiffusion from '@lmagder/node-stable-diffusion-cpp'
|
|
2
|
+
import { gguf } from '@huggingface/gguf'
|
|
3
|
+
import sharp from 'sharp'
|
|
4
|
+
import fs from 'node:fs'
|
|
5
|
+
import path from 'node:path'
|
|
6
|
+
import {
|
|
7
|
+
EngineContext,
|
|
8
|
+
FileDownloadProgress,
|
|
9
|
+
ModelConfig,
|
|
10
|
+
EngineTextToImageResult,
|
|
11
|
+
ModelFileSource,
|
|
12
|
+
EngineTextToImageArgs,
|
|
13
|
+
Image,
|
|
14
|
+
EngineImageToImageArgs,
|
|
15
|
+
} from '#package/types/index.js'
|
|
16
|
+
import { LogLevel, LogLevels } from '#package/lib/logger.js'
|
|
17
|
+
import { downloadModelFile } from '#package/lib/downloadModelFile.js'
|
|
18
|
+
import { resolveModelFileLocation } from '#package/lib/resolveModelFileLocation.js'
|
|
19
|
+
import { acquireFileLock } from '#package/lib/acquireFileLock.js'
|
|
20
|
+
import { getRandomNumber } from '#package/lib/util.js'
|
|
21
|
+
import { StableDiffusionSamplingMethod, StableDiffusionSchedule, StableDiffusionWeightType } from './types.js'
|
|
22
|
+
import { validateModelFiles, ModelValidationResult } from './validateModelFiles.js'
|
|
23
|
+
import { parseQuantization, getWeightType, getSamplingMethod } from './util.js'
|
|
24
|
+
|
|
25
|
+
export interface StableDiffusionInstance {
|
|
26
|
+
context: StableDiffusion.Context
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
export interface StableDiffusionModelConfig extends ModelConfig {
|
|
30
|
+
location: string
|
|
31
|
+
sha256?: string
|
|
32
|
+
clipL?: ModelFileSource
|
|
33
|
+
clipG?: ModelFileSource
|
|
34
|
+
vae?: ModelFileSource
|
|
35
|
+
t5xxl?: ModelFileSource
|
|
36
|
+
controlNet?: ModelFileSource
|
|
37
|
+
taesd?: ModelFileSource
|
|
38
|
+
diffusionModel?: boolean
|
|
39
|
+
model?: ModelFileSource
|
|
40
|
+
loras?: ModelFileSource[]
|
|
41
|
+
samplingMethod?: StableDiffusionSamplingMethod
|
|
42
|
+
weightType?: StableDiffusionWeightType
|
|
43
|
+
schedule?: StableDiffusionSchedule
|
|
44
|
+
device?: {
|
|
45
|
+
gpu?: boolean | 'auto' | (string & {})
|
|
46
|
+
cpuThreads?: number
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
interface StableDiffusionModelMeta {
|
|
51
|
+
gguf: any
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
export const autoGpu = true
|
|
55
|
+
|
|
56
|
+
export async function prepareModel(
|
|
57
|
+
{ config, log }: EngineContext<StableDiffusionModelConfig, StableDiffusionModelMeta>,
|
|
58
|
+
onProgress?: (progress: FileDownloadProgress) => void,
|
|
59
|
+
signal?: AbortSignal,
|
|
60
|
+
) {
|
|
61
|
+
fs.mkdirSync(path.dirname(config.location), { recursive: true })
|
|
62
|
+
const releaseFileLock = await acquireFileLock(config.location)
|
|
63
|
+
if (signal?.aborted) {
|
|
64
|
+
releaseFileLock()
|
|
65
|
+
return
|
|
66
|
+
}
|
|
67
|
+
log(LogLevels.info, `Preparing stable-diffusion model at ${config.location}`, {
|
|
68
|
+
model: config.id,
|
|
69
|
+
})
|
|
70
|
+
|
|
71
|
+
const downloadModel = (url: string, validationResult: ModelValidationResult) => {
|
|
72
|
+
log(LogLevels.info, `${validationResult.message} - Downloading model files`, {
|
|
73
|
+
model: config.id,
|
|
74
|
+
url: config.url,
|
|
75
|
+
location: config.location,
|
|
76
|
+
errors: validationResult.errors,
|
|
77
|
+
})
|
|
78
|
+
const downloadPromises = []
|
|
79
|
+
if (validationResult.errors.model && config.location) {
|
|
80
|
+
downloadPromises.push(
|
|
81
|
+
downloadModelFile({
|
|
82
|
+
url: url,
|
|
83
|
+
filePath: config.location,
|
|
84
|
+
modelsCachePath: config.modelsCachePath,
|
|
85
|
+
onProgress,
|
|
86
|
+
signal,
|
|
87
|
+
}),
|
|
88
|
+
)
|
|
89
|
+
}
|
|
90
|
+
const pushDownload = (src: ModelFileSource) => {
|
|
91
|
+
if (!src.url) {
|
|
92
|
+
return
|
|
93
|
+
}
|
|
94
|
+
downloadPromises.push(
|
|
95
|
+
downloadModelFile({
|
|
96
|
+
url: src.url,
|
|
97
|
+
filePath: src.file,
|
|
98
|
+
modelsCachePath: config.modelsCachePath,
|
|
99
|
+
onProgress,
|
|
100
|
+
signal,
|
|
101
|
+
}),
|
|
102
|
+
)
|
|
103
|
+
}
|
|
104
|
+
if (validationResult.errors.clipG && config.clipG) {
|
|
105
|
+
pushDownload(config.clipG)
|
|
106
|
+
}
|
|
107
|
+
if (validationResult.errors.clipL && config.clipL) {
|
|
108
|
+
pushDownload(config.clipL)
|
|
109
|
+
}
|
|
110
|
+
if (validationResult.errors.vae && config.vae) {
|
|
111
|
+
pushDownload(config.vae)
|
|
112
|
+
}
|
|
113
|
+
if (validationResult.errors.t5xxl && config.t5xxl) {
|
|
114
|
+
pushDownload(config.t5xxl)
|
|
115
|
+
}
|
|
116
|
+
if (validationResult.errors.controlNet && config.controlNet) {
|
|
117
|
+
pushDownload(config.controlNet)
|
|
118
|
+
}
|
|
119
|
+
if (validationResult.errors.taesd && config.taesd) {
|
|
120
|
+
pushDownload(config.taesd)
|
|
121
|
+
}
|
|
122
|
+
if (config.loras) {
|
|
123
|
+
for (const lora of config.loras) {
|
|
124
|
+
if (!lora.url) {
|
|
125
|
+
continue
|
|
126
|
+
}
|
|
127
|
+
pushDownload(lora)
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
return Promise.all(downloadPromises)
|
|
131
|
+
}
|
|
132
|
+
try {
|
|
133
|
+
if (signal?.aborted) {
|
|
134
|
+
return
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
const validationResults = await validateModelFiles(config)
|
|
138
|
+
if (signal?.aborted) {
|
|
139
|
+
return
|
|
140
|
+
}
|
|
141
|
+
if (validationResults) {
|
|
142
|
+
if (config.url) {
|
|
143
|
+
await downloadModel(config.url, validationResults)
|
|
144
|
+
} else {
|
|
145
|
+
throw new Error(`${validationResults.message} - No URL provided`)
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
const finalValidationError = await validateModelFiles(config)
|
|
150
|
+
if (finalValidationError) {
|
|
151
|
+
throw new Error(`Downloaded files are invalid: ${finalValidationError}`)
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
const result: any = {}
|
|
155
|
+
if (config.location.endsWith('.gguf')) {
|
|
156
|
+
const { metadata, tensorInfos } = await gguf(config.location, {
|
|
157
|
+
allowLocalFile: true,
|
|
158
|
+
})
|
|
159
|
+
result.gguf = metadata
|
|
160
|
+
}
|
|
161
|
+
return result
|
|
162
|
+
} catch (error) {
|
|
163
|
+
throw error
|
|
164
|
+
} finally {
|
|
165
|
+
releaseFileLock()
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
export async function createInstance({ config, log }: EngineContext<StableDiffusionModelConfig>, signal?: AbortSignal) {
|
|
170
|
+
log(LogLevels.debug, 'Load Stable Diffusion model', config)
|
|
171
|
+
const handleLog = (level: string, message: string) => {
|
|
172
|
+
log(level as LogLevel, message)
|
|
173
|
+
}
|
|
174
|
+
const handleProgress = (step: number, steps: number, time: number) => {
|
|
175
|
+
log(LogLevels.debug, `Progress: ${step}/${steps} (${time}ms)`)
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
const resolveComponentLocation = (src?: ModelFileSource) => {
|
|
179
|
+
if (src) {
|
|
180
|
+
return resolveModelFileLocation({
|
|
181
|
+
url: src.url,
|
|
182
|
+
filePath: src.file,
|
|
183
|
+
modelsCachePath: config.modelsCachePath,
|
|
184
|
+
})
|
|
185
|
+
}
|
|
186
|
+
return undefined
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
const vaeFilePath = resolveComponentLocation(config.vae)
|
|
190
|
+
const clipLFilePath = resolveComponentLocation(config.clipL)
|
|
191
|
+
const clipGFilePath = resolveComponentLocation(config.clipG)
|
|
192
|
+
const t5xxlFilePath = resolveComponentLocation(config.t5xxl)
|
|
193
|
+
const controlNetFilePath = resolveComponentLocation(config.controlNet)
|
|
194
|
+
const taesdFilePath = resolveComponentLocation(config.taesd)
|
|
195
|
+
|
|
196
|
+
let weightType = config.weightType ? getWeightType(config.weightType) : undefined
|
|
197
|
+
if (typeof weightType === 'undefined') {
|
|
198
|
+
const quantization = parseQuantization(config.location)
|
|
199
|
+
if (quantization) {
|
|
200
|
+
weightType = getWeightType(quantization)
|
|
201
|
+
}
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
if (typeof weightType === 'undefined') {
|
|
205
|
+
log(LogLevels.warn, 'Failed to parse model weight type (quantization) from file name, falling back to f32', {
|
|
206
|
+
file: config.location,
|
|
207
|
+
})
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
const loraDir = path.join(path.dirname(config.location), 'loras')
|
|
211
|
+
const contextParams = {
|
|
212
|
+
model: !config.diffusionModel ? config.location : undefined,
|
|
213
|
+
diffusionModel: config.diffusionModel ? config.location : undefined,
|
|
214
|
+
numThreads: config.device?.cpuThreads,
|
|
215
|
+
vae: vaeFilePath,
|
|
216
|
+
clipL: clipLFilePath,
|
|
217
|
+
clipG: clipGFilePath,
|
|
218
|
+
t5xxl: t5xxlFilePath,
|
|
219
|
+
controlNet: controlNetFilePath,
|
|
220
|
+
taesd: taesdFilePath,
|
|
221
|
+
weightType: weightType,
|
|
222
|
+
loraDir: loraDir,
|
|
223
|
+
// TODO how to expose?
|
|
224
|
+
// keepClipOnCpu: true,
|
|
225
|
+
// keepControlNetOnCpu: true,
|
|
226
|
+
// keepVaeOnCpu: true,
|
|
227
|
+
}
|
|
228
|
+
log(LogLevels.debug, 'Creating context with', contextParams)
|
|
229
|
+
const context = await StableDiffusion.createContext(
|
|
230
|
+
// @ts-ignore
|
|
231
|
+
contextParams,
|
|
232
|
+
handleLog,
|
|
233
|
+
handleProgress,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
return {
|
|
237
|
+
context,
|
|
238
|
+
}
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
export async function processTextToImageTask(
|
|
242
|
+
{ request, config, log }: EngineTextToImageArgs<StableDiffusionModelConfig>,
|
|
243
|
+
instance: StableDiffusionInstance,
|
|
244
|
+
signal?: AbortSignal,
|
|
245
|
+
): Promise<EngineTextToImageResult> {
|
|
246
|
+
const seed = request.seed ?? getRandomNumber(0, 1000000)
|
|
247
|
+
const results = await instance.context.txt2img({
|
|
248
|
+
prompt: request.prompt,
|
|
249
|
+
negativePrompt: request.negativePrompt,
|
|
250
|
+
width: request.width || 512,
|
|
251
|
+
height: request.height || 512,
|
|
252
|
+
batchCount: request.batchCount,
|
|
253
|
+
sampleMethod: getSamplingMethod(request.samplingMethod || config.samplingMethod),
|
|
254
|
+
sampleSteps: request.sampleSteps,
|
|
255
|
+
cfgScale: request.cfgScale,
|
|
256
|
+
// @ts-ignore
|
|
257
|
+
guidance: request.guidance,
|
|
258
|
+
styleRatio: request.styleRatio,
|
|
259
|
+
controlStrength: request.controlStrength,
|
|
260
|
+
normalizeInput: false,
|
|
261
|
+
seed,
|
|
262
|
+
})
|
|
263
|
+
|
|
264
|
+
const images: Image[] = []
|
|
265
|
+
for (const [idx, img] of results.entries()) {
|
|
266
|
+
images.push({
|
|
267
|
+
handle: sharp(img.data, {
|
|
268
|
+
raw: {
|
|
269
|
+
width: img.width,
|
|
270
|
+
height: img.height,
|
|
271
|
+
channels: img.channel,
|
|
272
|
+
},
|
|
273
|
+
}),
|
|
274
|
+
width: img.width,
|
|
275
|
+
height: img.height,
|
|
276
|
+
channels: img.channel,
|
|
277
|
+
})
|
|
278
|
+
}
|
|
279
|
+
if (!images.length) {
|
|
280
|
+
throw new Error('No images generated')
|
|
281
|
+
}
|
|
282
|
+
return {
|
|
283
|
+
images: images,
|
|
284
|
+
seed,
|
|
285
|
+
}
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
export async function processImageToImageTask(
|
|
289
|
+
{ request, config, log }: EngineImageToImageArgs<StableDiffusionModelConfig>,
|
|
290
|
+
instance: StableDiffusionInstance,
|
|
291
|
+
signal?: AbortSignal,
|
|
292
|
+
): Promise<EngineTextToImageResult> {
|
|
293
|
+
const seed = request.seed ?? getRandomNumber(0, 1000000)
|
|
294
|
+
console.debug('processImageToImageTask', {
|
|
295
|
+
width: request.image.width,
|
|
296
|
+
height: request.image.height,
|
|
297
|
+
channel: request.image.channels as 3 | 4,
|
|
298
|
+
})
|
|
299
|
+
const initImage = {
|
|
300
|
+
data: await request.image.handle.raw().toBuffer(),
|
|
301
|
+
width: request.image.width,
|
|
302
|
+
height: request.image.height,
|
|
303
|
+
channel: request.image.channels as 3 | 4,
|
|
304
|
+
}
|
|
305
|
+
const results = await instance.context.img2img({
|
|
306
|
+
initImage,
|
|
307
|
+
prompt: request.prompt,
|
|
308
|
+
width: request.width || 512,
|
|
309
|
+
height: request.height || 512,
|
|
310
|
+
batchCount: request.batchCount,
|
|
311
|
+
sampleMethod: getSamplingMethod(request.samplingMethod || config.samplingMethod),
|
|
312
|
+
cfgScale: request.cfgScale,
|
|
313
|
+
sampleSteps: request.sampleSteps,
|
|
314
|
+
// @ts-ignore
|
|
315
|
+
guidance: request.guidance,
|
|
316
|
+
strength: request.strength,
|
|
317
|
+
styleRatio: request.styleRatio,
|
|
318
|
+
controlStrength: request.controlStrength,
|
|
319
|
+
seed,
|
|
320
|
+
})
|
|
321
|
+
|
|
322
|
+
const images: Image[] = []
|
|
323
|
+
// to sharp
|
|
324
|
+
// const imagePromises = results.map(async (img, idx) => {
|
|
325
|
+
// return await sharp(img.data, {
|
|
326
|
+
// raw: {
|
|
327
|
+
// width: img.width,
|
|
328
|
+
// height: img.height,
|
|
329
|
+
// channels: img.channel,
|
|
330
|
+
// },
|
|
331
|
+
// })
|
|
332
|
+
// })
|
|
333
|
+
|
|
334
|
+
for (const [idx, img] of results.entries()) {
|
|
335
|
+
console.debug('img', {
|
|
336
|
+
id: idx,
|
|
337
|
+
width: img.width,
|
|
338
|
+
height: img.height,
|
|
339
|
+
channels: img.channel,
|
|
340
|
+
})
|
|
341
|
+
|
|
342
|
+
images.push({
|
|
343
|
+
handle: sharp(img.data, {
|
|
344
|
+
raw: {
|
|
345
|
+
width: img.width,
|
|
346
|
+
height: img.height,
|
|
347
|
+
channels: img.channel,
|
|
348
|
+
},
|
|
349
|
+
}),
|
|
350
|
+
width: img.width,
|
|
351
|
+
height: img.height,
|
|
352
|
+
channels: img.channel,
|
|
353
|
+
})
|
|
354
|
+
|
|
355
|
+
// images.push({
|
|
356
|
+
// data: img.data,
|
|
357
|
+
// width: img.width,
|
|
358
|
+
// height: img.height,
|
|
359
|
+
// channels: img.channel,
|
|
360
|
+
// })
|
|
361
|
+
}
|
|
362
|
+
if (!images.length) {
|
|
363
|
+
throw new Error('No images generated')
|
|
364
|
+
}
|
|
365
|
+
return {
|
|
366
|
+
images: images,
|
|
367
|
+
seed,
|
|
368
|
+
}
|
|
369
|
+
}
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
export type StableDiffusionWeightType =
|
|
2
|
+
| 'f32'
|
|
3
|
+
| 'f16'
|
|
4
|
+
| 'q4_0'
|
|
5
|
+
| 'q4_1'
|
|
6
|
+
| 'q5_0'
|
|
7
|
+
| 'q5_1'
|
|
8
|
+
| 'q8_0'
|
|
9
|
+
| 'q2_k'
|
|
10
|
+
| 'q3_k'
|
|
11
|
+
| 'q4_k'
|
|
12
|
+
| 'q5_k'
|
|
13
|
+
| 'q6_k'
|
|
14
|
+
| 'q8_k'
|
|
15
|
+
| 'iq2_xxs'
|
|
16
|
+
| 'iq2_xs'
|
|
17
|
+
| 'iq3_xxs'
|
|
18
|
+
| 'iq1_s'
|
|
19
|
+
| 'iq4_nl'
|
|
20
|
+
| 'iq3_s'
|
|
21
|
+
| 'iq2_s'
|
|
22
|
+
| 'iq4_xs'
|
|
23
|
+
| 'i8'
|
|
24
|
+
| 'i16'
|
|
25
|
+
| 'i32'
|
|
26
|
+
| 'i64'
|
|
27
|
+
| 'f64'
|
|
28
|
+
| 'iq1_m'
|
|
29
|
+
| 'bf16'
|
|
30
|
+
| 'q4_0_4_4'
|
|
31
|
+
| 'q4_0_4_8'
|
|
32
|
+
| 'q4_0_8_8'
|
|
33
|
+
| (string & {})
|
|
34
|
+
|
|
35
|
+
export type StableDiffusionSchedule =
|
|
36
|
+
| 'discrete'
|
|
37
|
+
| 'karras'
|
|
38
|
+
| 'exponential'
|
|
39
|
+
| 'ays'
|
|
40
|
+
| 'gits'
|
|
41
|
+
| (string & {})
|
|
42
|
+
|
|
43
|
+
export type StableDiffusionSamplingMethod =
|
|
44
|
+
| 'euler'
|
|
45
|
+
| 'euler_a'
|
|
46
|
+
| 'lcm'
|
|
47
|
+
| 'heun'
|
|
48
|
+
| 'dpm2'
|
|
49
|
+
| 'dpm++2s_a'
|
|
50
|
+
| 'dpm++2m'
|
|
51
|
+
| 'dpm++2mv2'
|
|
52
|
+
| 'ipndm'
|
|
53
|
+
| 'ipndm_v'
|
|
54
|
+
| (string & {})
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import StableDiffusion from '@lmagder/node-stable-diffusion-cpp'
|
|
2
|
+
|
|
3
|
+
export function parseQuantization(filename: string): string | null {
|
|
4
|
+
// Regular expressions to match different quantization patterns
|
|
5
|
+
const regexPatterns = [
|
|
6
|
+
/q(\d+)_(\d+)/i, // q4_0
|
|
7
|
+
/[-_\.](f16|f32|int8|int4)/i, // f16, f32, int8, int4
|
|
8
|
+
/[-_\.](fp16|fp32)/i, // fp16, fp32
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
for (const regex of regexPatterns) {
|
|
12
|
+
const match = filename.match(regex)
|
|
13
|
+
if (match) {
|
|
14
|
+
// If there's a match, return the full matched quantization string
|
|
15
|
+
// Remove leading dash if present, convert to uppercase
|
|
16
|
+
return match[0].replace(/^[-_]/, '').replace(/fp/i, 'f').toLowerCase()
|
|
17
|
+
}
|
|
18
|
+
}
|
|
19
|
+
return null
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
export function getWeightType(key: string): number | undefined {
|
|
23
|
+
const weightKey = key.toUpperCase() as keyof typeof StableDiffusion.Type
|
|
24
|
+
if (weightKey in StableDiffusion.Type) {
|
|
25
|
+
return StableDiffusion.Type[weightKey]
|
|
26
|
+
}
|
|
27
|
+
console.warn('Unknown weight type', weightKey)
|
|
28
|
+
return undefined
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
export function getSamplingMethod(method?: string): StableDiffusion.SampleMethod | undefined {
|
|
32
|
+
switch (method) {
|
|
33
|
+
case 'euler':
|
|
34
|
+
return StableDiffusion.SampleMethod.Euler
|
|
35
|
+
case 'euler_a':
|
|
36
|
+
return StableDiffusion.SampleMethod.EulerA
|
|
37
|
+
case 'lcm':
|
|
38
|
+
return StableDiffusion.SampleMethod.LCM
|
|
39
|
+
case 'heun':
|
|
40
|
+
return StableDiffusion.SampleMethod.Heun
|
|
41
|
+
case 'dpm2':
|
|
42
|
+
return StableDiffusion.SampleMethod.DPM2
|
|
43
|
+
case 'dpm++2s_a':
|
|
44
|
+
return StableDiffusion.SampleMethod.DPMPP2SA
|
|
45
|
+
case 'dpm++2m':
|
|
46
|
+
return StableDiffusion.SampleMethod.DPMPP2M
|
|
47
|
+
case 'dpm++2mv2':
|
|
48
|
+
return StableDiffusion.SampleMethod.DPMPP2Mv2
|
|
49
|
+
case 'ipndm':
|
|
50
|
+
// @ts-ignore
|
|
51
|
+
return StableDiffusion.SampleMethod.IPNDM
|
|
52
|
+
case 'ipndm_v':
|
|
53
|
+
// @ts-ignore
|
|
54
|
+
return StableDiffusion.SampleMethod.IPNDMV
|
|
55
|
+
}
|
|
56
|
+
console.warn('Unknown sampling method', method)
|
|
57
|
+
return undefined
|
|
58
|
+
}
|