inference-server 1.0.0-beta.19

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (227) hide show
  1. package/README.md +216 -0
  2. package/dist/api/openai/enums.d.ts +4 -0
  3. package/dist/api/openai/enums.js +17 -0
  4. package/dist/api/openai/enums.js.map +1 -0
  5. package/dist/api/openai/handlers/chat.d.ts +3 -0
  6. package/dist/api/openai/handlers/chat.js +358 -0
  7. package/dist/api/openai/handlers/chat.js.map +1 -0
  8. package/dist/api/openai/handlers/completions.d.ts +3 -0
  9. package/dist/api/openai/handlers/completions.js +169 -0
  10. package/dist/api/openai/handlers/completions.js.map +1 -0
  11. package/dist/api/openai/handlers/embeddings.d.ts +3 -0
  12. package/dist/api/openai/handlers/embeddings.js +74 -0
  13. package/dist/api/openai/handlers/embeddings.js.map +1 -0
  14. package/dist/api/openai/handlers/images.d.ts +0 -0
  15. package/dist/api/openai/handlers/images.js +4 -0
  16. package/dist/api/openai/handlers/images.js.map +1 -0
  17. package/dist/api/openai/handlers/models.d.ts +3 -0
  18. package/dist/api/openai/handlers/models.js +23 -0
  19. package/dist/api/openai/handlers/models.js.map +1 -0
  20. package/dist/api/openai/handlers/transcription.d.ts +0 -0
  21. package/dist/api/openai/handlers/transcription.js +4 -0
  22. package/dist/api/openai/handlers/transcription.js.map +1 -0
  23. package/dist/api/openai/index.d.ts +7 -0
  24. package/dist/api/openai/index.js +14 -0
  25. package/dist/api/openai/index.js.map +1 -0
  26. package/dist/api/parseJSONRequestBody.d.ts +2 -0
  27. package/dist/api/parseJSONRequestBody.js +24 -0
  28. package/dist/api/parseJSONRequestBody.js.map +1 -0
  29. package/dist/api/v1/index.d.ts +2 -0
  30. package/dist/api/v1/index.js +29 -0
  31. package/dist/api/v1/index.js.map +1 -0
  32. package/dist/cli.d.ts +1 -0
  33. package/dist/cli.js +10 -0
  34. package/dist/cli.js.map +1 -0
  35. package/dist/engines/gpt4all/engine.d.ts +34 -0
  36. package/dist/engines/gpt4all/engine.js +357 -0
  37. package/dist/engines/gpt4all/engine.js.map +1 -0
  38. package/dist/engines/gpt4all/util.d.ts +3 -0
  39. package/dist/engines/gpt4all/util.js +29 -0
  40. package/dist/engines/gpt4all/util.js.map +1 -0
  41. package/dist/engines/index.d.ts +19 -0
  42. package/dist/engines/index.js +21 -0
  43. package/dist/engines/index.js.map +1 -0
  44. package/dist/engines/node-llama-cpp/engine.d.ts +49 -0
  45. package/dist/engines/node-llama-cpp/engine.js +666 -0
  46. package/dist/engines/node-llama-cpp/engine.js.map +1 -0
  47. package/dist/engines/node-llama-cpp/types.d.ts +13 -0
  48. package/dist/engines/node-llama-cpp/types.js +2 -0
  49. package/dist/engines/node-llama-cpp/types.js.map +1 -0
  50. package/dist/engines/node-llama-cpp/util.d.ts +15 -0
  51. package/dist/engines/node-llama-cpp/util.js +84 -0
  52. package/dist/engines/node-llama-cpp/util.js.map +1 -0
  53. package/dist/engines/node-llama-cpp/validateModelFile.d.ts +8 -0
  54. package/dist/engines/node-llama-cpp/validateModelFile.js +36 -0
  55. package/dist/engines/node-llama-cpp/validateModelFile.js.map +1 -0
  56. package/dist/engines/stable-diffusion-cpp/engine.d.ts +90 -0
  57. package/dist/engines/stable-diffusion-cpp/engine.js +294 -0
  58. package/dist/engines/stable-diffusion-cpp/engine.js.map +1 -0
  59. package/dist/engines/stable-diffusion-cpp/types.d.ts +3 -0
  60. package/dist/engines/stable-diffusion-cpp/types.js +2 -0
  61. package/dist/engines/stable-diffusion-cpp/types.js.map +1 -0
  62. package/dist/engines/stable-diffusion-cpp/util.d.ts +4 -0
  63. package/dist/engines/stable-diffusion-cpp/util.js +55 -0
  64. package/dist/engines/stable-diffusion-cpp/util.js.map +1 -0
  65. package/dist/engines/stable-diffusion-cpp/validateModelFiles.d.ts +19 -0
  66. package/dist/engines/stable-diffusion-cpp/validateModelFiles.js +91 -0
  67. package/dist/engines/stable-diffusion-cpp/validateModelFiles.js.map +1 -0
  68. package/dist/engines/transformers-js/engine.d.ts +37 -0
  69. package/dist/engines/transformers-js/engine.js +538 -0
  70. package/dist/engines/transformers-js/engine.js.map +1 -0
  71. package/dist/engines/transformers-js/types.d.ts +7 -0
  72. package/dist/engines/transformers-js/types.js +2 -0
  73. package/dist/engines/transformers-js/types.js.map +1 -0
  74. package/dist/engines/transformers-js/util.d.ts +7 -0
  75. package/dist/engines/transformers-js/util.js +36 -0
  76. package/dist/engines/transformers-js/util.js.map +1 -0
  77. package/dist/engines/transformers-js/validateModelFiles.d.ts +17 -0
  78. package/dist/engines/transformers-js/validateModelFiles.js +133 -0
  79. package/dist/engines/transformers-js/validateModelFiles.js.map +1 -0
  80. package/dist/experiments/ChatWithVision.d.ts +11 -0
  81. package/dist/experiments/ChatWithVision.js +91 -0
  82. package/dist/experiments/ChatWithVision.js.map +1 -0
  83. package/dist/experiments/StableDiffPromptGenerator.d.ts +0 -0
  84. package/dist/experiments/StableDiffPromptGenerator.js +4 -0
  85. package/dist/experiments/StableDiffPromptGenerator.js.map +1 -0
  86. package/dist/experiments/VoiceFunctionCall.d.ts +18 -0
  87. package/dist/experiments/VoiceFunctionCall.js +51 -0
  88. package/dist/experiments/VoiceFunctionCall.js.map +1 -0
  89. package/dist/http.d.ts +19 -0
  90. package/dist/http.js +54 -0
  91. package/dist/http.js.map +1 -0
  92. package/dist/index.d.ts +7 -0
  93. package/dist/index.js +8 -0
  94. package/dist/index.js.map +1 -0
  95. package/dist/instance.d.ts +88 -0
  96. package/dist/instance.js +594 -0
  97. package/dist/instance.js.map +1 -0
  98. package/dist/lib/acquireFileLock.d.ts +7 -0
  99. package/dist/lib/acquireFileLock.js +38 -0
  100. package/dist/lib/acquireFileLock.js.map +1 -0
  101. package/dist/lib/calculateContextIdentity.d.ts +7 -0
  102. package/dist/lib/calculateContextIdentity.js +39 -0
  103. package/dist/lib/calculateContextIdentity.js.map +1 -0
  104. package/dist/lib/calculateFileChecksum.d.ts +1 -0
  105. package/dist/lib/calculateFileChecksum.js +16 -0
  106. package/dist/lib/calculateFileChecksum.js.map +1 -0
  107. package/dist/lib/copyDirectory.d.ts +6 -0
  108. package/dist/lib/copyDirectory.js +27 -0
  109. package/dist/lib/copyDirectory.js.map +1 -0
  110. package/dist/lib/decodeAudio.d.ts +1 -0
  111. package/dist/lib/decodeAudio.js +26 -0
  112. package/dist/lib/decodeAudio.js.map +1 -0
  113. package/dist/lib/downloadModelFile.d.ts +10 -0
  114. package/dist/lib/downloadModelFile.js +58 -0
  115. package/dist/lib/downloadModelFile.js.map +1 -0
  116. package/dist/lib/flattenMessageTextContent.d.ts +2 -0
  117. package/dist/lib/flattenMessageTextContent.js +11 -0
  118. package/dist/lib/flattenMessageTextContent.js.map +1 -0
  119. package/dist/lib/getCacheDirPath.d.ts +12 -0
  120. package/dist/lib/getCacheDirPath.js +31 -0
  121. package/dist/lib/getCacheDirPath.js.map +1 -0
  122. package/dist/lib/loadImage.d.ts +12 -0
  123. package/dist/lib/loadImage.js +30 -0
  124. package/dist/lib/loadImage.js.map +1 -0
  125. package/dist/lib/logger.d.ts +12 -0
  126. package/dist/lib/logger.js +98 -0
  127. package/dist/lib/logger.js.map +1 -0
  128. package/dist/lib/math.d.ts +7 -0
  129. package/dist/lib/math.js +30 -0
  130. package/dist/lib/math.js.map +1 -0
  131. package/dist/lib/resolveModelFileLocation.d.ts +15 -0
  132. package/dist/lib/resolveModelFileLocation.js +41 -0
  133. package/dist/lib/resolveModelFileLocation.js.map +1 -0
  134. package/dist/lib/util.d.ts +7 -0
  135. package/dist/lib/util.js +61 -0
  136. package/dist/lib/util.js.map +1 -0
  137. package/dist/lib/validateModelFile.d.ts +9 -0
  138. package/dist/lib/validateModelFile.js +62 -0
  139. package/dist/lib/validateModelFile.js.map +1 -0
  140. package/dist/lib/validateModelOptions.d.ts +3 -0
  141. package/dist/lib/validateModelOptions.js +23 -0
  142. package/dist/lib/validateModelOptions.js.map +1 -0
  143. package/dist/pool.d.ts +61 -0
  144. package/dist/pool.js +512 -0
  145. package/dist/pool.js.map +1 -0
  146. package/dist/server.d.ts +59 -0
  147. package/dist/server.js +221 -0
  148. package/dist/server.js.map +1 -0
  149. package/dist/standalone.d.ts +1 -0
  150. package/dist/standalone.js +306 -0
  151. package/dist/standalone.js.map +1 -0
  152. package/dist/store.d.ts +60 -0
  153. package/dist/store.js +203 -0
  154. package/dist/store.js.map +1 -0
  155. package/dist/types/completions.d.ts +57 -0
  156. package/dist/types/completions.js +2 -0
  157. package/dist/types/completions.js.map +1 -0
  158. package/dist/types/index.d.ts +326 -0
  159. package/dist/types/index.js +2 -0
  160. package/dist/types/index.js.map +1 -0
  161. package/docs/engines.md +28 -0
  162. package/docs/gpu.md +72 -0
  163. package/docs/http-api.md +147 -0
  164. package/examples/all-options.js +108 -0
  165. package/examples/chat-cli.js +56 -0
  166. package/examples/chat-server.js +65 -0
  167. package/examples/concurrency.js +70 -0
  168. package/examples/express.js +70 -0
  169. package/examples/pool.js +91 -0
  170. package/package.json +113 -0
  171. package/src/api/openai/enums.ts +20 -0
  172. package/src/api/openai/handlers/chat.ts +408 -0
  173. package/src/api/openai/handlers/completions.ts +196 -0
  174. package/src/api/openai/handlers/embeddings.ts +92 -0
  175. package/src/api/openai/handlers/images.ts +3 -0
  176. package/src/api/openai/handlers/models.ts +33 -0
  177. package/src/api/openai/handlers/transcription.ts +2 -0
  178. package/src/api/openai/index.ts +16 -0
  179. package/src/api/parseJSONRequestBody.ts +26 -0
  180. package/src/api/v1/DRAFT.md +16 -0
  181. package/src/api/v1/index.ts +37 -0
  182. package/src/cli.ts +9 -0
  183. package/src/engines/gpt4all/engine.ts +441 -0
  184. package/src/engines/gpt4all/util.ts +31 -0
  185. package/src/engines/index.ts +28 -0
  186. package/src/engines/node-llama-cpp/engine.ts +811 -0
  187. package/src/engines/node-llama-cpp/types.ts +17 -0
  188. package/src/engines/node-llama-cpp/util.ts +126 -0
  189. package/src/engines/node-llama-cpp/validateModelFile.ts +46 -0
  190. package/src/engines/stable-diffusion-cpp/engine.ts +369 -0
  191. package/src/engines/stable-diffusion-cpp/types.ts +54 -0
  192. package/src/engines/stable-diffusion-cpp/util.ts +58 -0
  193. package/src/engines/stable-diffusion-cpp/validateModelFiles.ts +119 -0
  194. package/src/engines/transformers-js/engine.ts +659 -0
  195. package/src/engines/transformers-js/types.ts +25 -0
  196. package/src/engines/transformers-js/util.ts +40 -0
  197. package/src/engines/transformers-js/validateModelFiles.ts +168 -0
  198. package/src/experiments/ChatWithVision.ts +103 -0
  199. package/src/experiments/StableDiffPromptGenerator.ts +2 -0
  200. package/src/experiments/VoiceFunctionCall.ts +71 -0
  201. package/src/http.ts +72 -0
  202. package/src/index.ts +7 -0
  203. package/src/instance.ts +723 -0
  204. package/src/lib/acquireFileLock.ts +38 -0
  205. package/src/lib/calculateContextIdentity.ts +53 -0
  206. package/src/lib/calculateFileChecksum.ts +18 -0
  207. package/src/lib/copyDirectory.ts +29 -0
  208. package/src/lib/decodeAudio.ts +39 -0
  209. package/src/lib/downloadModelFile.ts +70 -0
  210. package/src/lib/flattenMessageTextContent.ts +19 -0
  211. package/src/lib/getCacheDirPath.ts +34 -0
  212. package/src/lib/loadImage.ts +46 -0
  213. package/src/lib/logger.ts +112 -0
  214. package/src/lib/math.ts +31 -0
  215. package/src/lib/resolveModelFileLocation.ts +49 -0
  216. package/src/lib/util.ts +75 -0
  217. package/src/lib/validateModelFile.ts +71 -0
  218. package/src/lib/validateModelOptions.ts +31 -0
  219. package/src/pool.ts +651 -0
  220. package/src/server.ts +270 -0
  221. package/src/standalone.ts +320 -0
  222. package/src/store.ts +278 -0
  223. package/src/types/completions.ts +86 -0
  224. package/src/types/index.ts +488 -0
  225. package/tsconfig.json +29 -0
  226. package/tsconfig.release.json +11 -0
  227. package/vitest.config.ts +18 -0
@@ -0,0 +1,659 @@
1
+ import path from 'node:path'
2
+ import fs from 'node:fs'
3
+ import {
4
+ EngineContext,
5
+ FileDownloadProgress,
6
+ ModelConfig,
7
+ EngineImageToTextArgs,
8
+ EngineSpeechToTextArgs,
9
+ EngineTextCompletionResult,
10
+ EngineTextCompletionArgs,
11
+ EngineEmbeddingArgs,
12
+ EngineEmbeddingResult,
13
+ ImageEmbeddingInput,
14
+ TransformersJsModel,
15
+ TextEmbeddingInput,
16
+ } from '#package/types/index.js'
17
+ import {
18
+ env,
19
+ AutoModel,
20
+ AutoProcessor,
21
+ AutoTokenizer,
22
+ RawImage,
23
+ TextStreamer,
24
+ mean_pooling,
25
+ Processor,
26
+ PreTrainedModel,
27
+ PreTrainedTokenizer,
28
+ } from '@huggingface/transformers'
29
+ import { LogLevels } from '#package/lib/logger.js'
30
+ import { acquireFileLock } from '#package/lib/acquireFileLock.js'
31
+ import { decodeAudio } from '#package/lib/decodeAudio.js'
32
+ import { resolveModelFileLocation } from '#package/lib/resolveModelFileLocation.js'
33
+ import { parseHuggingfaceModelIdAndBranch, remoteFileExists } from './util.js'
34
+ import { validateModelFiles, ModelValidationResult } from './validateModelFiles.js'
35
+ import { copyDirectory } from '#package/lib/copyDirectory.js'
36
+
37
+ interface TransformersJsModelComponents {
38
+ model?: PreTrainedModel
39
+ processor?: Processor
40
+ tokenizer?: PreTrainedTokenizer
41
+ }
42
+
43
+ interface TransformersJsInstance {
44
+ textModel?: TransformersJsModelComponents
45
+ visionModel?: TransformersJsModelComponents
46
+ speechModel?: TransformersJsModelComponents
47
+ }
48
+
49
+ interface ModelFile {
50
+ file: string
51
+ size: number
52
+ }
53
+
54
+ // interface TransformersJsModelMeta {
55
+ // modelType: string
56
+ // files: ModelFile[]
57
+ // }
58
+
59
+ export interface TransformersJsModelConfig extends ModelConfig {
60
+ location: string
61
+ url: string
62
+ textModel?: TransformersJsModel
63
+ visionModel?: TransformersJsModel
64
+ speechModel?: TransformersJsModel
65
+ device?: {
66
+ gpu?: boolean | 'auto' | (string & {})
67
+ }
68
+ }
69
+
70
+ export const autoGpu = true
71
+
72
+ let didConfigureEnvironment = false
73
+ function configureEnvironment(modelsPath: string) {
74
+ // console.debug({
75
+ // cacheDir: env.cacheDir,
76
+ // localModelPaths: env.localModelPath,
77
+ // })
78
+ // env.useFSCache = false
79
+ // env.useCustomCache = true
80
+ // env.customCache = new TransformersFileCache(modelsPath)
81
+ env.localModelPath = ''
82
+ didConfigureEnvironment = true
83
+ }
84
+
85
+ async function loadModelComponents(
86
+ modelOpts: TransformersJsModel,
87
+ config: TransformersJsModelConfig,
88
+ ): Promise<TransformersJsModelComponents> {
89
+ const device = config.device?.gpu ? 'gpu' : 'cpu'
90
+ const modelClass = modelOpts.modelClass ?? AutoModel
91
+ let modelPath = config.location
92
+ if (!modelPath.endsWith('/')) {
93
+ modelPath += '/'
94
+ }
95
+ const loadPromises = []
96
+ const modelPromise = modelClass.from_pretrained(modelPath, {
97
+ local_files_only: true,
98
+ device: device,
99
+ dtype: modelOpts.dtype || 'fp32',
100
+ })
101
+ loadPromises.push(modelPromise)
102
+ const tokenizerClass = modelOpts.tokenizerClass ?? AutoTokenizer
103
+ const tokenizerPromise = tokenizerClass.from_pretrained(modelPath, {
104
+ local_files_only: true,
105
+ })
106
+ loadPromises.push(tokenizerPromise)
107
+
108
+ const hasPreprocessor = fs.existsSync(modelPath + 'preprocessor_config.json')
109
+ const hasProcessor = fs.existsSync(modelPath + 'processor_config.json')
110
+
111
+ if (hasProcessor || hasPreprocessor || modelOpts.processor) {
112
+ const processorClass = modelOpts.processorClass ?? AutoProcessor
113
+ if (modelOpts.processor) {
114
+ const processorPath = resolveModelFileLocation({
115
+ url: modelOpts.processor.url,
116
+ filePath: modelOpts.processor.file,
117
+ modelsCachePath: config.modelsCachePath,
118
+ })
119
+ const processorPromise = processorClass.from_pretrained(processorPath, {
120
+ local_files_only: true,
121
+ })
122
+ loadPromises.push(processorPromise)
123
+ } else {
124
+ const processorPromise = processorClass.from_pretrained(modelPath, {
125
+ local_files_only: true,
126
+ })
127
+ loadPromises.push(processorPromise)
128
+ }
129
+ }
130
+
131
+ const loadedComponents = await Promise.all(loadPromises)
132
+ const modelComponents: TransformersJsModelComponents = {}
133
+ if (loadedComponents[0]) {
134
+ modelComponents.model = loadedComponents[0] as PreTrainedModel
135
+ }
136
+ if (loadedComponents[1]) {
137
+ modelComponents.tokenizer = loadedComponents[1] as PreTrainedTokenizer
138
+ }
139
+ if (loadedComponents[2]) {
140
+ modelComponents.processor = loadedComponents[2] as Processor
141
+ }
142
+ return modelComponents
143
+ }
144
+
145
+ async function disposeModelComponents(modelComponents: TransformersJsModelComponents) {
146
+ if (modelComponents.model && 'dispose' in modelComponents.model) {
147
+ await modelComponents.model.dispose()
148
+ }
149
+ }
150
+
151
+ interface TransformersJsDownloadProgress {
152
+ status: 'progress' | 'done' | 'initiate'
153
+ name: string
154
+ file: string
155
+ progress: number
156
+ loaded: number
157
+ total: number
158
+ }
159
+
160
+ async function acquireModelFileLocks(config: TransformersJsModelConfig, signal?: AbortSignal) {
161
+ const requestedLocks: Array<Promise<() => void>> = []
162
+ const modelId = config.id
163
+ const modelCacheDir = path.join(env.cacheDir, modelId)
164
+ fs.mkdirSync(modelCacheDir, { recursive: true })
165
+ requestedLocks.push(acquireFileLock(modelCacheDir, signal))
166
+ if (config.visionModel?.processor?.url) {
167
+ const { modelId } = parseHuggingfaceModelIdAndBranch(config.visionModel.processor.url)
168
+ const processorCacheDir = path.join(env.cacheDir, modelId)
169
+ fs.mkdirSync(processorCacheDir, { recursive: true })
170
+ requestedLocks.push(acquireFileLock(processorCacheDir, signal))
171
+ }
172
+ const acquiredLocks = await Promise.all(requestedLocks)
173
+ return () => {
174
+ for (const releaseLock of acquiredLocks) {
175
+ releaseLock()
176
+ }
177
+ }
178
+ }
179
+
180
+ export async function prepareModel(
181
+ { config, log }: EngineContext<TransformersJsModelConfig>,
182
+ onProgress?: (progress: FileDownloadProgress) => void,
183
+ signal?: AbortSignal,
184
+ ) {
185
+ if (!didConfigureEnvironment) {
186
+ configureEnvironment(config.modelsCachePath)
187
+ }
188
+ fs.mkdirSync(config.location, { recursive: true })
189
+ const releaseFileLocks = await acquireModelFileLocks(config, signal)
190
+ if (signal?.aborted) {
191
+ releaseFileLocks()
192
+ return
193
+ }
194
+ log(LogLevels.info, `Preparing transformers.js model at ${config.location}`, {
195
+ model: config.id,
196
+ })
197
+
198
+ const downloadModelFiles = async (
199
+ modelOpts: TransformersJsModel,
200
+ { modelId, branch }: { modelId: string; branch: string },
201
+ requiredComponents: string[] = ['model', 'tokenizer', 'processor'],
202
+ ) => {
203
+ const modelClass = modelOpts.modelClass ?? AutoModel
204
+ const downloadPromises: Record<string, Promise<any> | undefined> = {}
205
+ const progressCallback = (progress: TransformersJsDownloadProgress) => {
206
+ if (onProgress && progress.status === 'progress') {
207
+ onProgress({
208
+ file: env.cacheDir + progress.name + '/' + progress.file,
209
+ loadedBytes: progress.loaded,
210
+ totalBytes: progress.total,
211
+ })
212
+ }
213
+ }
214
+ if (requiredComponents.includes('model')) {
215
+ const modelDownloadPromise = modelClass.from_pretrained(modelId, {
216
+ revision: branch,
217
+ dtype: modelOpts.dtype || 'fp32',
218
+ progress_callback: progressCallback,
219
+ // use_external_data_format: true, // https://github.com/xenova/transformers.js/blob/38a3bf6dab2265d9f0c2f613064535863194e6b9/src/models.js#L205-L207
220
+ })
221
+ downloadPromises.model = modelDownloadPromise
222
+ }
223
+ if (requiredComponents.includes('tokenizer')) {
224
+ const hasTokenizer = await remoteFileExists(`${config.url}/blob/${branch}/tokenizer.json`)
225
+ if (hasTokenizer) {
226
+ const tokenizerClass = modelOpts.tokenizerClass ?? AutoTokenizer
227
+ const tokenizerDownload = tokenizerClass.from_pretrained(modelId, {
228
+ revision: branch,
229
+ progress_callback: progressCallback,
230
+ // use_external_data_format: true,
231
+ })
232
+ downloadPromises.tokenizer = tokenizerDownload
233
+ }
234
+ }
235
+
236
+ if (requiredComponents.includes('processor')) {
237
+ if (modelOpts.processor?.url) {
238
+ const { modelId, branch } = parseHuggingfaceModelIdAndBranch(modelOpts.processor.url)
239
+ const processorDownload = AutoProcessor.from_pretrained(modelId, {
240
+ revision: branch,
241
+ progress_callback: progressCallback,
242
+ })
243
+ downloadPromises.processor = processorDownload
244
+ } else {
245
+ const [hasProcessor, hasPreprocessor] = await Promise.all([
246
+ remoteFileExists(`${config.url}/blob/${branch}/processor_config.json`),
247
+ remoteFileExists(`${config.url}/blob/${branch}/preprocessor_config.json`),
248
+ ])
249
+ if (hasProcessor || hasPreprocessor) {
250
+ const processorDownload = AutoProcessor.from_pretrained(modelId, {
251
+ revision: branch,
252
+ progress_callback: progressCallback,
253
+ // use_external_data_format: true,
254
+ })
255
+ downloadPromises.processor = processorDownload
256
+ }
257
+ }
258
+ }
259
+ await Promise.all(Object.values(downloadPromises))
260
+ const modelComponents: TransformersJsModelComponents = {}
261
+ if (downloadPromises.model) {
262
+ modelComponents.model = (await downloadPromises.model) as PreTrainedModel
263
+ }
264
+ if (downloadPromises.tokenizer) {
265
+ modelComponents.tokenizer = (await downloadPromises.tokenizer) as PreTrainedTokenizer
266
+ }
267
+ if (downloadPromises.processor) {
268
+ modelComponents.processor = (await downloadPromises.processor) as Processor
269
+ }
270
+ return modelComponents
271
+ }
272
+
273
+ const downloadModel = async (validationResult: ModelValidationResult) => {
274
+ log(LogLevels.info, `${validationResult.message} - Downloading files`, {
275
+ model: config.id,
276
+ url: config.url,
277
+ location: config.location,
278
+ errors: validationResult.errors,
279
+ })
280
+ const modelDownloadPromises = []
281
+ if (!config.url) {
282
+ throw new Error(`Missing URL for model ${config.id}`)
283
+ }
284
+ const { modelId, branch } = parseHuggingfaceModelIdAndBranch(config.url)
285
+ const directoriesToCopy: Record<string, string> = {}
286
+ const modelCacheDir = path.join(env.cacheDir, modelId)
287
+ directoriesToCopy[modelCacheDir] = config.location
288
+ const noModelConfigured = !config.textModel && !config.visionModel && !config.speechModel
289
+ if (config.textModel || noModelConfigured) {
290
+ const requiredComponents = validationResult.errors?.textModel
291
+ ? Object.keys(validationResult.errors.textModel)
292
+ : undefined
293
+ modelDownloadPromises.push(downloadModelFiles(config.textModel || {}, { modelId, branch }, requiredComponents))
294
+ }
295
+ if (config.visionModel) {
296
+ const requiredComponents = validationResult.errors?.visionModel
297
+ ? Object.keys(validationResult.errors.visionModel)
298
+ : undefined
299
+ modelDownloadPromises.push(downloadModelFiles(config.visionModel, { modelId, branch }, requiredComponents))
300
+ if (config.visionModel.processor?.url) {
301
+ const processorPath = resolveModelFileLocation({
302
+ url: config.visionModel.processor.url,
303
+ filePath: config.visionModel.processor.file,
304
+ modelsCachePath: config.modelsCachePath,
305
+ })
306
+ const { modelId } = parseHuggingfaceModelIdAndBranch(config.visionModel.processor.url)
307
+ const processorCacheDir = path.join(env.cacheDir, modelId)
308
+ directoriesToCopy[processorCacheDir] = processorPath
309
+ }
310
+ }
311
+ if (config.speechModel) {
312
+ const requiredComponents = validationResult.errors?.speechModel
313
+ ? Object.keys(validationResult.errors.speechModel)
314
+ : undefined
315
+ modelDownloadPromises.push(downloadModelFiles(config.speechModel, { modelId, branch }, requiredComponents))
316
+ }
317
+ const models = await Promise.all(modelDownloadPromises)
318
+ for (const modelComponents of models) {
319
+ disposeModelComponents(modelComponents)
320
+ }
321
+ if (signal?.aborted) {
322
+ return
323
+ }
324
+ // copy all downloads to their actual location, then remove the cache so we dont duplicate
325
+ await Promise.all(Object.entries(directoriesToCopy).map(async ([from, to]) => {
326
+ await copyDirectory(from, to)
327
+ await fs.promises.rmdir(from, { recursive: true })
328
+ }))
329
+ }
330
+
331
+ try {
332
+ const validationResults = await validateModelFiles(config)
333
+ if (signal?.aborted) {
334
+ releaseFileLocks()
335
+ return
336
+ }
337
+ if (validationResults) {
338
+ if (config.url) {
339
+ await downloadModel(validationResults)
340
+ } else {
341
+ throw new Error(`Model files are invalid: ${validationResults.message}`)
342
+ }
343
+ }
344
+ } catch (error) {
345
+ throw error
346
+ } finally {
347
+ releaseFileLocks()
348
+ }
349
+ const configMeta: Record<string, any> = {}
350
+ const fileList: ModelFile[] = []
351
+ const modelFiles = fs.readdirSync(config.location, { recursive: true })
352
+
353
+ const pushFile = (file: string) => {
354
+ const targetFile = path.join(config.location, file)
355
+ const targetStat = fs.statSync(targetFile)
356
+ fileList.push({
357
+ file: targetFile,
358
+ size: targetStat.size,
359
+ })
360
+ if (targetFile.endsWith('.json')) {
361
+ const key = path.basename(targetFile).replace('.json', '')
362
+ configMeta[key] = JSON.parse(fs.readFileSync(targetFile, 'utf8'))
363
+ }
364
+ }
365
+ // add model files to the list
366
+ for (const file of modelFiles) {
367
+ pushFile(file.toString())
368
+ }
369
+
370
+ // add extra stuff from external repos
371
+ if (config.visionModel?.processor) {
372
+ const processorPath = resolveModelFileLocation({
373
+ url: config.visionModel.processor.url,
374
+ filePath: config.visionModel.processor.file,
375
+ modelsCachePath: config.modelsCachePath,
376
+ })
377
+ const processorFiles = fs.readdirSync(processorPath, { recursive: true })
378
+ for (const file of processorFiles) {
379
+ pushFile(file.toString())
380
+ }
381
+ }
382
+ return {
383
+ files: modelFiles,
384
+ ...configMeta,
385
+ }
386
+ }
387
+
388
+ export async function createInstance({ config, log }: EngineContext<TransformersJsModelConfig>, signal?: AbortSignal) {
389
+ const modelLoadPromises = []
390
+ const noModelConfigured = !config.textModel && !config.visionModel && !config.speechModel
391
+
392
+ if (config.textModel || noModelConfigured) {
393
+ modelLoadPromises.push(loadModelComponents(config.textModel || {}, config))
394
+ } else {
395
+ modelLoadPromises.push(Promise.resolve(undefined))
396
+ }
397
+ if (config.visionModel) {
398
+ modelLoadPromises.push(loadModelComponents(config.visionModel, config))
399
+ } else {
400
+ modelLoadPromises.push(Promise.resolve(undefined))
401
+ }
402
+ if (config.speechModel) {
403
+ modelLoadPromises.push(loadModelComponents(config.speechModel, config))
404
+ } else {
405
+ modelLoadPromises.push(Promise.resolve(undefined))
406
+ }
407
+
408
+ const models = await Promise.all(modelLoadPromises)
409
+ const instance: TransformersJsInstance = {
410
+ textModel: models[0],
411
+ visionModel: models[1],
412
+ speechModel: models[2],
413
+ }
414
+
415
+ // TODO preload whisper / any speech to text?
416
+ // await model.generate({
417
+ // input_features: full([1, 80, 3000], 0.0),
418
+ // max_new_tokens: 1,
419
+ // });
420
+
421
+ return instance
422
+ }
423
+
424
+ export async function disposeInstance(instance: TransformersJsInstance) {
425
+ const disposePromises = []
426
+ if (instance.textModel) {
427
+ disposePromises.push(disposeModelComponents(instance.textModel))
428
+ }
429
+ if (instance.visionModel) {
430
+ disposePromises.push(disposeModelComponents(instance.visionModel))
431
+ }
432
+ if (instance.speechModel) {
433
+ disposePromises.push(disposeModelComponents(instance.speechModel))
434
+ }
435
+ await Promise.all(disposePromises)
436
+ }
437
+
438
+ export async function processTextCompletionTask(
439
+ { request, config, log, onChunk }: EngineTextCompletionArgs<TransformersJsModelConfig>,
440
+ instance: TransformersJsInstance,
441
+ signal?: AbortSignal,
442
+ ): Promise<EngineTextCompletionResult> {
443
+ if (!request.prompt) {
444
+ throw new Error('Prompt is required for text completion.')
445
+ }
446
+ if (!(instance.textModel?.tokenizer && instance.textModel?.model)) {
447
+ throw new Error('Text model is not loaded.')
448
+ }
449
+ const inputTokens = instance.textModel.tokenizer(request.prompt)
450
+ const outputTokens = await instance.textModel.model.generate({
451
+ ...inputTokens,
452
+ max_new_tokens: request.maxTokens ?? 128,
453
+ })
454
+ // @ts-ignore
455
+ const outputText = instance.textModel.tokenizer.batch_decode(outputTokens, {
456
+ skip_special_tokens: true,
457
+ })
458
+
459
+ return {
460
+ finishReason: 'eogToken',
461
+ text: outputText[0],
462
+ promptTokens: inputTokens.length,
463
+ // @ts-ignore
464
+ completionTokens: outputTokens.length,
465
+ // @ts-ignore
466
+ contextTokens: inputTokens.length + outputTokens.length,
467
+ }
468
+ }
469
+
470
+ // see https://github.com/xenova/transformers.js/blob/v3/src/utils/tensor.js
471
+ // https://github.com/xenova/transformers.js/blob/v3/src/pipelines.js#L1284
472
+ export async function processEmbeddingTask(
473
+ { request, config }: EngineEmbeddingArgs<TransformersJsModelConfig>,
474
+ instance: TransformersJsInstance,
475
+ signal?: AbortSignal,
476
+ ): Promise<EngineEmbeddingResult> {
477
+ if (!request.input) {
478
+ throw new Error('Input is required for embedding.')
479
+ }
480
+ const inputs = Array.isArray(request.input) ? request.input : [request.input]
481
+ const normalizedInputs: Array<TextEmbeddingInput | ImageEmbeddingInput> = inputs.map((input) => {
482
+ if (typeof input === 'string') {
483
+ return {
484
+ type: 'text',
485
+ content: input,
486
+ }
487
+ } else if (input.type) {
488
+ return input
489
+ } else {
490
+ throw new Error('Invalid input type')
491
+ }
492
+ })
493
+
494
+ const embeddings: Float32Array[] = []
495
+ let inputTokens = 0
496
+
497
+ const applyPooling = (result: any, pooling: string, modelInputs: any) => {
498
+ if (pooling === 'mean') {
499
+ return mean_pooling(result, modelInputs.attention_mask)
500
+ } else if (pooling === 'cls') {
501
+ return result.slice(null, 0)
502
+ } else {
503
+ throw Error(`Pooling method '${pooling}' not supported.`)
504
+ }
505
+ }
506
+
507
+ const truncateDimensions = (result: any, dimensions: number) => {
508
+ const truncatedData = new Float32Array(dimensions)
509
+ truncatedData.set(result.data.slice(0, dimensions))
510
+ return truncatedData
511
+ }
512
+
513
+ for (const embeddingInput of normalizedInputs) {
514
+ if (signal?.aborted) {
515
+ break
516
+ }
517
+ let result
518
+ let modelInputs
519
+ if (embeddingInput.type === 'text') {
520
+ if (!instance.textModel?.tokenizer || !instance.textModel?.model) {
521
+ throw new Error('Text model is not loaded.')
522
+ }
523
+ modelInputs = instance.textModel.tokenizer(embeddingInput.content, {
524
+ padding: true, // pads input if it is shorter than context window
525
+ truncation: true, // truncates input if it exceeds context window
526
+ })
527
+ inputTokens += modelInputs.input_ids.size
528
+ const modelOutputs = await instance.textModel.model(modelInputs)
529
+ result =
530
+ modelOutputs.last_hidden_state ??
531
+ modelOutputs.logits ??
532
+ modelOutputs.token_embeddings ??
533
+ modelOutputs.text_embeds
534
+ } else if (embeddingInput.type === 'image') {
535
+ if (!instance.visionModel?.processor || !instance.visionModel?.model) {
536
+ throw new Error('Vision model is not loaded.')
537
+ }
538
+ const { data, info } = await embeddingInput.content.handle.raw().toBuffer({ resolveWithObject: true })
539
+ const image = new RawImage(new Uint8ClampedArray(data), info.width, info.height, info.channels)
540
+ modelInputs = await instance.visionModel.processor!(image)
541
+ const modelOutputs = await instance.visionModel.model(modelInputs)
542
+ result = modelOutputs.last_hidden_state ?? modelOutputs.logits ?? modelOutputs.image_embeds
543
+ }
544
+
545
+ if (request.pooling) {
546
+ result = applyPooling(result, request.pooling, modelInputs)
547
+ }
548
+ if (request.dimensions && result.data.length > request.dimensions) {
549
+ embeddings.push(truncateDimensions(result, request.dimensions))
550
+ } else {
551
+ embeddings.push(result.data)
552
+ }
553
+ }
554
+
555
+ return {
556
+ embeddings,
557
+ inputTokens,
558
+ }
559
+ }
560
+
561
+ export async function processImageToTextTask(
562
+ { request, config, log }: EngineImageToTextArgs,
563
+ instance: TransformersJsInstance,
564
+ signal?: AbortSignal,
565
+ ) {
566
+ if (!request.image) {
567
+ throw new Error('No image provided')
568
+ }
569
+ const { data, info } = await request.image.handle.raw().toBuffer({ resolveWithObject: true })
570
+ const image = new RawImage(new Uint8ClampedArray(data), info.width, info.height, info.channels)
571
+
572
+ if (signal?.aborted) {
573
+ return
574
+ }
575
+
576
+ const model = instance.visionModel || instance.textModel
577
+ if (!(model && model.tokenizer && model.processor && model.model)) {
578
+ throw new Error('No model loaded')
579
+ }
580
+ let textInputs = {}
581
+ if (request.prompt) {
582
+ textInputs = model!.tokenizer(request.prompt)
583
+ }
584
+ const imageInputs = await model.processor(image)
585
+ const outputTokens = await model.model.generate({
586
+ ...textInputs,
587
+ ...imageInputs,
588
+ max_new_tokens: request.maxTokens ?? 128,
589
+ })
590
+ // @ts-ignore
591
+ const outputText = model.tokenizer.batch_decode(outputTokens, {
592
+ skip_special_tokens: true,
593
+ })
594
+
595
+ return {
596
+ text: outputText[0],
597
+ }
598
+ }
599
+
600
+ async function readAudioFile(filePath: string) {
601
+ const WHISPER_SAMPLING_RATE = 16_000
602
+ const MAX_AUDIO_LENGTH = 30 // seconds
603
+ const MAX_SAMPLES = WHISPER_SAMPLING_RATE * MAX_AUDIO_LENGTH
604
+ // Read the file into a buffer
605
+ const fileBuffer = fs.readFileSync(filePath)
606
+
607
+ // Decode the audio data
608
+ let decodedAudio = await decodeAudio(fileBuffer, WHISPER_SAMPLING_RATE)
609
+
610
+ // Trim the audio data if it exceeds MAX_SAMPLES
611
+ if (decodedAudio.length > MAX_SAMPLES) {
612
+ decodedAudio = decodedAudio.slice(-MAX_SAMPLES)
613
+ }
614
+
615
+ return decodedAudio
616
+ }
617
+
618
+ // see examples
619
+ // https://huggingface.co/docs/transformers.js/guides/node-audio-processing
620
+ // https://github.com/xenova/transformers.js/tree/v3/examples/node-audio-processing
621
+ export async function processSpeechToTextTask(
622
+ { request, onChunk }: EngineSpeechToTextArgs,
623
+ instance: TransformersJsInstance,
624
+ signal?: AbortSignal,
625
+ ) {
626
+ if (!(instance.speechModel?.tokenizer && instance.speechModel?.model)) {
627
+ throw new Error('No speech model loaded')
628
+ }
629
+ const streamer = new TextStreamer(instance.speechModel.tokenizer, {
630
+ skip_prompt: true,
631
+ // skip_special_tokens: true,
632
+ callback_function: (output: any) => {
633
+ if (onChunk) {
634
+ onChunk({ text: output })
635
+ }
636
+ },
637
+ })
638
+ let inputs
639
+ if (request.file) {
640
+ const audio = await readAudioFile(request.file)
641
+ inputs = await instance.speechModel.processor!(audio)
642
+ }
643
+
644
+ const outputs = await instance.speechModel.model.generate({
645
+ ...inputs,
646
+ max_new_tokens: request.maxTokens ?? 128,
647
+ language: request.language ?? 'en',
648
+ streamer,
649
+ })
650
+
651
+ // @ts-ignore
652
+ const outputText = instance.speechModel.tokenizer.batch_decode(outputs, {
653
+ skip_special_tokens: true,
654
+ })
655
+
656
+ return {
657
+ text: outputText[0],
658
+ }
659
+ }
@@ -0,0 +1,25 @@
1
+ import type {
2
+ Processor,
3
+ PreTrainedModel,
4
+ PreTrainedTokenizer,
5
+ PretrainedMixin,
6
+ AutoProcessor,
7
+ // DataType, // this is the tensor.js DataType Type, not the one im looking for
8
+ } from '@huggingface/transformers'
9
+ // import type { DataType } from '@huggingface/transformers/src/utils/dtypes.js' // this is dtypes.js DataType Type, cant import
10
+
11
+ export type TransformersJsModelClass = typeof PreTrainedModel
12
+ export type TransformersJsTokenizerClass = typeof PreTrainedTokenizer
13
+ export interface TransformersJsProcessorClass {
14
+ from_pretrained: (typeof AutoProcessor)['from_pretrained']
15
+ }
16
+
17
+ export type TransformersJsDataType =
18
+ | 'fp32'
19
+ | 'fp16'
20
+ | 'q8'
21
+ | 'int8'
22
+ | 'uint8'
23
+ | 'q4'
24
+ | 'bnb4'
25
+ | 'q4f16'