@workglow/huggingface-transformers 0.2.28
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/ai-provider/HuggingFaceTransformersProvider.d.ts +28 -0
- package/dist/ai-provider/HuggingFaceTransformersProvider.d.ts.map +1 -0
- package/dist/ai-provider/HuggingFaceTransformersQueuedProvider.d.ts +29 -0
- package/dist/ai-provider/HuggingFaceTransformersQueuedProvider.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_BackgroundRemoval.d.ts +12 -0
- package/dist/ai-provider/common/HFT_BackgroundRemoval.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_Chat.d.ts +10 -0
- package/dist/ai-provider/common/HFT_Chat.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_Constants.d.ts +95 -0
- package/dist/ai-provider/common/HFT_Constants.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_CountTokens.d.ts +10 -0
- package/dist/ai-provider/common/HFT_CountTokens.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_Download.d.ts +13 -0
- package/dist/ai-provider/common/HFT_Download.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_ImageClassification.d.ts +13 -0
- package/dist/ai-provider/common/HFT_ImageClassification.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_ImageEmbedding.d.ts +12 -0
- package/dist/ai-provider/common/HFT_ImageEmbedding.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_ImageSegmentation.d.ts +12 -0
- package/dist/ai-provider/common/HFT_ImageSegmentation.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_ImageToText.d.ts +12 -0
- package/dist/ai-provider/common/HFT_ImageToText.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_InlineLifecycle.d.ts +7 -0
- package/dist/ai-provider/common/HFT_InlineLifecycle.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_JobRunFns.d.ts +19 -0
- package/dist/ai-provider/common/HFT_JobRunFns.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_ModelInfo.d.ts +9 -0
- package/dist/ai-provider/common/HFT_ModelInfo.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_ModelSchema.d.ts +368 -0
- package/dist/ai-provider/common/HFT_ModelSchema.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_ModelSearch.d.ts +8 -0
- package/dist/ai-provider/common/HFT_ModelSearch.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_ObjectDetection.d.ts +13 -0
- package/dist/ai-provider/common/HFT_ObjectDetection.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_OnnxDtypes.d.ts +27 -0
- package/dist/ai-provider/common/HFT_OnnxDtypes.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_Pipeline.d.ts +62 -0
- package/dist/ai-provider/common/HFT_Pipeline.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_Streaming.d.ts +25 -0
- package/dist/ai-provider/common/HFT_Streaming.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_StructuredGeneration.d.ts +10 -0
- package/dist/ai-provider/common/HFT_StructuredGeneration.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_TextClassification.d.ts +9 -0
- package/dist/ai-provider/common/HFT_TextClassification.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_TextEmbedding.d.ts +13 -0
- package/dist/ai-provider/common/HFT_TextEmbedding.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_TextFillMask.d.ts +9 -0
- package/dist/ai-provider/common/HFT_TextFillMask.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_TextGeneration.d.ts +14 -0
- package/dist/ai-provider/common/HFT_TextGeneration.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_TextLanguageDetection.d.ts +9 -0
- package/dist/ai-provider/common/HFT_TextLanguageDetection.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_TextNamedEntityRecognition.d.ts +9 -0
- package/dist/ai-provider/common/HFT_TextNamedEntityRecognition.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_TextOutput.d.ts +8 -0
- package/dist/ai-provider/common/HFT_TextOutput.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_TextQuestionAnswer.d.ts +14 -0
- package/dist/ai-provider/common/HFT_TextQuestionAnswer.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_TextRewriter.d.ts +14 -0
- package/dist/ai-provider/common/HFT_TextRewriter.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_TextSummary.d.ts +14 -0
- package/dist/ai-provider/common/HFT_TextSummary.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_TextTranslation.d.ts +14 -0
- package/dist/ai-provider/common/HFT_TextTranslation.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_ToolCalling.d.ts +19 -0
- package/dist/ai-provider/common/HFT_ToolCalling.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_ToolMarkup.d.ts +20 -0
- package/dist/ai-provider/common/HFT_ToolMarkup.d.ts.map +1 -0
- package/dist/ai-provider/common/HFT_Unload.d.ts +13 -0
- package/dist/ai-provider/common/HFT_Unload.d.ts.map +1 -0
- package/dist/ai-provider/index.d.ts +13 -0
- package/dist/ai-provider/index.d.ts.map +1 -0
- package/dist/ai-provider/registerHuggingFaceTransformers.d.ts +14 -0
- package/dist/ai-provider/registerHuggingFaceTransformers.d.ts.map +1 -0
- package/dist/ai-provider/registerHuggingFaceTransformersInline.d.ts +15 -0
- package/dist/ai-provider/registerHuggingFaceTransformersInline.d.ts.map +1 -0
- package/dist/ai-provider/registerHuggingFaceTransformersWorker.d.ts +7 -0
- package/dist/ai-provider/registerHuggingFaceTransformersWorker.d.ts.map +1 -0
- package/dist/ai-provider/runtime.d.ts +21 -0
- package/dist/ai-provider/runtime.d.ts.map +1 -0
- package/dist/ai-provider-runtime.d.ts +7 -0
- package/dist/ai-provider-runtime.d.ts.map +1 -0
- package/dist/ai-provider-runtime.js +2367 -0
- package/dist/ai-provider-runtime.js.map +46 -0
- package/dist/ai-provider.d.ts +7 -0
- package/dist/ai-provider.d.ts.map +1 -0
- package/dist/ai-provider.js +879 -0
- package/dist/ai-provider.js.map +17 -0
- package/package.json +60 -0
|
@@ -0,0 +1,2367 @@
|
|
|
1
|
+
var __defProp = Object.defineProperty;
|
|
2
|
+
var __returnValue = (v) => v;
|
|
3
|
+
function __exportSetter(name, newValue) {
|
|
4
|
+
this[name] = __returnValue.bind(null, newValue);
|
|
5
|
+
}
|
|
6
|
+
var __export = (target, all) => {
|
|
7
|
+
for (var name in all)
|
|
8
|
+
__defProp(target, name, {
|
|
9
|
+
get: all[name],
|
|
10
|
+
enumerable: true,
|
|
11
|
+
configurable: true,
|
|
12
|
+
set: __exportSetter.bind(all, name)
|
|
13
|
+
});
|
|
14
|
+
};
|
|
15
|
+
var __esm = (fn, res) => () => (fn && (res = fn(fn = 0)), res);
|
|
16
|
+
var __require = /* @__PURE__ */ ((x) => typeof require !== "undefined" ? require : typeof Proxy !== "undefined" ? new Proxy(x, {
|
|
17
|
+
get: (a, b) => (typeof require !== "undefined" ? require : a)[b]
|
|
18
|
+
}) : x)(function(x) {
|
|
19
|
+
if (typeof require !== "undefined")
|
|
20
|
+
return require.apply(this, arguments);
|
|
21
|
+
throw Error('Dynamic require of "' + x + '" is not supported');
|
|
22
|
+
});
|
|
23
|
+
|
|
24
|
+
// src/ai-provider/common/HFT_Pipeline.ts
|
|
25
|
+
var exports_HFT_Pipeline = {};
|
|
26
|
+
__export(exports_HFT_Pipeline, {
|
|
27
|
+
setHftSession: () => setHftSession,
|
|
28
|
+
setHftCacheDir: () => setHftCacheDir,
|
|
29
|
+
removeCachedPipeline: () => removeCachedPipeline,
|
|
30
|
+
loadTransformersSDK: () => loadTransformersSDK,
|
|
31
|
+
hasCachedPipeline: () => hasCachedPipeline,
|
|
32
|
+
getPipelineCacheKey: () => getPipelineCacheKey,
|
|
33
|
+
getPipeline: () => getPipeline,
|
|
34
|
+
getHftSession: () => getHftSession,
|
|
35
|
+
disposeHftSessionsForModel: () => disposeHftSessionsForModel,
|
|
36
|
+
deleteHftSession: () => deleteHftSession,
|
|
37
|
+
clearPipelineCache: () => clearPipelineCache,
|
|
38
|
+
HFT_NULL_PROCESSOR_PREFIX: () => HFT_NULL_PROCESSOR_PREFIX
|
|
39
|
+
});
|
|
40
|
+
import { getLogger } from "@workglow/util/worker";
|
|
41
|
+
function setHftCacheDir(dir) {
|
|
42
|
+
_cacheDir = dir;
|
|
43
|
+
if (_transformersSdk) {
|
|
44
|
+
_transformersSdk.env.cacheDir = dir;
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
async function loadTransformersSDK() {
|
|
48
|
+
if (!_transformersSdk) {
|
|
49
|
+
try {
|
|
50
|
+
_transformersSdk = await import("@huggingface/transformers");
|
|
51
|
+
_transformersSdk.env.fetch = abortableFetch;
|
|
52
|
+
if (_cacheDir) {
|
|
53
|
+
_transformersSdk.env.cacheDir = _cacheDir;
|
|
54
|
+
}
|
|
55
|
+
} catch {
|
|
56
|
+
throw new Error("@huggingface/transformers is required for HuggingFace Transformers tasks. Install it with: bun add @huggingface/transformers");
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
return _transformersSdk;
|
|
60
|
+
}
|
|
61
|
+
function combineAbortSignals(existingSignal, modelSignal) {
|
|
62
|
+
if (!existingSignal) {
|
|
63
|
+
return modelSignal;
|
|
64
|
+
}
|
|
65
|
+
if (!modelSignal) {
|
|
66
|
+
return existingSignal;
|
|
67
|
+
}
|
|
68
|
+
if (existingSignal.aborted || modelSignal.aborted) {
|
|
69
|
+
return AbortSignal.abort(existingSignal.reason ?? modelSignal.reason);
|
|
70
|
+
}
|
|
71
|
+
if (typeof AbortSignal.any === "function") {
|
|
72
|
+
return AbortSignal.any([existingSignal, modelSignal]);
|
|
73
|
+
}
|
|
74
|
+
const controller = new AbortController;
|
|
75
|
+
const abort = (event) => {
|
|
76
|
+
const signal = event.target;
|
|
77
|
+
controller.abort(signal.reason);
|
|
78
|
+
};
|
|
79
|
+
existingSignal.addEventListener("abort", abort, { once: true });
|
|
80
|
+
modelSignal.addEventListener("abort", abort, { once: true });
|
|
81
|
+
return controller.signal;
|
|
82
|
+
}
|
|
83
|
+
function createAbortError(signal) {
|
|
84
|
+
const reason = signal.reason;
|
|
85
|
+
if (reason instanceof Error) {
|
|
86
|
+
return reason;
|
|
87
|
+
}
|
|
88
|
+
return new Error(String(reason ?? "Fetch aborted"));
|
|
89
|
+
}
|
|
90
|
+
function wrapAbortableResponse(response, signal) {
|
|
91
|
+
if (!signal || !response.body) {
|
|
92
|
+
return response;
|
|
93
|
+
}
|
|
94
|
+
const contentLengthHeader = response.headers.get("content-length");
|
|
95
|
+
const expectedSize = contentLengthHeader && /^\d+$/.test(contentLengthHeader) ? Number.parseInt(contentLengthHeader, 10) : undefined;
|
|
96
|
+
const sourceBody = response.body;
|
|
97
|
+
let reader;
|
|
98
|
+
let abortHandler;
|
|
99
|
+
let loaded = 0;
|
|
100
|
+
const cleanup = () => {
|
|
101
|
+
if (abortHandler) {
|
|
102
|
+
signal.removeEventListener("abort", abortHandler);
|
|
103
|
+
abortHandler = undefined;
|
|
104
|
+
}
|
|
105
|
+
reader?.releaseLock();
|
|
106
|
+
};
|
|
107
|
+
const body = new ReadableStream({
|
|
108
|
+
start(controller) {
|
|
109
|
+
reader = sourceBody.getReader();
|
|
110
|
+
if (signal.aborted) {
|
|
111
|
+
controller.error(createAbortError(signal));
|
|
112
|
+
return;
|
|
113
|
+
}
|
|
114
|
+
abortHandler = () => controller.error(createAbortError(signal));
|
|
115
|
+
signal.addEventListener("abort", abortHandler, { once: true });
|
|
116
|
+
},
|
|
117
|
+
async pull(controller) {
|
|
118
|
+
try {
|
|
119
|
+
if (signal.aborted) {
|
|
120
|
+
throw createAbortError(signal);
|
|
121
|
+
}
|
|
122
|
+
const { done, value } = await reader.read();
|
|
123
|
+
if (done) {
|
|
124
|
+
if (signal.aborted) {
|
|
125
|
+
throw createAbortError(signal);
|
|
126
|
+
}
|
|
127
|
+
if (expectedSize !== undefined && loaded < expectedSize) {
|
|
128
|
+
throw new Error(`Fetch ended before reading the full response body (${loaded}/${expectedSize} bytes)`);
|
|
129
|
+
}
|
|
130
|
+
cleanup();
|
|
131
|
+
controller.close();
|
|
132
|
+
return;
|
|
133
|
+
}
|
|
134
|
+
loaded += value.length;
|
|
135
|
+
controller.enqueue(value);
|
|
136
|
+
} catch (error) {
|
|
137
|
+
cleanup();
|
|
138
|
+
controller.error(error);
|
|
139
|
+
}
|
|
140
|
+
},
|
|
141
|
+
cancel(reason) {
|
|
142
|
+
cleanup();
|
|
143
|
+
return sourceBody.cancel(reason);
|
|
144
|
+
}
|
|
145
|
+
});
|
|
146
|
+
return new Response(body, {
|
|
147
|
+
headers: new Headers(response.headers),
|
|
148
|
+
status: response.status,
|
|
149
|
+
statusText: response.statusText
|
|
150
|
+
});
|
|
151
|
+
}
|
|
152
|
+
function abortableFetch(url, options) {
|
|
153
|
+
let modelSignal;
|
|
154
|
+
try {
|
|
155
|
+
const pathname = new URL(url).pathname;
|
|
156
|
+
for (const [modelPath, controller] of modelAbortControllers) {
|
|
157
|
+
if (pathname.includes(`/${modelPath}/`)) {
|
|
158
|
+
modelSignal = controller.signal;
|
|
159
|
+
break;
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
} catch {}
|
|
163
|
+
const combinedSignal = options?.signal ? combineAbortSignals(options.signal, modelSignal) : modelSignal;
|
|
164
|
+
return fetch(url, { ...options, ...combinedSignal ? { signal: combinedSignal } : {} }).then((response) => wrapAbortableResponse(response, combinedSignal));
|
|
165
|
+
}
|
|
166
|
+
function getHftSession(sessionId) {
|
|
167
|
+
return hftSessions.get(sessionId);
|
|
168
|
+
}
|
|
169
|
+
function setHftSession(sessionId, state) {
|
|
170
|
+
hftSessions.set(sessionId, state);
|
|
171
|
+
}
|
|
172
|
+
function disposeSessionResources(session) {
|
|
173
|
+
if (session.mode === "progressive") {
|
|
174
|
+
if (session.cache?.dispose) {
|
|
175
|
+
session.cache.dispose();
|
|
176
|
+
}
|
|
177
|
+
} else {
|
|
178
|
+
for (const tensor of Object.values(session.baseEntries)) {
|
|
179
|
+
if (tensor?.location === "gpu-buffer" && typeof tensor.dispose === "function") {
|
|
180
|
+
tensor.dispose();
|
|
181
|
+
}
|
|
182
|
+
}
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
function deleteHftSession(sessionId) {
|
|
186
|
+
const session = hftSessions.get(sessionId);
|
|
187
|
+
if (session) {
|
|
188
|
+
disposeSessionResources(session);
|
|
189
|
+
}
|
|
190
|
+
return hftSessions.delete(sessionId);
|
|
191
|
+
}
|
|
192
|
+
function disposeHftSessionsForModel(modelPath) {
|
|
193
|
+
for (const [id, state] of hftSessions) {
|
|
194
|
+
if (state.modelPath === modelPath) {
|
|
195
|
+
disposeSessionResources(state);
|
|
196
|
+
hftSessions.delete(id);
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
function clearPipelineCache() {
|
|
201
|
+
pipelines.clear();
|
|
202
|
+
}
|
|
203
|
+
function hasCachedPipeline(cacheKey) {
|
|
204
|
+
return pipelines.has(cacheKey);
|
|
205
|
+
}
|
|
206
|
+
function removeCachedPipeline(cacheKey) {
|
|
207
|
+
return pipelines.delete(cacheKey);
|
|
208
|
+
}
|
|
209
|
+
function isBrowserEnv() {
|
|
210
|
+
if (typeof globalThis === "undefined")
|
|
211
|
+
return false;
|
|
212
|
+
if (typeof globalThis.window !== "undefined")
|
|
213
|
+
return true;
|
|
214
|
+
if (typeof globalThis.WorkerGlobalScope !== "undefined")
|
|
215
|
+
return true;
|
|
216
|
+
return false;
|
|
217
|
+
}
|
|
218
|
+
function getPipelineCacheKey(model) {
|
|
219
|
+
const dtype = model.provider_config.dtype || "";
|
|
220
|
+
const device = model.provider_config.device || "";
|
|
221
|
+
const revision = model.provider_config.revision || "main";
|
|
222
|
+
return `${model.provider_config.model_path}:${model.provider_config.pipeline}:${dtype}:${device}:${revision}`;
|
|
223
|
+
}
|
|
224
|
+
async function getPipeline(model, onProgress, options = {}, signal, progressScaleMax = 10) {
|
|
225
|
+
const cacheKey = getPipelineCacheKey(model);
|
|
226
|
+
if (pipelines.has(cacheKey)) {
|
|
227
|
+
getLogger().debug("HFT pipeline cache hit", { cacheKey });
|
|
228
|
+
return pipelines.get(cacheKey);
|
|
229
|
+
}
|
|
230
|
+
const inFlight = pipelineLoadPromises.get(cacheKey);
|
|
231
|
+
if (inFlight) {
|
|
232
|
+
try {
|
|
233
|
+
await inFlight;
|
|
234
|
+
} catch {}
|
|
235
|
+
const cached = pipelines.get(cacheKey);
|
|
236
|
+
if (cached)
|
|
237
|
+
return cached;
|
|
238
|
+
}
|
|
239
|
+
const loadPromise = doGetPipeline(model, onProgress, options, progressScaleMax, cacheKey, signal).finally(() => {
|
|
240
|
+
pipelineLoadPromises.delete(cacheKey);
|
|
241
|
+
});
|
|
242
|
+
pipelineLoadPromises.set(cacheKey, loadPromise);
|
|
243
|
+
return loadPromise;
|
|
244
|
+
}
|
|
245
|
+
var _transformersSdk, _cacheDir, modelAbortControllers, pipelines, hftSessions, pipelineLoadPromises, IMAGE_PIPELINE_TYPES, HFT_NULL_PROCESSOR_PREFIX = "HFT_NULL_PROCESSOR:", doGetPipeline = async (model, onProgress, options, progressScaleMax, cacheKey, signal) => {
|
|
246
|
+
let lastProgressTime = 0;
|
|
247
|
+
let pendingProgress = null;
|
|
248
|
+
let throttleTimer = null;
|
|
249
|
+
const THROTTLE_MS = 160;
|
|
250
|
+
const buildProgressDetails = (file, fileProgress, filesMap) => {
|
|
251
|
+
const details = {
|
|
252
|
+
file,
|
|
253
|
+
progress: fileProgress
|
|
254
|
+
};
|
|
255
|
+
if (filesMap && Object.keys(filesMap).length > 0) {
|
|
256
|
+
details.files = filesMap;
|
|
257
|
+
}
|
|
258
|
+
return details;
|
|
259
|
+
};
|
|
260
|
+
const sendProgress = (progress, file, fileProgress, filesMap) => {
|
|
261
|
+
const now = Date.now();
|
|
262
|
+
const timeSinceLastEvent = now - lastProgressTime;
|
|
263
|
+
const isFirst = lastProgressTime === 0;
|
|
264
|
+
const isFinal = progress >= progressScaleMax;
|
|
265
|
+
if (isFirst || isFinal) {
|
|
266
|
+
if (throttleTimer) {
|
|
267
|
+
clearTimeout(throttleTimer);
|
|
268
|
+
throttleTimer = null;
|
|
269
|
+
}
|
|
270
|
+
pendingProgress = null;
|
|
271
|
+
onProgress(Math.round(progress), "Downloading model", buildProgressDetails(file, fileProgress, filesMap));
|
|
272
|
+
lastProgressTime = now;
|
|
273
|
+
return;
|
|
274
|
+
}
|
|
275
|
+
if (timeSinceLastEvent < THROTTLE_MS) {
|
|
276
|
+
pendingProgress = { progress, file, fileProgress, filesMap };
|
|
277
|
+
if (!throttleTimer) {
|
|
278
|
+
const timeRemaining = Math.max(1, THROTTLE_MS - timeSinceLastEvent);
|
|
279
|
+
throttleTimer = setTimeout(() => {
|
|
280
|
+
throttleTimer = null;
|
|
281
|
+
if (pendingProgress) {
|
|
282
|
+
const p = pendingProgress;
|
|
283
|
+
onProgress(Math.round(p.progress), "Downloading model", buildProgressDetails(p.file, p.fileProgress, p.filesMap));
|
|
284
|
+
lastProgressTime = Date.now();
|
|
285
|
+
pendingProgress = null;
|
|
286
|
+
}
|
|
287
|
+
}, timeRemaining);
|
|
288
|
+
}
|
|
289
|
+
return;
|
|
290
|
+
}
|
|
291
|
+
onProgress(Math.round(progress), "Downloading model", buildProgressDetails(file, fileProgress, filesMap));
|
|
292
|
+
lastProgressTime = now;
|
|
293
|
+
pendingProgress = null;
|
|
294
|
+
};
|
|
295
|
+
const abortSignal = signal;
|
|
296
|
+
const modelPath = model.provider_config.model_path;
|
|
297
|
+
const modelController = new AbortController;
|
|
298
|
+
modelAbortControllers.set(modelPath, modelController);
|
|
299
|
+
if (abortSignal) {
|
|
300
|
+
if (abortSignal.aborted) {
|
|
301
|
+
modelController.abort();
|
|
302
|
+
} else {
|
|
303
|
+
abortSignal.addEventListener("abort", () => modelController.abort(), { once: true });
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
const progressCallback = (status) => {
|
|
307
|
+
if (abortSignal?.aborted)
|
|
308
|
+
return;
|
|
309
|
+
if (status.status === "progress_total") {
|
|
310
|
+
const totalStatus = status;
|
|
311
|
+
const scaledProgress = totalStatus.progress * progressScaleMax / 100;
|
|
312
|
+
let activeFile = "";
|
|
313
|
+
let activeFileProgress = 0;
|
|
314
|
+
const files = totalStatus.files;
|
|
315
|
+
if (files) {
|
|
316
|
+
for (const [file, info] of Object.entries(files)) {
|
|
317
|
+
if (info.loaded < info.total) {
|
|
318
|
+
activeFile = file;
|
|
319
|
+
activeFileProgress = info.total > 0 ? info.loaded / info.total * 100 : 0;
|
|
320
|
+
break;
|
|
321
|
+
}
|
|
322
|
+
}
|
|
323
|
+
if (!activeFile) {
|
|
324
|
+
const fileNames = Object.keys(files);
|
|
325
|
+
if (fileNames.length > 0) {
|
|
326
|
+
activeFile = fileNames[fileNames.length - 1];
|
|
327
|
+
activeFileProgress = 100;
|
|
328
|
+
}
|
|
329
|
+
}
|
|
330
|
+
}
|
|
331
|
+
sendProgress(scaledProgress, activeFile, activeFileProgress, files);
|
|
332
|
+
}
|
|
333
|
+
};
|
|
334
|
+
let device = model.provider_config.device;
|
|
335
|
+
if (isBrowserEnv()) {
|
|
336
|
+
if (device === "gpu") {
|
|
337
|
+
device = "webgpu";
|
|
338
|
+
}
|
|
339
|
+
if (device === "cpu") {
|
|
340
|
+
device = "wasm";
|
|
341
|
+
}
|
|
342
|
+
if (device !== "wasm" && device !== "webgpu") {
|
|
343
|
+
device = "wasm";
|
|
344
|
+
}
|
|
345
|
+
} else {
|
|
346
|
+
if (device === "wasm" || device === "webgpu") {
|
|
347
|
+
device = undefined;
|
|
348
|
+
}
|
|
349
|
+
}
|
|
350
|
+
const dtype = model.provider_config.dtype || "";
|
|
351
|
+
const pipelineOptions = {
|
|
352
|
+
revision: model.provider_config.revision || "main",
|
|
353
|
+
...model.provider_config.use_external_data_format ? { useExternalDataFormat: model.provider_config.use_external_data_format } : {},
|
|
354
|
+
...dtype ? { dtype } : {},
|
|
355
|
+
...device ? { device } : {},
|
|
356
|
+
...options,
|
|
357
|
+
progress_callback: progressCallback
|
|
358
|
+
};
|
|
359
|
+
if (abortSignal?.aborted) {
|
|
360
|
+
modelAbortControllers.delete(modelPath);
|
|
361
|
+
throw new Error("Operation aborted before pipeline creation");
|
|
362
|
+
}
|
|
363
|
+
const pipelineType = model.provider_config.pipeline;
|
|
364
|
+
const { pipeline } = await loadTransformersSDK();
|
|
365
|
+
const logger = getLogger();
|
|
366
|
+
const pipelineTimerLabel = `hft:pipeline:${cacheKey}`;
|
|
367
|
+
logger.time(pipelineTimerLabel, { pipelineType, modelPath });
|
|
368
|
+
try {
|
|
369
|
+
const result = await pipeline(pipelineType, model.provider_config.model_path, pipelineOptions);
|
|
370
|
+
if (throttleTimer) {
|
|
371
|
+
clearTimeout(throttleTimer);
|
|
372
|
+
throttleTimer = null;
|
|
373
|
+
}
|
|
374
|
+
const finalPending = pendingProgress;
|
|
375
|
+
if (finalPending) {
|
|
376
|
+
onProgress(Math.round(finalPending.progress), "Downloading model", buildProgressDetails(finalPending.file, finalPending.fileProgress, finalPending.filesMap));
|
|
377
|
+
pendingProgress = null;
|
|
378
|
+
}
|
|
379
|
+
if (abortSignal?.aborted) {
|
|
380
|
+
logger.timeEnd(pipelineTimerLabel, { status: "aborted" });
|
|
381
|
+
throw new Error("Operation aborted after pipeline creation");
|
|
382
|
+
}
|
|
383
|
+
if (IMAGE_PIPELINE_TYPES.has(pipelineType) && result.processor == null) {
|
|
384
|
+
throw new Error(`${HFT_NULL_PROCESSOR_PREFIX} Image processor not initialized for ${pipelineType}/${modelPath}. Model cache may be incomplete.`);
|
|
385
|
+
}
|
|
386
|
+
logger.timeEnd(pipelineTimerLabel, { status: "loaded" });
|
|
387
|
+
pipelines.set(cacheKey, result);
|
|
388
|
+
return result;
|
|
389
|
+
} catch (error) {
|
|
390
|
+
logger.timeEnd(pipelineTimerLabel, { status: "error", error: String(error) });
|
|
391
|
+
if (!error?.message?.startsWith(HFT_NULL_PROCESSOR_PREFIX) && (abortSignal?.aborted || modelController.signal.aborted)) {
|
|
392
|
+
throw new Error("Pipeline download aborted");
|
|
393
|
+
}
|
|
394
|
+
throw error;
|
|
395
|
+
} finally {
|
|
396
|
+
modelAbortControllers.delete(modelPath);
|
|
397
|
+
const { random } = await loadTransformersSDK();
|
|
398
|
+
random.seed(model.provider_config.seed ?? undefined);
|
|
399
|
+
}
|
|
400
|
+
};
|
|
401
|
+
var init_HFT_Pipeline = __esm(() => {
|
|
402
|
+
modelAbortControllers = new Map;
|
|
403
|
+
pipelines = new Map;
|
|
404
|
+
hftSessions = new Map;
|
|
405
|
+
pipelineLoadPromises = new Map;
|
|
406
|
+
IMAGE_PIPELINE_TYPES = new Set([
|
|
407
|
+
"image-classification",
|
|
408
|
+
"image-segmentation",
|
|
409
|
+
"object-detection",
|
|
410
|
+
"image-to-text",
|
|
411
|
+
"image-feature-extraction",
|
|
412
|
+
"zero-shot-image-classification",
|
|
413
|
+
"depth-estimation",
|
|
414
|
+
"mask-generation"
|
|
415
|
+
]);
|
|
416
|
+
});
|
|
417
|
+
|
|
418
|
+
// src/ai-provider/common/HFT_Constants.ts
|
|
419
|
+
var HF_TRANSFORMERS_ONNX = "HF_TRANSFORMERS_ONNX";
|
|
420
|
+
var HF_TRANSFORMERS_ONNX_GPU = `${HF_TRANSFORMERS_ONNX}_gpu`;
|
|
421
|
+
var HF_TRANSFORMERS_ONNX_CPU = `${HF_TRANSFORMERS_ONNX}_cpu`;
|
|
422
|
+
var HTF_CACHE_NAME = "transformers-cache";
|
|
423
|
+
var QuantizationDataType = {
|
|
424
|
+
auto: "auto",
|
|
425
|
+
fp32: "fp32",
|
|
426
|
+
fp16: "fp16",
|
|
427
|
+
q8: "q8",
|
|
428
|
+
int8: "int8",
|
|
429
|
+
uint8: "uint8",
|
|
430
|
+
q4: "q4",
|
|
431
|
+
bnb4: "bnb4",
|
|
432
|
+
q4f16: "q4f16",
|
|
433
|
+
q2: "q2",
|
|
434
|
+
q2f16: "q2f16",
|
|
435
|
+
q1: "q1",
|
|
436
|
+
q1f16: "q1f16"
|
|
437
|
+
};
|
|
438
|
+
var TextPipelineUseCase = {
|
|
439
|
+
"fill-mask": "fill-mask",
|
|
440
|
+
"token-classification": "token-classification",
|
|
441
|
+
"text-generation": "text-generation",
|
|
442
|
+
"text2text-generation": "text2text-generation",
|
|
443
|
+
"text-classification": "text-classification",
|
|
444
|
+
summarization: "summarization",
|
|
445
|
+
translation: "translation",
|
|
446
|
+
"feature-extraction": "feature-extraction",
|
|
447
|
+
"zero-shot-classification": "zero-shot-classification",
|
|
448
|
+
"question-answering": "question-answering"
|
|
449
|
+
};
|
|
450
|
+
var VisionPipelineUseCase = {
|
|
451
|
+
"background-removal": "background-removal",
|
|
452
|
+
"image-segmentation": "image-segmentation",
|
|
453
|
+
"depth-estimation": "depth-estimation",
|
|
454
|
+
"image-classification": "image-classification",
|
|
455
|
+
"image-to-image": "image-to-image",
|
|
456
|
+
"image-to-text": "image-to-text",
|
|
457
|
+
"object-detection": "object-detection",
|
|
458
|
+
"image-feature-extraction": "image-feature-extraction"
|
|
459
|
+
};
|
|
460
|
+
var AudioPipelineUseCase = {
|
|
461
|
+
"audio-classification": "audio-classification",
|
|
462
|
+
"automatic-speech-recognition": "automatic-speech-recognition",
|
|
463
|
+
"text-to-speech": "text-to-speech"
|
|
464
|
+
};
|
|
465
|
+
var MultimodalPipelineUseCase = {
|
|
466
|
+
"document-question-answering": "document-question-answering",
|
|
467
|
+
"image-to-text": "image-to-text",
|
|
468
|
+
"zero-shot-audio-classification": "zero-shot-audio-classification",
|
|
469
|
+
"zero-shot-image-classification": "zero-shot-image-classification",
|
|
470
|
+
"zero-shot-object-detection": "zero-shot-object-detection"
|
|
471
|
+
};
|
|
472
|
+
var PipelineUseCase = {
|
|
473
|
+
...TextPipelineUseCase,
|
|
474
|
+
...VisionPipelineUseCase,
|
|
475
|
+
...AudioPipelineUseCase,
|
|
476
|
+
...MultimodalPipelineUseCase
|
|
477
|
+
};
|
|
478
|
+
|
|
479
|
+
// src/ai-provider/common/HFT_ModelSchema.ts
|
|
480
|
+
import { ModelConfigSchema, ModelRecordSchema } from "@workglow/ai/worker";
|
|
481
|
+
var HfTransformersOnnxModelSchema = {
|
|
482
|
+
type: "object",
|
|
483
|
+
properties: {
|
|
484
|
+
provider: {
|
|
485
|
+
const: HF_TRANSFORMERS_ONNX,
|
|
486
|
+
description: "Discriminator: ONNX runtime backend."
|
|
487
|
+
},
|
|
488
|
+
provider_config: {
|
|
489
|
+
type: "object",
|
|
490
|
+
description: "ONNX runtime-specific options.",
|
|
491
|
+
properties: {
|
|
492
|
+
pipeline: {
|
|
493
|
+
type: "string",
|
|
494
|
+
enum: Object.values(PipelineUseCase),
|
|
495
|
+
description: "Pipeline type for the ONNX model.",
|
|
496
|
+
default: "text-generation"
|
|
497
|
+
},
|
|
498
|
+
model_path: {
|
|
499
|
+
type: "string",
|
|
500
|
+
description: "Filesystem path or URI for the ONNX model."
|
|
501
|
+
},
|
|
502
|
+
revision: {
|
|
503
|
+
type: "string",
|
|
504
|
+
description: "Git revision (branch, tag, or commit hash) of the model repository.",
|
|
505
|
+
default: "main"
|
|
506
|
+
},
|
|
507
|
+
dtype: {
|
|
508
|
+
type: "string",
|
|
509
|
+
enum: Object.values(QuantizationDataType),
|
|
510
|
+
description: "Data type for the ONNX model.",
|
|
511
|
+
default: "auto"
|
|
512
|
+
},
|
|
513
|
+
device: {
|
|
514
|
+
type: "string",
|
|
515
|
+
enum: ["cpu", "gpu", "webgpu", "wasm", "metal"],
|
|
516
|
+
description: "High-level device selection.",
|
|
517
|
+
default: "webgpu"
|
|
518
|
+
},
|
|
519
|
+
execution_providers: {
|
|
520
|
+
type: "array",
|
|
521
|
+
items: { type: "string" },
|
|
522
|
+
description: "Raw ONNX Runtime execution provider identifiers.",
|
|
523
|
+
"x-ui-hidden": true
|
|
524
|
+
},
|
|
525
|
+
intra_op_num_threads: {
|
|
526
|
+
type: "integer",
|
|
527
|
+
minimum: 1
|
|
528
|
+
},
|
|
529
|
+
inter_op_num_threads: {
|
|
530
|
+
type: "integer",
|
|
531
|
+
minimum: 1
|
|
532
|
+
},
|
|
533
|
+
use_external_data_format: {
|
|
534
|
+
type: "boolean",
|
|
535
|
+
description: "Whether the model uses external data format."
|
|
536
|
+
},
|
|
537
|
+
native_dimensions: {
|
|
538
|
+
type: "integer",
|
|
539
|
+
description: "The native dimensions of the model."
|
|
540
|
+
},
|
|
541
|
+
pooling: {
|
|
542
|
+
type: "string",
|
|
543
|
+
enum: ["mean", "last_token", "cls"],
|
|
544
|
+
description: "The pooling strategy to use for the model.",
|
|
545
|
+
default: "mean"
|
|
546
|
+
},
|
|
547
|
+
normalize: {
|
|
548
|
+
type: "boolean",
|
|
549
|
+
description: "Whether the model uses normalization.",
|
|
550
|
+
default: true
|
|
551
|
+
},
|
|
552
|
+
language_style: {
|
|
553
|
+
type: "string",
|
|
554
|
+
description: "The language style of the model."
|
|
555
|
+
},
|
|
556
|
+
seed: {
|
|
557
|
+
type: "integer",
|
|
558
|
+
description: "RNG seed passed to transformers.js sampling. Omit for time-based seeding; set for reproducible generation.",
|
|
559
|
+
minimum: 0
|
|
560
|
+
},
|
|
561
|
+
mrl: {
|
|
562
|
+
type: "boolean",
|
|
563
|
+
description: "Whether the model uses matryoshka.",
|
|
564
|
+
default: false
|
|
565
|
+
}
|
|
566
|
+
},
|
|
567
|
+
required: ["model_path", "pipeline"],
|
|
568
|
+
additionalProperties: false,
|
|
569
|
+
if: {
|
|
570
|
+
properties: {
|
|
571
|
+
pipeline: {
|
|
572
|
+
const: "feature-extraction"
|
|
573
|
+
}
|
|
574
|
+
}
|
|
575
|
+
},
|
|
576
|
+
then: {
|
|
577
|
+
required: ["native_dimensions"]
|
|
578
|
+
}
|
|
579
|
+
}
|
|
580
|
+
},
|
|
581
|
+
required: ["provider", "provider_config"],
|
|
582
|
+
additionalProperties: true
|
|
583
|
+
};
|
|
584
|
+
var HfTransformersOnnxModelRecordSchema = {
|
|
585
|
+
type: "object",
|
|
586
|
+
properties: {
|
|
587
|
+
...ModelRecordSchema.properties,
|
|
588
|
+
...HfTransformersOnnxModelSchema.properties
|
|
589
|
+
},
|
|
590
|
+
required: [...ModelRecordSchema.required, ...HfTransformersOnnxModelSchema.required],
|
|
591
|
+
additionalProperties: false
|
|
592
|
+
};
|
|
593
|
+
var HfTransformersOnnxModelConfigSchema = {
|
|
594
|
+
type: "object",
|
|
595
|
+
properties: {
|
|
596
|
+
...ModelConfigSchema.properties,
|
|
597
|
+
...HfTransformersOnnxModelSchema.properties
|
|
598
|
+
},
|
|
599
|
+
required: [...ModelConfigSchema.required, ...HfTransformersOnnxModelSchema.required],
|
|
600
|
+
additionalProperties: false
|
|
601
|
+
};
|
|
602
|
+
|
|
603
|
+
// src/ai-provider/common/HFT_OnnxDtypes.ts
|
|
604
|
+
var ONNX_QUANTIZATION_SUFFIX_MAPPING = {
|
|
605
|
+
fp32: "",
|
|
606
|
+
fp16: "_fp16",
|
|
607
|
+
int8: "_int8",
|
|
608
|
+
uint8: "_uint8",
|
|
609
|
+
q8: "_quantized",
|
|
610
|
+
q4: "_q4",
|
|
611
|
+
q4f16: "_q4f16",
|
|
612
|
+
bnb4: "_bnb4",
|
|
613
|
+
q2: "_q2",
|
|
614
|
+
q2f16: "_q2f16",
|
|
615
|
+
q1: "_q1",
|
|
616
|
+
q1f16: "_q1f16"
|
|
617
|
+
};
|
|
618
|
+
var SUFFIXES_LONGEST_FIRST = Object.entries(ONNX_QUANTIZATION_SUFFIX_MAPPING).filter(([, suffix]) => suffix !== "").sort((a, b) => b[1].length - a[1].length);
|
|
619
|
+
function parseOnnxQuantizations(params) {
|
|
620
|
+
const subfolder = params.subfolder ?? "onnx";
|
|
621
|
+
const prefix = subfolder + "/";
|
|
622
|
+
const stems = [];
|
|
623
|
+
for (const fp of params.filePaths) {
|
|
624
|
+
if (!fp.startsWith(prefix))
|
|
625
|
+
continue;
|
|
626
|
+
if (!fp.endsWith(".onnx"))
|
|
627
|
+
continue;
|
|
628
|
+
if (fp.endsWith(".onnx_data"))
|
|
629
|
+
continue;
|
|
630
|
+
stems.push(fp.slice(prefix.length, -".onnx".length));
|
|
631
|
+
}
|
|
632
|
+
if (stems.length === 0)
|
|
633
|
+
return [];
|
|
634
|
+
const parsed = [];
|
|
635
|
+
for (const stem of stems) {
|
|
636
|
+
let matched = false;
|
|
637
|
+
for (const [dtype, suffix] of SUFFIXES_LONGEST_FIRST) {
|
|
638
|
+
if (stem.endsWith(suffix)) {
|
|
639
|
+
parsed.push({ baseName: stem.slice(0, -suffix.length), dtype });
|
|
640
|
+
matched = true;
|
|
641
|
+
break;
|
|
642
|
+
}
|
|
643
|
+
}
|
|
644
|
+
if (!matched) {
|
|
645
|
+
parsed.push({ baseName: stem, dtype: "fp32" });
|
|
646
|
+
}
|
|
647
|
+
}
|
|
648
|
+
const allBaseNames = new Set(parsed.map((p) => p.baseName));
|
|
649
|
+
const byDtype = new Map;
|
|
650
|
+
for (const { baseName, dtype } of parsed) {
|
|
651
|
+
let set = byDtype.get(dtype);
|
|
652
|
+
if (!set) {
|
|
653
|
+
set = new Set;
|
|
654
|
+
byDtype.set(dtype, set);
|
|
655
|
+
}
|
|
656
|
+
set.add(baseName);
|
|
657
|
+
}
|
|
658
|
+
const allDtypes = Object.keys(ONNX_QUANTIZATION_SUFFIX_MAPPING);
|
|
659
|
+
return allDtypes.filter((dtype) => {
|
|
660
|
+
const set = byDtype.get(dtype);
|
|
661
|
+
return set !== undefined && set.size === allBaseNames.size;
|
|
662
|
+
});
|
|
663
|
+
}
|
|
664
|
+
|
|
665
|
+
// src/ai-provider/common/HFT_ToolMarkup.ts
|
|
666
|
+
function createToolCallMarkupFilter(emit) {
|
|
667
|
+
const OPEN_TAG = "<tool_call>";
|
|
668
|
+
const CLOSE_TAG = "</tool_call>";
|
|
669
|
+
let state = "text";
|
|
670
|
+
let pending = "";
|
|
671
|
+
function feed(token) {
|
|
672
|
+
if (state === "tag") {
|
|
673
|
+
pending += token;
|
|
674
|
+
const closeIdx = pending.indexOf(CLOSE_TAG);
|
|
675
|
+
if (closeIdx !== -1) {
|
|
676
|
+
const afterClose = pending.slice(closeIdx + CLOSE_TAG.length);
|
|
677
|
+
pending = "";
|
|
678
|
+
state = "text";
|
|
679
|
+
if (afterClose.length > 0) {
|
|
680
|
+
feed(afterClose);
|
|
681
|
+
}
|
|
682
|
+
}
|
|
683
|
+
return;
|
|
684
|
+
}
|
|
685
|
+
const combined = pending + token;
|
|
686
|
+
const openIdx = combined.indexOf(OPEN_TAG);
|
|
687
|
+
if (openIdx !== -1) {
|
|
688
|
+
const before = combined.slice(0, openIdx);
|
|
689
|
+
if (before.length > 0) {
|
|
690
|
+
emit(before);
|
|
691
|
+
}
|
|
692
|
+
pending = "";
|
|
693
|
+
state = "tag";
|
|
694
|
+
const afterOpen = combined.slice(openIdx + OPEN_TAG.length);
|
|
695
|
+
if (afterOpen.length > 0) {
|
|
696
|
+
feed(afterOpen);
|
|
697
|
+
}
|
|
698
|
+
return;
|
|
699
|
+
}
|
|
700
|
+
let prefixLen = 0;
|
|
701
|
+
for (let len = Math.min(combined.length, OPEN_TAG.length - 1);len >= 1; len--) {
|
|
702
|
+
if (combined.endsWith(OPEN_TAG.slice(0, len))) {
|
|
703
|
+
prefixLen = len;
|
|
704
|
+
break;
|
|
705
|
+
}
|
|
706
|
+
}
|
|
707
|
+
if (prefixLen > 0) {
|
|
708
|
+
const safe = combined.slice(0, combined.length - prefixLen);
|
|
709
|
+
if (safe.length > 0) {
|
|
710
|
+
emit(safe);
|
|
711
|
+
}
|
|
712
|
+
pending = combined.slice(combined.length - prefixLen);
|
|
713
|
+
} else {
|
|
714
|
+
if (combined.length > 0) {
|
|
715
|
+
emit(combined);
|
|
716
|
+
}
|
|
717
|
+
pending = "";
|
|
718
|
+
}
|
|
719
|
+
}
|
|
720
|
+
function flush() {
|
|
721
|
+
if (pending.length > 0 && state === "text") {
|
|
722
|
+
emit(pending);
|
|
723
|
+
pending = "";
|
|
724
|
+
}
|
|
725
|
+
pending = "";
|
|
726
|
+
state = "text";
|
|
727
|
+
}
|
|
728
|
+
return { feed, flush };
|
|
729
|
+
}
|
|
730
|
+
|
|
731
|
+
// src/ai-provider/registerHuggingFaceTransformersInline.ts
|
|
732
|
+
import { registerProviderInline } from "@workglow/ai-provider/common";
|
|
733
|
+
|
|
734
|
+
// src/ai-provider/common/HFT_InlineLifecycle.ts
|
|
735
|
+
async function clearHftInlinePipelineCache() {
|
|
736
|
+
const { clearPipelineCache: clearPipelineCache2 } = await Promise.resolve().then(() => (init_HFT_Pipeline(), exports_HFT_Pipeline));
|
|
737
|
+
clearPipelineCache2();
|
|
738
|
+
}
|
|
739
|
+
|
|
740
|
+
// src/ai-provider/common/HFT_ModelSearch.ts
|
|
741
|
+
import { searchHfModels, mapHfModelResult } from "@workglow/ai-provider/common";
|
|
742
|
+
var HFT_ModelSearch = async (input, _model, _onProgress, signal) => {
|
|
743
|
+
const entries = await searchHfModels(input.query?.trim() ?? "", { filter: "onnx" }, ["siblings"], signal);
|
|
744
|
+
const results = entries.map((entry) => {
|
|
745
|
+
const item = mapHfModelResult(entry, HF_TRANSFORMERS_ONNX);
|
|
746
|
+
if (entry.siblings && entry.siblings.length > 0) {
|
|
747
|
+
const filePaths = entry.siblings.map((s) => s.rfilename);
|
|
748
|
+
const quantizations = parseOnnxQuantizations({ filePaths });
|
|
749
|
+
if (quantizations.length > 0) {
|
|
750
|
+
const record = item.record;
|
|
751
|
+
const providerConfig = record.provider_config ?? {};
|
|
752
|
+
providerConfig.quantizations = quantizations;
|
|
753
|
+
record.provider_config = providerConfig;
|
|
754
|
+
}
|
|
755
|
+
}
|
|
756
|
+
const raw = item.raw;
|
|
757
|
+
delete raw.siblings;
|
|
758
|
+
return item;
|
|
759
|
+
});
|
|
760
|
+
return { results };
|
|
761
|
+
};
|
|
762
|
+
|
|
763
|
+
// src/ai-provider/common/HFT_BackgroundRemoval.ts
|
|
764
|
+
init_HFT_Pipeline();
|
|
765
|
+
import { dataUriToImageValue, imageValueToBlob } from "@workglow/ai-provider/common";
|
|
766
|
+
function rawImageToBase64Png(image) {
|
|
767
|
+
const fn = image.toBase64;
|
|
768
|
+
if (typeof fn !== "function") {
|
|
769
|
+
throw new Error("HFT_BackgroundRemoval: RawImage.toBase64 unavailable in this transformers version");
|
|
770
|
+
}
|
|
771
|
+
return fn.call(image);
|
|
772
|
+
}
|
|
773
|
+
var HFT_BackgroundRemoval = async (input, model, onProgress, signal) => {
|
|
774
|
+
const remover = await getPipeline(model, onProgress, {}, signal);
|
|
775
|
+
const imageArg = await imageValueToBlob(input.image);
|
|
776
|
+
const result = await remover(imageArg);
|
|
777
|
+
const resultImage = Array.isArray(result) ? result[0] : result;
|
|
778
|
+
const dataUri = `data:image/png;base64,${rawImageToBase64Png(resultImage)}`;
|
|
779
|
+
return {
|
|
780
|
+
image: await dataUriToImageValue(dataUri)
|
|
781
|
+
};
|
|
782
|
+
};
|
|
783
|
+
|
|
784
|
+
// src/ai-provider/common/HFT_Chat.ts
|
|
785
|
+
init_HFT_Pipeline();
|
|
786
|
+
|
|
787
|
+
// src/ai-provider/common/HFT_ToolCalling.ts
|
|
788
|
+
init_HFT_Pipeline();
|
|
789
|
+
import {
|
|
790
|
+
buildToolDescription,
|
|
791
|
+
filterValidToolCalls,
|
|
792
|
+
toTextFlatMessages
|
|
793
|
+
} from "@workglow/ai/worker";
|
|
794
|
+
import {
|
|
795
|
+
adaptParserResult,
|
|
796
|
+
forcedToolSelection,
|
|
797
|
+
getAvailableParsers,
|
|
798
|
+
getGenerationPrefix,
|
|
799
|
+
parseToolCalls
|
|
800
|
+
} from "@workglow/ai-provider/common";
|
|
801
|
+
|
|
802
|
+
// src/ai-provider/common/HFT_Streaming.ts
|
|
803
|
+
function createStreamEventQueue() {
|
|
804
|
+
const buffer = [];
|
|
805
|
+
let resolve = null;
|
|
806
|
+
let finished = false;
|
|
807
|
+
let err = null;
|
|
808
|
+
const push = (event) => {
|
|
809
|
+
if (resolve) {
|
|
810
|
+
const r = resolve;
|
|
811
|
+
resolve = null;
|
|
812
|
+
r({ value: event, done: false });
|
|
813
|
+
} else {
|
|
814
|
+
buffer.push(event);
|
|
815
|
+
}
|
|
816
|
+
};
|
|
817
|
+
const done = () => {
|
|
818
|
+
finished = true;
|
|
819
|
+
if (resolve) {
|
|
820
|
+
const r = resolve;
|
|
821
|
+
resolve = null;
|
|
822
|
+
r({ value: undefined, done: true });
|
|
823
|
+
}
|
|
824
|
+
};
|
|
825
|
+
const error = (e) => {
|
|
826
|
+
err = e;
|
|
827
|
+
if (resolve) {
|
|
828
|
+
const r = resolve;
|
|
829
|
+
resolve = null;
|
|
830
|
+
r({ value: undefined, done: true });
|
|
831
|
+
}
|
|
832
|
+
};
|
|
833
|
+
const iterable = {
|
|
834
|
+
[Symbol.asyncIterator]() {
|
|
835
|
+
return {
|
|
836
|
+
next() {
|
|
837
|
+
if (err)
|
|
838
|
+
return Promise.reject(err);
|
|
839
|
+
if (buffer.length > 0) {
|
|
840
|
+
return Promise.resolve({ value: buffer.shift(), done: false });
|
|
841
|
+
}
|
|
842
|
+
if (finished) {
|
|
843
|
+
return Promise.resolve({ value: undefined, done: true });
|
|
844
|
+
}
|
|
845
|
+
return new Promise((r) => {
|
|
846
|
+
resolve = r;
|
|
847
|
+
});
|
|
848
|
+
}
|
|
849
|
+
};
|
|
850
|
+
}
|
|
851
|
+
};
|
|
852
|
+
return { push, done, error, iterable };
|
|
853
|
+
}
|
|
854
|
+
function createStreamingTextStreamer(tokenizer, queue, textStreamer) {
|
|
855
|
+
return new textStreamer(tokenizer, {
|
|
856
|
+
skip_prompt: true,
|
|
857
|
+
decode_kwargs: { skip_special_tokens: true },
|
|
858
|
+
callback_function: (text) => {
|
|
859
|
+
queue.push({ type: "text-delta", port: "text", textDelta: text });
|
|
860
|
+
}
|
|
861
|
+
});
|
|
862
|
+
}
|
|
863
|
+
function createTextStreamer(tokenizer, updateProgress, textStreamer) {
|
|
864
|
+
let count = 0;
|
|
865
|
+
return new textStreamer(tokenizer, {
|
|
866
|
+
skip_prompt: true,
|
|
867
|
+
decode_kwargs: { skip_special_tokens: true },
|
|
868
|
+
callback_function: (text) => {
|
|
869
|
+
count++;
|
|
870
|
+
const result = 100 * (1 - Math.exp(-0.05 * count));
|
|
871
|
+
const progress = Math.round(Math.min(result, 100));
|
|
872
|
+
updateProgress(progress, "Generating", { text, progress });
|
|
873
|
+
}
|
|
874
|
+
});
|
|
875
|
+
}
|
|
876
|
+
|
|
877
|
+
// src/ai-provider/common/HFT_ToolCalling.ts
|
|
878
|
+
function getModelTextCandidates(model) {
|
|
879
|
+
return [model.model_id, model.title, model.description, model.provider_config.model_path].filter((value) => typeof value === "string" && value.length > 0).map((value) => value.toLowerCase());
|
|
880
|
+
}
|
|
881
|
+
function detectModelFamilyFromConfig(model) {
|
|
882
|
+
const candidates = getModelTextCandidates(model);
|
|
883
|
+
const families = getAvailableParsers();
|
|
884
|
+
for (const candidate of candidates) {
|
|
885
|
+
for (const family of families) {
|
|
886
|
+
if (candidate.includes(family)) {
|
|
887
|
+
return family;
|
|
888
|
+
}
|
|
889
|
+
}
|
|
890
|
+
}
|
|
891
|
+
return null;
|
|
892
|
+
}
|
|
893
|
+
function normalizeParsedToolCalls(input, toolCalls) {
|
|
894
|
+
const forcedToolName = forcedToolSelection(input);
|
|
895
|
+
return toolCalls.map((toolCall) => toolCall.name ? toolCall : {
|
|
896
|
+
...toolCall,
|
|
897
|
+
name: forcedToolName ?? toolCall.name
|
|
898
|
+
});
|
|
899
|
+
}
|
|
900
|
+
function mapHFTTools(tools) {
|
|
901
|
+
return tools.map((t) => ({
|
|
902
|
+
type: "function",
|
|
903
|
+
function: {
|
|
904
|
+
name: t.name,
|
|
905
|
+
description: buildToolDescription(t),
|
|
906
|
+
parameters: t.inputSchema
|
|
907
|
+
}
|
|
908
|
+
}));
|
|
909
|
+
}
|
|
910
|
+
function resolveHFTToolsAndMessages(input, messages) {
|
|
911
|
+
if (input.toolChoice === "none") {
|
|
912
|
+
return;
|
|
913
|
+
}
|
|
914
|
+
if (input.toolChoice === "required") {
|
|
915
|
+
const requiredInstruction = "You must call at least one tool from the provided tool list when answering.";
|
|
916
|
+
if (messages.length > 0 && messages[0].role === "system") {
|
|
917
|
+
messages[0] = { ...messages[0], content: `${messages[0].content}
|
|
918
|
+
|
|
919
|
+
${requiredInstruction}` };
|
|
920
|
+
} else {
|
|
921
|
+
messages.unshift({ role: "system", content: requiredInstruction });
|
|
922
|
+
}
|
|
923
|
+
return mapHFTTools(input.tools);
|
|
924
|
+
}
|
|
925
|
+
if (typeof input.toolChoice === "string" && input.toolChoice !== "auto") {
|
|
926
|
+
const selectedTools = input.tools?.filter((tool) => tool.name === input.toolChoice);
|
|
927
|
+
const toolsToMap = selectedTools && selectedTools.length > 0 ? selectedTools : input.tools;
|
|
928
|
+
return mapHFTTools(toolsToMap);
|
|
929
|
+
}
|
|
930
|
+
return mapHFTTools(input.tools);
|
|
931
|
+
}
|
|
932
|
+
function buildHFTMessages(messages, systemPrompt, prompt, toolChoice) {
|
|
933
|
+
const out = [];
|
|
934
|
+
if (systemPrompt) {
|
|
935
|
+
out.push({ role: "system", content: systemPrompt });
|
|
936
|
+
}
|
|
937
|
+
if (toolChoice === "required") {
|
|
938
|
+
out.push({
|
|
939
|
+
role: "system",
|
|
940
|
+
content: "You MUST call one of the provided tools in this turn."
|
|
941
|
+
});
|
|
942
|
+
}
|
|
943
|
+
if (!messages || messages.length === 0) {
|
|
944
|
+
out.push({ role: "user", content: extractPromptText(prompt) });
|
|
945
|
+
return out;
|
|
946
|
+
}
|
|
947
|
+
for (const msg of messages) {
|
|
948
|
+
if (msg.role === "user") {
|
|
949
|
+
const text = msg.content.filter((b) => b.type === "text").map((b) => b.text).join("");
|
|
950
|
+
out.push({ role: "user", content: text });
|
|
951
|
+
} else if (msg.role === "assistant") {
|
|
952
|
+
const text = msg.content.filter((b) => b.type === "text").map((b) => b.text).join("");
|
|
953
|
+
const toolCalls = msg.content.filter((b) => b.type === "tool_use").map((b) => {
|
|
954
|
+
const tu = b;
|
|
955
|
+
return { id: tu.id, name: tu.name, arguments: tu.input };
|
|
956
|
+
});
|
|
957
|
+
const entry = { role: "assistant", content: text };
|
|
958
|
+
if (toolCalls.length > 0)
|
|
959
|
+
entry.tool_calls = toolCalls;
|
|
960
|
+
out.push(entry);
|
|
961
|
+
} else if (msg.role === "tool") {
|
|
962
|
+
for (const b of msg.content) {
|
|
963
|
+
if (b.type !== "tool_result")
|
|
964
|
+
continue;
|
|
965
|
+
const text = b.content.filter((inner) => inner.type === "text").map((inner) => inner.text).join("");
|
|
966
|
+
out.push({
|
|
967
|
+
role: "tool",
|
|
968
|
+
content: text,
|
|
969
|
+
tool_call_id: b.tool_use_id
|
|
970
|
+
});
|
|
971
|
+
}
|
|
972
|
+
}
|
|
973
|
+
}
|
|
974
|
+
return out;
|
|
975
|
+
}
|
|
976
|
+
function extractPromptText(prompt) {
|
|
977
|
+
if (typeof prompt === "string")
|
|
978
|
+
return prompt;
|
|
979
|
+
if (!Array.isArray(prompt))
|
|
980
|
+
return String(prompt ?? "");
|
|
981
|
+
return prompt.map((item) => {
|
|
982
|
+
if (typeof item === "string")
|
|
983
|
+
return item;
|
|
984
|
+
if (item && typeof item === "object" && item.type === "text") {
|
|
985
|
+
return item.text;
|
|
986
|
+
}
|
|
987
|
+
return "";
|
|
988
|
+
}).filter((s) => s).join(`
|
|
989
|
+
`);
|
|
990
|
+
}
|
|
991
|
+
function selectHFTTools(input) {
|
|
992
|
+
if (input.toolChoice === "none")
|
|
993
|
+
return;
|
|
994
|
+
if (typeof input.toolChoice === "string" && input.toolChoice !== "auto" && input.toolChoice !== "required") {
|
|
995
|
+
const selected = input.tools.filter((t) => t.name === input.toolChoice);
|
|
996
|
+
return mapHFTTools(selected.length > 0 ? selected : input.tools);
|
|
997
|
+
}
|
|
998
|
+
return mapHFTTools(input.tools);
|
|
999
|
+
}
|
|
1000
|
+
function hasToolMessages(input) {
|
|
1001
|
+
return input.messages?.some((m) => m.role === "tool") ?? false;
|
|
1002
|
+
}
|
|
1003
|
+
function buildPromptAndPrefix(tokenizer, input, modelFamily) {
|
|
1004
|
+
let basePrompt;
|
|
1005
|
+
if (hasToolMessages(input)) {
|
|
1006
|
+
const messages = buildHFTMessages(input.messages, input.systemPrompt, input.prompt, input.toolChoice);
|
|
1007
|
+
const tools = selectHFTTools(input);
|
|
1008
|
+
basePrompt = tokenizer.apply_chat_template(messages, {
|
|
1009
|
+
tools,
|
|
1010
|
+
tokenize: false,
|
|
1011
|
+
add_generation_prompt: true
|
|
1012
|
+
});
|
|
1013
|
+
} else {
|
|
1014
|
+
const messages = toTextFlatMessages(input);
|
|
1015
|
+
const tools = resolveHFTToolsAndMessages(input, messages);
|
|
1016
|
+
basePrompt = tokenizer.apply_chat_template(messages, {
|
|
1017
|
+
tools,
|
|
1018
|
+
tokenize: false,
|
|
1019
|
+
add_generation_prompt: true
|
|
1020
|
+
});
|
|
1021
|
+
}
|
|
1022
|
+
const responsePrefix = input.toolChoice === "none" || hasToolMessages(input) ? undefined : getGenerationPrefix(modelFamily, forcedToolSelection(input));
|
|
1023
|
+
return {
|
|
1024
|
+
prompt: responsePrefix ? `${basePrompt}${responsePrefix}` : basePrompt,
|
|
1025
|
+
responsePrefix
|
|
1026
|
+
};
|
|
1027
|
+
}
|
|
1028
|
+
var HFT_ToolCalling = async (input, model, onProgress, signal, _outputSchema, sessionId) => {
|
|
1029
|
+
const generateText = await getPipeline(model, onProgress, {}, signal);
|
|
1030
|
+
const { TextStreamer, InterruptableStoppingCriteria } = await loadTransformersSDK();
|
|
1031
|
+
const hfTokenizer = generateText.tokenizer;
|
|
1032
|
+
const hfModel = generateText.model;
|
|
1033
|
+
const streamer = createTextStreamer(hfTokenizer, onProgress, TextStreamer);
|
|
1034
|
+
const stopping_criteria = new InterruptableStoppingCriteria;
|
|
1035
|
+
if (signal) {
|
|
1036
|
+
signal.addEventListener("abort", () => stopping_criteria.interrupt(), { once: true });
|
|
1037
|
+
}
|
|
1038
|
+
const modelFamily = detectModelFamilyFromConfig(model);
|
|
1039
|
+
const { prompt, responsePrefix } = buildPromptAndPrefix(hfTokenizer, input, modelFamily);
|
|
1040
|
+
const inputs = hfTokenizer(prompt, { return_tensor: true });
|
|
1041
|
+
const modelPath = model.provider_config.model_path;
|
|
1042
|
+
let session = sessionId ? getHftSession(sessionId) : undefined;
|
|
1043
|
+
let past_key_values = undefined;
|
|
1044
|
+
if (sessionId && !session) {
|
|
1045
|
+
const { DynamicCache } = await loadTransformersSDK();
|
|
1046
|
+
const cache = new DynamicCache;
|
|
1047
|
+
await hfModel.generate({
|
|
1048
|
+
...inputs,
|
|
1049
|
+
max_new_tokens: 0,
|
|
1050
|
+
past_key_values: cache
|
|
1051
|
+
});
|
|
1052
|
+
const baseEntries = {};
|
|
1053
|
+
for (const key of Object.keys(cache)) {
|
|
1054
|
+
baseEntries[key] = cache[key];
|
|
1055
|
+
}
|
|
1056
|
+
const newSession = {
|
|
1057
|
+
mode: "prefix-rewind",
|
|
1058
|
+
baseEntries,
|
|
1059
|
+
baseSeqLength: cache.get_seq_length(),
|
|
1060
|
+
modelPath
|
|
1061
|
+
};
|
|
1062
|
+
setHftSession(sessionId, newSession);
|
|
1063
|
+
session = newSession;
|
|
1064
|
+
}
|
|
1065
|
+
if (session?.mode === "prefix-rewind") {
|
|
1066
|
+
const { DynamicCache } = await loadTransformersSDK();
|
|
1067
|
+
past_key_values = new DynamicCache(session.baseEntries);
|
|
1068
|
+
}
|
|
1069
|
+
const output = await hfModel.generate({
|
|
1070
|
+
...inputs,
|
|
1071
|
+
max_new_tokens: input.maxTokens ?? 1024,
|
|
1072
|
+
streamer,
|
|
1073
|
+
stopping_criteria: [stopping_criteria],
|
|
1074
|
+
...past_key_values ? { past_key_values } : {}
|
|
1075
|
+
});
|
|
1076
|
+
const promptLen = inputs.input_ids.dims[1];
|
|
1077
|
+
const seqLen = output.dims[1];
|
|
1078
|
+
const newTokens = output.slice(0, [promptLen, seqLen], null);
|
|
1079
|
+
const decoded = hfTokenizer.decode(newTokens, {
|
|
1080
|
+
skip_special_tokens: false
|
|
1081
|
+
});
|
|
1082
|
+
const parseableText = responsePrefix ? `${responsePrefix}${decoded}` : decoded;
|
|
1083
|
+
const { text, toolCalls } = adaptParserResult(parseToolCalls(parseableText, { parser: modelFamily }));
|
|
1084
|
+
return {
|
|
1085
|
+
text,
|
|
1086
|
+
toolCalls: filterValidToolCalls(normalizeParsedToolCalls(input, toolCalls), input.tools)
|
|
1087
|
+
};
|
|
1088
|
+
};
|
|
1089
|
+
var HFT_ToolCalling_Stream = async function* (input, model, signal, _outputSchema, sessionId) {
|
|
1090
|
+
const noopProgress = () => {};
|
|
1091
|
+
const generateText = await getPipeline(model, noopProgress, {}, signal);
|
|
1092
|
+
const { TextStreamer, InterruptableStoppingCriteria } = await loadTransformersSDK();
|
|
1093
|
+
const modelFamily = detectModelFamilyFromConfig(model);
|
|
1094
|
+
const { prompt, responsePrefix } = buildPromptAndPrefix(generateText.tokenizer, input, modelFamily);
|
|
1095
|
+
const innerQueue = createStreamEventQueue();
|
|
1096
|
+
const outerQueue = createStreamEventQueue();
|
|
1097
|
+
const streamer = createStreamingTextStreamer(generateText.tokenizer, innerQueue, TextStreamer);
|
|
1098
|
+
const stopping_criteria = new InterruptableStoppingCriteria;
|
|
1099
|
+
if (signal) {
|
|
1100
|
+
signal.addEventListener("abort", () => stopping_criteria.interrupt(), { once: true });
|
|
1101
|
+
}
|
|
1102
|
+
const modelPath = model.provider_config.model_path;
|
|
1103
|
+
let session = sessionId ? getHftSession(sessionId) : undefined;
|
|
1104
|
+
let past_key_values = undefined;
|
|
1105
|
+
if (sessionId && !session) {
|
|
1106
|
+
const { DynamicCache } = await loadTransformersSDK();
|
|
1107
|
+
const hfModel = generateText.model;
|
|
1108
|
+
const hfTokenizer = generateText.tokenizer;
|
|
1109
|
+
const cache = new DynamicCache;
|
|
1110
|
+
const tokenized = hfTokenizer(prompt);
|
|
1111
|
+
await hfModel.generate({
|
|
1112
|
+
...tokenized,
|
|
1113
|
+
max_new_tokens: 0,
|
|
1114
|
+
past_key_values: cache
|
|
1115
|
+
});
|
|
1116
|
+
const baseEntries = {};
|
|
1117
|
+
for (const key of Object.keys(cache)) {
|
|
1118
|
+
baseEntries[key] = cache[key];
|
|
1119
|
+
}
|
|
1120
|
+
const newSession = {
|
|
1121
|
+
mode: "prefix-rewind",
|
|
1122
|
+
baseEntries,
|
|
1123
|
+
baseSeqLength: cache.get_seq_length(),
|
|
1124
|
+
modelPath
|
|
1125
|
+
};
|
|
1126
|
+
setHftSession(sessionId, newSession);
|
|
1127
|
+
session = newSession;
|
|
1128
|
+
}
|
|
1129
|
+
if (session?.mode === "prefix-rewind") {
|
|
1130
|
+
const { DynamicCache } = await loadTransformersSDK();
|
|
1131
|
+
past_key_values = new DynamicCache(session.baseEntries);
|
|
1132
|
+
}
|
|
1133
|
+
let fullText = "";
|
|
1134
|
+
const filter = createToolCallMarkupFilter((text) => {
|
|
1135
|
+
outerQueue.push({ type: "text-delta", port: "text", textDelta: text });
|
|
1136
|
+
});
|
|
1137
|
+
const originalPush = innerQueue.push;
|
|
1138
|
+
innerQueue.push = (event) => {
|
|
1139
|
+
if (event.type === "text-delta" && "textDelta" in event) {
|
|
1140
|
+
fullText += event.textDelta;
|
|
1141
|
+
filter.feed(event.textDelta);
|
|
1142
|
+
} else {
|
|
1143
|
+
outerQueue.push(event);
|
|
1144
|
+
}
|
|
1145
|
+
originalPush(event);
|
|
1146
|
+
};
|
|
1147
|
+
const originalDone = innerQueue.done;
|
|
1148
|
+
innerQueue.done = () => {
|
|
1149
|
+
filter.flush();
|
|
1150
|
+
outerQueue.done();
|
|
1151
|
+
originalDone();
|
|
1152
|
+
};
|
|
1153
|
+
const originalError = innerQueue.error;
|
|
1154
|
+
innerQueue.error = (e) => {
|
|
1155
|
+
filter.flush();
|
|
1156
|
+
outerQueue.error(e);
|
|
1157
|
+
originalError(e);
|
|
1158
|
+
};
|
|
1159
|
+
const pipelinePromise = generateText(prompt, {
|
|
1160
|
+
max_new_tokens: input.maxTokens ?? 1024,
|
|
1161
|
+
temperature: input.temperature ?? undefined,
|
|
1162
|
+
return_full_text: false,
|
|
1163
|
+
streamer,
|
|
1164
|
+
stopping_criteria: [stopping_criteria],
|
|
1165
|
+
...past_key_values ? { past_key_values } : {}
|
|
1166
|
+
}).then(() => innerQueue.done(), (err) => innerQueue.error(err));
|
|
1167
|
+
yield* outerQueue.iterable;
|
|
1168
|
+
await pipelinePromise;
|
|
1169
|
+
const parseableFullText = responsePrefix ? `${responsePrefix}${fullText}` : fullText;
|
|
1170
|
+
const { text: cleanedText, toolCalls } = adaptParserResult(parseToolCalls(parseableFullText, { parser: modelFamily }));
|
|
1171
|
+
const validToolCalls = filterValidToolCalls(normalizeParsedToolCalls(input, toolCalls), input.tools);
|
|
1172
|
+
if (validToolCalls.length > 0) {
|
|
1173
|
+
yield { type: "object-delta", port: "toolCalls", objectDelta: [...validToolCalls] };
|
|
1174
|
+
}
|
|
1175
|
+
yield {
|
|
1176
|
+
type: "finish",
|
|
1177
|
+
data: { text: cleanedText, toolCalls: validToolCalls }
|
|
1178
|
+
};
|
|
1179
|
+
};
|
|
1180
|
+
|
|
1181
|
+
// src/ai-provider/common/HFT_Chat.ts
|
|
1182
|
+
async function generateTurn(input, model, sessionId, onProgress, signal, onDelta) {
|
|
1183
|
+
const generateText = await getPipeline(model, onProgress, {}, signal);
|
|
1184
|
+
const { TextStreamer, InterruptableStoppingCriteria } = await loadTransformersSDK();
|
|
1185
|
+
const hfTokenizer = generateText.tokenizer;
|
|
1186
|
+
const hfModel = generateText.model;
|
|
1187
|
+
const stopping_criteria = new InterruptableStoppingCriteria;
|
|
1188
|
+
if (signal) {
|
|
1189
|
+
signal.addEventListener("abort", () => stopping_criteria.interrupt(), { once: true });
|
|
1190
|
+
}
|
|
1191
|
+
const messages = buildHFTMessages(input.messages, input.systemPrompt, input.prompt, undefined);
|
|
1192
|
+
const prompt = hfTokenizer.apply_chat_template(messages, {
|
|
1193
|
+
tokenize: false,
|
|
1194
|
+
add_generation_prompt: true
|
|
1195
|
+
});
|
|
1196
|
+
const inputs = hfTokenizer(prompt);
|
|
1197
|
+
const promptLen = inputs.input_ids.dims[1];
|
|
1198
|
+
const modelPath = model.provider_config.model_path;
|
|
1199
|
+
let session = sessionId ? getHftSession(sessionId) : undefined;
|
|
1200
|
+
let past_key_values = undefined;
|
|
1201
|
+
if (session?.mode === "prefix-rewind" && session.modelPath === modelPath) {
|
|
1202
|
+
const { DynamicCache } = await loadTransformersSDK();
|
|
1203
|
+
past_key_values = new DynamicCache(session.baseEntries);
|
|
1204
|
+
}
|
|
1205
|
+
let accumulated = "";
|
|
1206
|
+
let streamer;
|
|
1207
|
+
if (onDelta) {
|
|
1208
|
+
const queue = createStreamEventQueue();
|
|
1209
|
+
streamer = createStreamingTextStreamer(hfTokenizer, queue, TextStreamer);
|
|
1210
|
+
queue.push = (event) => {
|
|
1211
|
+
if (event.type === "text-delta" && "textDelta" in event) {
|
|
1212
|
+
accumulated += event.textDelta;
|
|
1213
|
+
onDelta(event.textDelta);
|
|
1214
|
+
}
|
|
1215
|
+
};
|
|
1216
|
+
} else {
|
|
1217
|
+
streamer = createTextStreamer(hfTokenizer, onProgress, TextStreamer);
|
|
1218
|
+
}
|
|
1219
|
+
const output = await hfModel.generate({
|
|
1220
|
+
...inputs,
|
|
1221
|
+
max_new_tokens: input.maxTokens ?? 1024,
|
|
1222
|
+
temperature: input.temperature ?? undefined,
|
|
1223
|
+
streamer,
|
|
1224
|
+
stopping_criteria: [stopping_criteria],
|
|
1225
|
+
...past_key_values ? { past_key_values } : {}
|
|
1226
|
+
});
|
|
1227
|
+
if (!onDelta) {
|
|
1228
|
+
const seqLen = output.dims[1];
|
|
1229
|
+
const newTokens = output.slice(0, [promptLen, seqLen], null);
|
|
1230
|
+
accumulated = hfTokenizer.decode(newTokens, { skip_special_tokens: true });
|
|
1231
|
+
}
|
|
1232
|
+
if (sessionId) {
|
|
1233
|
+
let outputCache;
|
|
1234
|
+
if (past_key_values) {
|
|
1235
|
+
outputCache = past_key_values;
|
|
1236
|
+
} else if (output.past_key_values) {
|
|
1237
|
+
outputCache = output.past_key_values;
|
|
1238
|
+
}
|
|
1239
|
+
if (outputCache) {
|
|
1240
|
+
const baseEntries = {};
|
|
1241
|
+
for (const key of Object.keys(outputCache)) {
|
|
1242
|
+
baseEntries[key] = outputCache[key];
|
|
1243
|
+
}
|
|
1244
|
+
const newSession = {
|
|
1245
|
+
mode: "prefix-rewind",
|
|
1246
|
+
baseEntries,
|
|
1247
|
+
baseSeqLength: outputCache.get_seq_length ? outputCache.get_seq_length() : 0,
|
|
1248
|
+
modelPath
|
|
1249
|
+
};
|
|
1250
|
+
setHftSession(sessionId, newSession);
|
|
1251
|
+
}
|
|
1252
|
+
}
|
|
1253
|
+
return accumulated;
|
|
1254
|
+
}
|
|
1255
|
+
var HFT_Chat = async (input, model, update_progress, signal, _outputSchema, sessionId) => {
|
|
1256
|
+
update_progress(0, "HFT chat turn");
|
|
1257
|
+
const text = await generateTurn(input, model, sessionId, update_progress, signal, undefined);
|
|
1258
|
+
update_progress(100, "Turn complete");
|
|
1259
|
+
return { text };
|
|
1260
|
+
};
|
|
1261
|
+
var HFT_Chat_Stream = async function* (input, model, signal, _outputSchema, sessionId) {
|
|
1262
|
+
const noopProgress = () => {};
|
|
1263
|
+
const queue = [];
|
|
1264
|
+
let done = false;
|
|
1265
|
+
let resolver;
|
|
1266
|
+
const task = (async () => {
|
|
1267
|
+
try {
|
|
1268
|
+
await generateTurn(input, model, sessionId, noopProgress, signal, (piece) => {
|
|
1269
|
+
queue.push(piece);
|
|
1270
|
+
resolver?.();
|
|
1271
|
+
});
|
|
1272
|
+
} finally {
|
|
1273
|
+
done = true;
|
|
1274
|
+
resolver?.();
|
|
1275
|
+
}
|
|
1276
|
+
})();
|
|
1277
|
+
while (!done || queue.length > 0) {
|
|
1278
|
+
if (queue.length === 0 && !done) {
|
|
1279
|
+
await new Promise((res) => resolver = res);
|
|
1280
|
+
resolver = undefined;
|
|
1281
|
+
}
|
|
1282
|
+
while (queue.length > 0) {
|
|
1283
|
+
yield { type: "text-delta", port: "text", textDelta: queue.shift() };
|
|
1284
|
+
}
|
|
1285
|
+
}
|
|
1286
|
+
await task;
|
|
1287
|
+
yield { type: "finish", data: {} };
|
|
1288
|
+
};
|
|
1289
|
+
|
|
1290
|
+
// src/ai-provider/common/HFT_CountTokens.ts
|
|
1291
|
+
init_HFT_Pipeline();
|
|
1292
|
+
var HFT_CountTokens = async (input, model, onProgress, _signal) => {
|
|
1293
|
+
const { AutoTokenizer } = await loadTransformersSDK();
|
|
1294
|
+
const tokenizer = await AutoTokenizer.from_pretrained(model.provider_config.model_path, {
|
|
1295
|
+
progress_callback: (progress) => onProgress(progress?.progress ?? 0)
|
|
1296
|
+
});
|
|
1297
|
+
const tokenIds = tokenizer.encode(input.text);
|
|
1298
|
+
return { count: tokenIds.length };
|
|
1299
|
+
};
|
|
1300
|
+
var HFT_CountTokens_Preview = async (input, model) => {
|
|
1301
|
+
return HFT_CountTokens(input, model, () => {}, new AbortController().signal);
|
|
1302
|
+
};
|
|
1303
|
+
|
|
1304
|
+
// src/ai-provider/common/HFT_Download.ts
|
|
1305
|
+
init_HFT_Pipeline();
|
|
1306
|
+
import { getLogger as getLogger2 } from "@workglow/util/worker";
|
|
1307
|
+
var HFT_Download = async (input, model, onProgress, signal) => {
|
|
1308
|
+
const logger = getLogger2();
|
|
1309
|
+
const timerLabel = `hft:Download:${model?.provider_config.model_path}`;
|
|
1310
|
+
logger.time(timerLabel, { model: model?.provider_config.model_path });
|
|
1311
|
+
await getPipeline(model, onProgress, {}, signal, 100);
|
|
1312
|
+
logger.timeEnd(timerLabel, { model: model?.provider_config.model_path });
|
|
1313
|
+
return {
|
|
1314
|
+
model: input.model
|
|
1315
|
+
};
|
|
1316
|
+
};
|
|
1317
|
+
|
|
1318
|
+
// src/ai-provider/common/HFT_ImageClassification.ts
|
|
1319
|
+
init_HFT_Pipeline();
|
|
1320
|
+
import { imageValueToBlob as imageValueToBlob2 } from "@workglow/ai-provider/common";
|
|
1321
|
+
var HFT_ImageClassification = async (input, model, onProgress, signal) => {
|
|
1322
|
+
if (model?.provider_config?.pipeline === "zero-shot-image-classification") {
|
|
1323
|
+
if (!input.categories || !Array.isArray(input.categories) || input.categories.length === 0) {
|
|
1324
|
+
console.warn("Zero-shot image classification requires categories", input);
|
|
1325
|
+
throw new Error("Zero-shot image classification requires categories");
|
|
1326
|
+
}
|
|
1327
|
+
const zeroShotClassifier = await getPipeline(model, onProgress, {}, signal);
|
|
1328
|
+
const imageArg2 = await imageValueToBlob2(input.image);
|
|
1329
|
+
const result2 = await zeroShotClassifier(imageArg2, input.categories, {});
|
|
1330
|
+
const results2 = Array.isArray(result2) ? result2 : [result2];
|
|
1331
|
+
return {
|
|
1332
|
+
categories: results2.map((r) => ({
|
|
1333
|
+
label: r.label,
|
|
1334
|
+
score: r.score
|
|
1335
|
+
}))
|
|
1336
|
+
};
|
|
1337
|
+
}
|
|
1338
|
+
const classifier = await getPipeline(model, onProgress, {}, signal);
|
|
1339
|
+
const imageArg = await imageValueToBlob2(input.image);
|
|
1340
|
+
const result = await classifier(imageArg, {
|
|
1341
|
+
top_k: input.maxCategories
|
|
1342
|
+
});
|
|
1343
|
+
const results = Array.isArray(result) ? result : [result];
|
|
1344
|
+
return {
|
|
1345
|
+
categories: results.map((r) => ({
|
|
1346
|
+
label: r.label,
|
|
1347
|
+
score: r.score
|
|
1348
|
+
}))
|
|
1349
|
+
};
|
|
1350
|
+
};
|
|
1351
|
+
|
|
1352
|
+
// src/ai-provider/common/HFT_ImageEmbedding.ts
|
|
1353
|
+
init_HFT_Pipeline();
|
|
1354
|
+
import { getLogger as getLogger3 } from "@workglow/util/worker";
|
|
1355
|
+
import { imageValueToBlob as imageValueToBlob3 } from "@workglow/ai-provider/common";
|
|
1356
|
+
var HFT_ImageEmbedding = async (input, model, onProgress, signal) => {
|
|
1357
|
+
const logger = getLogger3();
|
|
1358
|
+
const timerLabel = `hft:ImageEmbedding:${model?.provider_config.model_path}`;
|
|
1359
|
+
logger.time(timerLabel, { model: model?.provider_config.model_path });
|
|
1360
|
+
const embedder = await getPipeline(model, onProgress, {}, signal);
|
|
1361
|
+
logger.debug("HFT ImageEmbedding: pipeline ready, generating embedding", {
|
|
1362
|
+
model: model?.provider_config.model_path
|
|
1363
|
+
});
|
|
1364
|
+
if (Array.isArray(input.image)) {
|
|
1365
|
+
const vectors = [];
|
|
1366
|
+
for (const image of input.image) {
|
|
1367
|
+
const imageArg2 = await imageValueToBlob3(image);
|
|
1368
|
+
const result2 = await embedder(imageArg2);
|
|
1369
|
+
vectors.push(result2.data);
|
|
1370
|
+
}
|
|
1371
|
+
logger.timeEnd(timerLabel, { count: vectors.length });
|
|
1372
|
+
return { vector: vectors };
|
|
1373
|
+
}
|
|
1374
|
+
const imageArg = await imageValueToBlob3(input.image);
|
|
1375
|
+
const result = await embedder(imageArg);
|
|
1376
|
+
logger.timeEnd(timerLabel, { dimensions: result?.data?.length });
|
|
1377
|
+
return {
|
|
1378
|
+
vector: result.data
|
|
1379
|
+
};
|
|
1380
|
+
};
|
|
1381
|
+
|
|
1382
|
+
// src/ai-provider/common/HFT_ImageSegmentation.ts
|
|
1383
|
+
init_HFT_Pipeline();
|
|
1384
|
+
import { imageValueToBlob as imageValueToBlob4 } from "@workglow/ai-provider/common";
|
|
1385
|
+
var HFT_ImageSegmentation = async (input, model, onProgress, signal) => {
|
|
1386
|
+
const segmenter = await getPipeline(model, onProgress, {}, signal);
|
|
1387
|
+
const imageArg = await imageValueToBlob4(input.image);
|
|
1388
|
+
const result = await segmenter(imageArg, {
|
|
1389
|
+
threshold: input.threshold,
|
|
1390
|
+
mask_threshold: input.maskThreshold
|
|
1391
|
+
});
|
|
1392
|
+
const masks = Array.isArray(result) ? result : [result];
|
|
1393
|
+
const processedMasks = await Promise.all(masks.map(async (mask) => ({
|
|
1394
|
+
label: mask.label || "",
|
|
1395
|
+
score: mask.score || 0,
|
|
1396
|
+
mask: {}
|
|
1397
|
+
})));
|
|
1398
|
+
return {
|
|
1399
|
+
masks: processedMasks
|
|
1400
|
+
};
|
|
1401
|
+
};
|
|
1402
|
+
|
|
1403
|
+
// src/ai-provider/common/HFT_ImageToText.ts
|
|
1404
|
+
init_HFT_Pipeline();
|
|
1405
|
+
import { imageValueToBlob as imageValueToBlob5 } from "@workglow/ai-provider/common";
|
|
1406
|
+
var HFT_ImageToText = async (input, model, onProgress, signal) => {
|
|
1407
|
+
const captioner = await getPipeline(model, onProgress, {}, signal);
|
|
1408
|
+
const imageArg = await imageValueToBlob5(input.image);
|
|
1409
|
+
const result = await captioner(imageArg, {
|
|
1410
|
+
max_new_tokens: input.maxTokens
|
|
1411
|
+
});
|
|
1412
|
+
const text = Array.isArray(result[0]) ? result[0][0]?.generated_text : result[0]?.generated_text;
|
|
1413
|
+
return {
|
|
1414
|
+
text: text || ""
|
|
1415
|
+
};
|
|
1416
|
+
};
|
|
1417
|
+
|
|
1418
|
+
// src/ai-provider/common/HFT_ModelInfo.ts
|
|
1419
|
+
import { getLogger as getLogger4 } from "@workglow/util/worker";
|
|
1420
|
+
init_HFT_Pipeline();
|
|
1421
|
+
var HFT_ModelInfo = async (input, model) => {
|
|
1422
|
+
if (input.detail === "dimensions") {
|
|
1423
|
+
if (!model)
|
|
1424
|
+
throw new Error("Model config is required for ModelInfoTask.");
|
|
1425
|
+
const pc = model.provider_config;
|
|
1426
|
+
let native_dimensions = typeof pc.native_dimensions === "number" ? pc.native_dimensions : undefined;
|
|
1427
|
+
const mrl = typeof pc.mrl === "boolean" ? pc.mrl : false;
|
|
1428
|
+
if (native_dimensions === undefined && typeof pc.model_path === "string") {
|
|
1429
|
+
try {
|
|
1430
|
+
const response = await fetch(`https://huggingface.co/${pc.model_path}/resolve/${pc.revision ?? "main"}/config.json`);
|
|
1431
|
+
if (response.ok) {
|
|
1432
|
+
const config = await response.json();
|
|
1433
|
+
if (typeof config.hidden_size === "number") {
|
|
1434
|
+
native_dimensions = config.hidden_size;
|
|
1435
|
+
}
|
|
1436
|
+
}
|
|
1437
|
+
} catch {}
|
|
1438
|
+
}
|
|
1439
|
+
return {
|
|
1440
|
+
model: input.model,
|
|
1441
|
+
is_local: true,
|
|
1442
|
+
is_remote: false,
|
|
1443
|
+
supports_browser: true,
|
|
1444
|
+
supports_node: true,
|
|
1445
|
+
is_cached: false,
|
|
1446
|
+
is_loaded: false,
|
|
1447
|
+
file_sizes: null,
|
|
1448
|
+
...native_dimensions !== undefined ? { native_dimensions } : {},
|
|
1449
|
+
...mrl ? { mrl } : {}
|
|
1450
|
+
};
|
|
1451
|
+
}
|
|
1452
|
+
const logger = getLogger4();
|
|
1453
|
+
const { ModelRegistry } = await loadTransformersSDK();
|
|
1454
|
+
const timerLabel = `hft:ModelInfo:${model?.provider_config.model_path}`;
|
|
1455
|
+
logger.time(timerLabel, { model: model?.provider_config.model_path });
|
|
1456
|
+
const detail = input.detail;
|
|
1457
|
+
const is_loaded = hasCachedPipeline(getPipelineCacheKey(model));
|
|
1458
|
+
const { pipeline: pipelineType, model_path, dtype, device } = model.provider_config;
|
|
1459
|
+
const cacheOptions = {
|
|
1460
|
+
...dtype ? { dtype } : {},
|
|
1461
|
+
...device ? { device } : {}
|
|
1462
|
+
};
|
|
1463
|
+
const cacheStatus = await ModelRegistry.is_pipeline_cached_files(pipelineType, model_path, cacheOptions);
|
|
1464
|
+
logger.debug("is_pipeline_cached", {
|
|
1465
|
+
input: [pipelineType, model_path, cacheOptions],
|
|
1466
|
+
result: cacheStatus
|
|
1467
|
+
});
|
|
1468
|
+
const is_cached = is_loaded || cacheStatus.allCached;
|
|
1469
|
+
let file_sizes = null;
|
|
1470
|
+
if (detail === "files" && cacheStatus.files.length > 0) {
|
|
1471
|
+
const sizes = {};
|
|
1472
|
+
for (const { file } of cacheStatus.files) {
|
|
1473
|
+
sizes[file] = 0;
|
|
1474
|
+
}
|
|
1475
|
+
file_sizes = sizes;
|
|
1476
|
+
} else if (detail === "files_with_metadata" && cacheStatus.files.length > 0) {
|
|
1477
|
+
const sizes = {};
|
|
1478
|
+
await Promise.all(cacheStatus.files.map(async ({ file }) => {
|
|
1479
|
+
const metadata = await ModelRegistry.get_file_metadata(model_path, file);
|
|
1480
|
+
if (metadata.exists && metadata.size !== undefined) {
|
|
1481
|
+
sizes[file] = metadata.size;
|
|
1482
|
+
}
|
|
1483
|
+
}));
|
|
1484
|
+
if (Object.keys(sizes).length > 0) {
|
|
1485
|
+
file_sizes = sizes;
|
|
1486
|
+
}
|
|
1487
|
+
}
|
|
1488
|
+
let quantizations;
|
|
1489
|
+
if (cacheStatus.files.length > 0) {
|
|
1490
|
+
const filePaths = cacheStatus.files.map((f) => f.file);
|
|
1491
|
+
const quantizations_parsed = parseOnnxQuantizations({ filePaths });
|
|
1492
|
+
if (quantizations_parsed.length > 0) {
|
|
1493
|
+
quantizations = quantizations_parsed;
|
|
1494
|
+
}
|
|
1495
|
+
}
|
|
1496
|
+
logger.timeEnd(timerLabel, { model: model?.provider_config.model_path });
|
|
1497
|
+
return {
|
|
1498
|
+
model: input.model,
|
|
1499
|
+
is_local: true,
|
|
1500
|
+
is_remote: false,
|
|
1501
|
+
supports_browser: true,
|
|
1502
|
+
supports_node: true,
|
|
1503
|
+
is_cached,
|
|
1504
|
+
is_loaded,
|
|
1505
|
+
file_sizes,
|
|
1506
|
+
...quantizations ? { quantizations } : {}
|
|
1507
|
+
};
|
|
1508
|
+
};
|
|
1509
|
+
|
|
1510
|
+
// src/ai-provider/common/HFT_ObjectDetection.ts
|
|
1511
|
+
init_HFT_Pipeline();
|
|
1512
|
+
import { imageValueToBlob as imageValueToBlob6 } from "@workglow/ai-provider/common";
|
|
1513
|
+
var HFT_ObjectDetection = async (input, model, onProgress, signal) => {
|
|
1514
|
+
if (model?.provider_config?.pipeline === "zero-shot-object-detection") {
|
|
1515
|
+
if (!input.labels || !Array.isArray(input.labels) || input.labels.length === 0) {
|
|
1516
|
+
throw new Error("Zero-shot object detection requires labels");
|
|
1517
|
+
}
|
|
1518
|
+
const zeroShotDetector = await getPipeline(model, onProgress, {}, signal);
|
|
1519
|
+
const imageArg2 = await imageValueToBlob6(input.image);
|
|
1520
|
+
const result = await zeroShotDetector(imageArg2, Array.from(input.labels), {
|
|
1521
|
+
threshold: input.threshold
|
|
1522
|
+
});
|
|
1523
|
+
return {
|
|
1524
|
+
detections: result.map((d) => ({
|
|
1525
|
+
label: d.label,
|
|
1526
|
+
score: d.score,
|
|
1527
|
+
box: d.box
|
|
1528
|
+
}))
|
|
1529
|
+
};
|
|
1530
|
+
}
|
|
1531
|
+
const detector = await getPipeline(model, onProgress, {}, signal);
|
|
1532
|
+
const imageArg = await imageValueToBlob6(input.image);
|
|
1533
|
+
const detections = await detector(imageArg, {
|
|
1534
|
+
threshold: input.threshold
|
|
1535
|
+
});
|
|
1536
|
+
return {
|
|
1537
|
+
detections: detections.map((d) => ({
|
|
1538
|
+
label: d.label,
|
|
1539
|
+
score: d.score,
|
|
1540
|
+
box: d.box
|
|
1541
|
+
}))
|
|
1542
|
+
};
|
|
1543
|
+
};
|
|
1544
|
+
|
|
1545
|
+
// src/ai-provider/common/HFT_StructuredGeneration.ts
|
|
1546
|
+
init_HFT_Pipeline();
|
|
1547
|
+
import { parsePartialJson } from "@workglow/util/worker";
|
|
1548
|
+
|
|
1549
|
+
// src/ai-provider/common/HFT_TextOutput.ts
|
|
1550
|
+
function extractGeneratedText(generatedText) {
|
|
1551
|
+
if (generatedText == null)
|
|
1552
|
+
return "";
|
|
1553
|
+
if (typeof generatedText === "string")
|
|
1554
|
+
return generatedText;
|
|
1555
|
+
const lastMessage = generatedText[generatedText.length - 1];
|
|
1556
|
+
if (!lastMessage)
|
|
1557
|
+
return "";
|
|
1558
|
+
const content = lastMessage.content;
|
|
1559
|
+
if (typeof content === "string")
|
|
1560
|
+
return content;
|
|
1561
|
+
for (const part of content) {
|
|
1562
|
+
if (part.type === "text" && "text" in part) {
|
|
1563
|
+
return part.text;
|
|
1564
|
+
}
|
|
1565
|
+
}
|
|
1566
|
+
return "";
|
|
1567
|
+
}
|
|
1568
|
+
|
|
1569
|
+
// src/ai-provider/common/HFT_StructuredGeneration.ts
|
|
1570
|
+
function buildStructuredGenerationPrompt(input) {
|
|
1571
|
+
const schemaStr = JSON.stringify(input.outputSchema, null, 2);
|
|
1572
|
+
return `${input.prompt}
|
|
1573
|
+
|
|
1574
|
+
` + `You MUST respond with ONLY a valid JSON object conforming to this JSON schema:
|
|
1575
|
+
${schemaStr}
|
|
1576
|
+
|
|
1577
|
+
` + `Output ONLY the JSON object, no other text.`;
|
|
1578
|
+
}
|
|
1579
|
+
function stripThinkingAndSpecialTokens(text) {
|
|
1580
|
+
return text.replace(/<think>[\s\S]*?<\/think>/g, "").replace(/<\|[a-z_]+\|>/g, "").trim();
|
|
1581
|
+
}
|
|
1582
|
+
function extractJsonFromText(text) {
|
|
1583
|
+
const cleaned = stripThinkingAndSpecialTokens(text);
|
|
1584
|
+
try {
|
|
1585
|
+
return JSON.parse(cleaned);
|
|
1586
|
+
} catch {
|
|
1587
|
+
const match = cleaned.match(/\{[\s\S]*\}/);
|
|
1588
|
+
if (match) {
|
|
1589
|
+
try {
|
|
1590
|
+
return JSON.parse(match[0]);
|
|
1591
|
+
} catch {
|
|
1592
|
+
return parsePartialJson(match[0]) ?? {};
|
|
1593
|
+
}
|
|
1594
|
+
}
|
|
1595
|
+
return {};
|
|
1596
|
+
}
|
|
1597
|
+
}
|
|
1598
|
+
var HFT_StructuredGeneration = async (input, model, onProgress, signal) => {
|
|
1599
|
+
const generateText = await getPipeline(model, onProgress, {}, signal);
|
|
1600
|
+
const { TextStreamer, InterruptableStoppingCriteria } = await loadTransformersSDK();
|
|
1601
|
+
const prompt = buildStructuredGenerationPrompt(input);
|
|
1602
|
+
const messages = [{ role: "user", content: prompt }];
|
|
1603
|
+
const formattedPrompt = generateText.tokenizer.apply_chat_template(messages, {
|
|
1604
|
+
tokenize: false,
|
|
1605
|
+
add_generation_prompt: true
|
|
1606
|
+
});
|
|
1607
|
+
const streamer = createTextStreamer(generateText.tokenizer, onProgress, TextStreamer);
|
|
1608
|
+
const stopping_criteria = new InterruptableStoppingCriteria;
|
|
1609
|
+
if (signal) {
|
|
1610
|
+
signal.addEventListener("abort", () => stopping_criteria.interrupt(), { once: true });
|
|
1611
|
+
}
|
|
1612
|
+
let results = await generateText(formattedPrompt, {
|
|
1613
|
+
max_new_tokens: input.maxTokens ?? 1024,
|
|
1614
|
+
temperature: input.temperature ?? undefined,
|
|
1615
|
+
return_full_text: false,
|
|
1616
|
+
streamer,
|
|
1617
|
+
stopping_criteria: [stopping_criteria]
|
|
1618
|
+
});
|
|
1619
|
+
if (!Array.isArray(results)) {
|
|
1620
|
+
results = [results];
|
|
1621
|
+
}
|
|
1622
|
+
const responseText = extractGeneratedText(results[0]?.generated_text).trim();
|
|
1623
|
+
const object = extractJsonFromText(responseText);
|
|
1624
|
+
return { object };
|
|
1625
|
+
};
|
|
1626
|
+
var HFT_StructuredGeneration_Stream = async function* (input, model, signal) {
|
|
1627
|
+
const noopProgress = () => {};
|
|
1628
|
+
const generateText = await getPipeline(model, noopProgress, {}, signal);
|
|
1629
|
+
const { TextStreamer, InterruptableStoppingCriteria } = await loadTransformersSDK();
|
|
1630
|
+
const prompt = buildStructuredGenerationPrompt(input);
|
|
1631
|
+
const messages = [{ role: "user", content: prompt }];
|
|
1632
|
+
const formattedPrompt = generateText.tokenizer.apply_chat_template(messages, {
|
|
1633
|
+
tokenize: false,
|
|
1634
|
+
add_generation_prompt: true
|
|
1635
|
+
});
|
|
1636
|
+
const queue = createStreamEventQueue();
|
|
1637
|
+
const streamer = createStreamingTextStreamer(generateText.tokenizer, queue, TextStreamer);
|
|
1638
|
+
const stopping_criteria = new InterruptableStoppingCriteria;
|
|
1639
|
+
if (signal) {
|
|
1640
|
+
signal.addEventListener("abort", () => stopping_criteria.interrupt(), { once: true });
|
|
1641
|
+
}
|
|
1642
|
+
let fullText = "";
|
|
1643
|
+
let cleanedText = "";
|
|
1644
|
+
let inThinkBlock = false;
|
|
1645
|
+
let jsonStart = -1;
|
|
1646
|
+
const originalPush = queue.push;
|
|
1647
|
+
queue.push = (event) => {
|
|
1648
|
+
if (event.type === "text-delta" && "textDelta" in event) {
|
|
1649
|
+
const delta = event.textDelta;
|
|
1650
|
+
fullText += delta;
|
|
1651
|
+
let remaining = delta;
|
|
1652
|
+
while (remaining.length > 0) {
|
|
1653
|
+
if (inThinkBlock) {
|
|
1654
|
+
const closeIdx = remaining.indexOf("</think>");
|
|
1655
|
+
if (closeIdx !== -1) {
|
|
1656
|
+
inThinkBlock = false;
|
|
1657
|
+
remaining = remaining.slice(closeIdx + "</think>".length);
|
|
1658
|
+
} else {
|
|
1659
|
+
remaining = "";
|
|
1660
|
+
}
|
|
1661
|
+
} else {
|
|
1662
|
+
const openIdx = remaining.indexOf("<think>");
|
|
1663
|
+
if (openIdx !== -1) {
|
|
1664
|
+
cleanedText += remaining.slice(0, openIdx).replace(/<\|[a-z_]+\|>/g, "");
|
|
1665
|
+
inThinkBlock = true;
|
|
1666
|
+
remaining = remaining.slice(openIdx + "<think>".length);
|
|
1667
|
+
} else {
|
|
1668
|
+
cleanedText += remaining.replace(/<\|[a-z_]+\|>/g, "");
|
|
1669
|
+
remaining = "";
|
|
1670
|
+
}
|
|
1671
|
+
}
|
|
1672
|
+
}
|
|
1673
|
+
if (jsonStart === -1) {
|
|
1674
|
+
jsonStart = cleanedText.indexOf("{");
|
|
1675
|
+
}
|
|
1676
|
+
if (jsonStart !== -1) {
|
|
1677
|
+
const partial = parsePartialJson(cleanedText.slice(jsonStart));
|
|
1678
|
+
if (partial !== undefined) {
|
|
1679
|
+
originalPush({
|
|
1680
|
+
type: "object-delta",
|
|
1681
|
+
port: "object",
|
|
1682
|
+
objectDelta: partial
|
|
1683
|
+
});
|
|
1684
|
+
return;
|
|
1685
|
+
}
|
|
1686
|
+
}
|
|
1687
|
+
}
|
|
1688
|
+
originalPush(event);
|
|
1689
|
+
};
|
|
1690
|
+
const pipelinePromise = generateText(formattedPrompt, {
|
|
1691
|
+
max_new_tokens: input.maxTokens ?? 1024,
|
|
1692
|
+
temperature: input.temperature ?? undefined,
|
|
1693
|
+
return_full_text: false,
|
|
1694
|
+
streamer,
|
|
1695
|
+
stopping_criteria: [stopping_criteria]
|
|
1696
|
+
}).then(() => queue.done(), (err) => queue.error(err));
|
|
1697
|
+
yield* queue.iterable;
|
|
1698
|
+
await pipelinePromise;
|
|
1699
|
+
const object = extractJsonFromText(fullText);
|
|
1700
|
+
yield { type: "finish", data: { object } };
|
|
1701
|
+
};
|
|
1702
|
+
|
|
1703
|
+
// src/ai-provider/common/HFT_TextClassification.ts
|
|
1704
|
+
init_HFT_Pipeline();
|
|
1705
|
+
var HFT_TextClassification = async (input, model, onProgress, signal) => {
|
|
1706
|
+
if (model?.provider_config?.pipeline === "zero-shot-classification") {
|
|
1707
|
+
if (!input.candidateLabels || !Array.isArray(input.candidateLabels) || input.candidateLabels.length === 0) {
|
|
1708
|
+
throw new Error("Zero-shot text classification requires candidate labels");
|
|
1709
|
+
}
|
|
1710
|
+
const zeroShotClassifier = await getPipeline(model, onProgress, {}, signal);
|
|
1711
|
+
const result2 = await zeroShotClassifier(input.text, input.candidateLabels, {});
|
|
1712
|
+
return {
|
|
1713
|
+
categories: result2.labels.map((label, idx) => ({
|
|
1714
|
+
label,
|
|
1715
|
+
score: result2.scores[idx]
|
|
1716
|
+
}))
|
|
1717
|
+
};
|
|
1718
|
+
}
|
|
1719
|
+
const TextClassification = await getPipeline(model, onProgress, {}, signal);
|
|
1720
|
+
const result = await TextClassification(input.text, {
|
|
1721
|
+
top_k: input.maxCategories || undefined
|
|
1722
|
+
});
|
|
1723
|
+
return {
|
|
1724
|
+
categories: result.map((category) => ({
|
|
1725
|
+
label: category.label,
|
|
1726
|
+
score: category.score
|
|
1727
|
+
}))
|
|
1728
|
+
};
|
|
1729
|
+
};
|
|
1730
|
+
|
|
1731
|
+
// src/ai-provider/common/HFT_TextEmbedding.ts
|
|
1732
|
+
init_HFT_Pipeline();
|
|
1733
|
+
import { getLogger as getLogger5 } from "@workglow/util/worker";
|
|
1734
|
+
var HFT_TextEmbedding = async (input, model, onProgress, signal) => {
|
|
1735
|
+
const logger = getLogger5();
|
|
1736
|
+
const uuid = crypto.randomUUID();
|
|
1737
|
+
const timerLabel = `hft:TextEmbedding:${model?.provider_config.model_path}:${uuid}`;
|
|
1738
|
+
logger.time(timerLabel, { model: model?.provider_config.model_path });
|
|
1739
|
+
const generateEmbedding = await getPipeline(model, onProgress, {}, signal);
|
|
1740
|
+
logger.debug("HFT TextEmbedding: pipeline ready, generating embedding", {
|
|
1741
|
+
model: model?.provider_config.model_path,
|
|
1742
|
+
inputLength: Array.isArray(input.text) ? input.text.length : input.text?.length
|
|
1743
|
+
});
|
|
1744
|
+
const hfVector = await generateEmbedding(input.text, {
|
|
1745
|
+
pooling: model?.provider_config.pooling || "mean",
|
|
1746
|
+
normalize: model?.provider_config.normalize
|
|
1747
|
+
});
|
|
1748
|
+
const isArrayInput = Array.isArray(input.text);
|
|
1749
|
+
const embeddingDim = model?.provider_config.native_dimensions;
|
|
1750
|
+
if (isArrayInput && hfVector.dims.length > 1) {
|
|
1751
|
+
const [numTexts, vectorDim] = hfVector.dims;
|
|
1752
|
+
if (numTexts !== input.text.length) {
|
|
1753
|
+
throw new Error(`HuggingFace Embedding tensor batch size does not match input array length: ${numTexts} != ${input.text.length}`);
|
|
1754
|
+
}
|
|
1755
|
+
if (vectorDim !== embeddingDim) {
|
|
1756
|
+
throw new Error(`HuggingFace Embedding vector dimension does not match model dimensions: ${vectorDim} != ${embeddingDim}`);
|
|
1757
|
+
}
|
|
1758
|
+
const vectors = Array.from({ length: numTexts }, (_, i) => hfVector[i].data.slice());
|
|
1759
|
+
logger.timeEnd(timerLabel, { batchSize: numTexts, dimensions: vectorDim });
|
|
1760
|
+
return { vector: vectors };
|
|
1761
|
+
}
|
|
1762
|
+
if (hfVector.size !== embeddingDim) {
|
|
1763
|
+
logger.timeEnd(timerLabel, { status: "error", reason: "dimension mismatch" });
|
|
1764
|
+
console.warn(`HuggingFace Embedding vector length does not match model dimensions v${hfVector.size} != m${embeddingDim}`, input, hfVector);
|
|
1765
|
+
throw new Error(`HuggingFace Embedding vector length does not match model dimensions v${hfVector.size} != m${embeddingDim}`);
|
|
1766
|
+
}
|
|
1767
|
+
logger.timeEnd(timerLabel, { dimensions: hfVector.size });
|
|
1768
|
+
return { vector: hfVector.data };
|
|
1769
|
+
};
|
|
1770
|
+
|
|
1771
|
+
// src/ai-provider/common/HFT_TextFillMask.ts
|
|
1772
|
+
init_HFT_Pipeline();
|
|
1773
|
+
var HFT_TextFillMask = async (input, model, onProgress, signal) => {
|
|
1774
|
+
const unmasker = await getPipeline(model, onProgress, {}, signal);
|
|
1775
|
+
const predictions = await unmasker(input.text);
|
|
1776
|
+
return {
|
|
1777
|
+
predictions: predictions.map((prediction) => ({
|
|
1778
|
+
entity: prediction.token_str,
|
|
1779
|
+
score: prediction.score,
|
|
1780
|
+
sequence: prediction.sequence
|
|
1781
|
+
}))
|
|
1782
|
+
};
|
|
1783
|
+
};
|
|
1784
|
+
|
|
1785
|
+
// src/ai-provider/common/HFT_TextGeneration.ts
|
|
1786
|
+
init_HFT_Pipeline();
|
|
1787
|
+
import { getLogger as getLogger6 } from "@workglow/util/worker";
|
|
1788
|
+
var HFT_TextGeneration = async (input, model, onProgress, signal, _outputSchema, sessionId) => {
|
|
1789
|
+
const logger = getLogger6();
|
|
1790
|
+
const timerLabel = `hft:TextGeneration:${model?.provider_config.model_path}`;
|
|
1791
|
+
logger.time(timerLabel, { model: model?.provider_config.model_path });
|
|
1792
|
+
const generateText = await getPipeline(model, onProgress, {}, signal);
|
|
1793
|
+
const { TextStreamer, InterruptableStoppingCriteria } = await loadTransformersSDK();
|
|
1794
|
+
logger.debug("HFT TextGeneration: pipeline ready, generating text", {
|
|
1795
|
+
model: model?.provider_config.model_path,
|
|
1796
|
+
promptLength: input.prompt?.length
|
|
1797
|
+
});
|
|
1798
|
+
const streamer = createTextStreamer(generateText.tokenizer, onProgress, TextStreamer);
|
|
1799
|
+
const stopping_criteria = new InterruptableStoppingCriteria;
|
|
1800
|
+
if (signal) {
|
|
1801
|
+
signal.addEventListener("abort", () => stopping_criteria.interrupt(), { once: true });
|
|
1802
|
+
}
|
|
1803
|
+
const modelPath = model.provider_config.model_path;
|
|
1804
|
+
let session = sessionId ? getHftSession(sessionId) : undefined;
|
|
1805
|
+
let past_key_values = undefined;
|
|
1806
|
+
if (sessionId && !session) {
|
|
1807
|
+
const sdk = await loadTransformersSDK();
|
|
1808
|
+
const cache = new sdk.DynamicCache;
|
|
1809
|
+
const newSession = {
|
|
1810
|
+
mode: "progressive",
|
|
1811
|
+
cache,
|
|
1812
|
+
modelPath
|
|
1813
|
+
};
|
|
1814
|
+
setHftSession(sessionId, newSession);
|
|
1815
|
+
session = newSession;
|
|
1816
|
+
}
|
|
1817
|
+
if (session?.mode === "progressive") {
|
|
1818
|
+
past_key_values = session.cache;
|
|
1819
|
+
}
|
|
1820
|
+
const messages = [{ role: "user", content: input.prompt }];
|
|
1821
|
+
let results = await generateText(messages, {
|
|
1822
|
+
streamer,
|
|
1823
|
+
do_sample: false,
|
|
1824
|
+
max_new_tokens: input.maxTokens ?? 4 * 1024,
|
|
1825
|
+
stopping_criteria: [stopping_criteria],
|
|
1826
|
+
...past_key_values ? { past_key_values } : {}
|
|
1827
|
+
});
|
|
1828
|
+
if (!Array.isArray(results)) {
|
|
1829
|
+
results = [results];
|
|
1830
|
+
}
|
|
1831
|
+
const text = extractGeneratedText(results[0]?.generated_text);
|
|
1832
|
+
logger.timeEnd(timerLabel, { outputLength: text?.length });
|
|
1833
|
+
return {
|
|
1834
|
+
text
|
|
1835
|
+
};
|
|
1836
|
+
};
|
|
1837
|
+
var HFT_TextGeneration_Stream = async function* (input, model, signal, _outputSchema, sessionId) {
|
|
1838
|
+
const noopProgress = () => {};
|
|
1839
|
+
const generateText = await getPipeline(model, noopProgress, {}, signal);
|
|
1840
|
+
const { TextStreamer, InterruptableStoppingCriteria } = await loadTransformersSDK();
|
|
1841
|
+
const queue = createStreamEventQueue();
|
|
1842
|
+
const streamer = createStreamingTextStreamer(generateText.tokenizer, queue, TextStreamer);
|
|
1843
|
+
const stopping_criteria = new InterruptableStoppingCriteria;
|
|
1844
|
+
if (signal) {
|
|
1845
|
+
signal.addEventListener("abort", () => stopping_criteria.interrupt(), { once: true });
|
|
1846
|
+
}
|
|
1847
|
+
const modelPath = model.provider_config.model_path;
|
|
1848
|
+
let session = sessionId ? getHftSession(sessionId) : undefined;
|
|
1849
|
+
let past_key_values = undefined;
|
|
1850
|
+
if (sessionId && !session) {
|
|
1851
|
+
const sdk = await loadTransformersSDK();
|
|
1852
|
+
const cache = new sdk.DynamicCache;
|
|
1853
|
+
const newSession = {
|
|
1854
|
+
mode: "progressive",
|
|
1855
|
+
cache,
|
|
1856
|
+
modelPath
|
|
1857
|
+
};
|
|
1858
|
+
setHftSession(sessionId, newSession);
|
|
1859
|
+
session = newSession;
|
|
1860
|
+
}
|
|
1861
|
+
if (session?.mode === "progressive") {
|
|
1862
|
+
past_key_values = session.cache;
|
|
1863
|
+
}
|
|
1864
|
+
const messages = [{ role: "user", content: input.prompt }];
|
|
1865
|
+
const pipelinePromise = generateText(messages, {
|
|
1866
|
+
streamer,
|
|
1867
|
+
do_sample: false,
|
|
1868
|
+
max_new_tokens: input.maxTokens ?? 4 * 1024,
|
|
1869
|
+
stopping_criteria: [stopping_criteria],
|
|
1870
|
+
...past_key_values ? { past_key_values } : {}
|
|
1871
|
+
}).then(() => queue.done(), (err) => queue.error(err));
|
|
1872
|
+
yield* queue.iterable;
|
|
1873
|
+
await pipelinePromise;
|
|
1874
|
+
yield { type: "finish", data: {} };
|
|
1875
|
+
};
|
|
1876
|
+
|
|
1877
|
+
// src/ai-provider/common/HFT_TextLanguageDetection.ts
|
|
1878
|
+
init_HFT_Pipeline();
|
|
1879
|
+
var HFT_TextLanguageDetection = async (input, model, onProgress, signal) => {
|
|
1880
|
+
const TextClassification = await getPipeline(model, onProgress, {}, signal);
|
|
1881
|
+
const result = await TextClassification(input.text, {
|
|
1882
|
+
top_k: input.maxLanguages || undefined
|
|
1883
|
+
});
|
|
1884
|
+
return {
|
|
1885
|
+
languages: result.map((category) => ({
|
|
1886
|
+
language: category.label,
|
|
1887
|
+
score: category.score
|
|
1888
|
+
}))
|
|
1889
|
+
};
|
|
1890
|
+
};
|
|
1891
|
+
|
|
1892
|
+
// src/ai-provider/common/HFT_TextNamedEntityRecognition.ts
|
|
1893
|
+
init_HFT_Pipeline();
|
|
1894
|
+
var HFT_TextNamedEntityRecognition = async (input, model, onProgress, signal) => {
|
|
1895
|
+
const textNamedEntityRecognition = await getPipeline(model, onProgress, {}, signal);
|
|
1896
|
+
const results = await textNamedEntityRecognition(input.text, {
|
|
1897
|
+
ignore_labels: input.blockList
|
|
1898
|
+
});
|
|
1899
|
+
return {
|
|
1900
|
+
entities: results.map((entity) => ({
|
|
1901
|
+
entity: entity.entity,
|
|
1902
|
+
score: entity.score,
|
|
1903
|
+
word: entity.word
|
|
1904
|
+
}))
|
|
1905
|
+
};
|
|
1906
|
+
};
|
|
1907
|
+
|
|
1908
|
+
// src/ai-provider/common/HFT_TextQuestionAnswer.ts
|
|
1909
|
+
init_HFT_Pipeline();
|
|
1910
|
+
var HFT_TextQuestionAnswer = async (input, model, onProgress, signal) => {
|
|
1911
|
+
const generateAnswer = await getPipeline(model, onProgress, {}, signal);
|
|
1912
|
+
const { TextStreamer, InterruptableStoppingCriteria } = await loadTransformersSDK();
|
|
1913
|
+
const streamer = createTextStreamer(generateAnswer.tokenizer, onProgress, TextStreamer);
|
|
1914
|
+
const stopping_criteria = new InterruptableStoppingCriteria;
|
|
1915
|
+
if (signal) {
|
|
1916
|
+
signal.addEventListener("abort", () => stopping_criteria.interrupt(), { once: true });
|
|
1917
|
+
}
|
|
1918
|
+
const result = await generateAnswer(input.question, input.context, {
|
|
1919
|
+
streamer,
|
|
1920
|
+
stopping_criteria: [stopping_criteria]
|
|
1921
|
+
});
|
|
1922
|
+
const answerText = result?.answer || "";
|
|
1923
|
+
return { text: answerText };
|
|
1924
|
+
};
|
|
1925
|
+
var HFT_TextQuestionAnswer_Stream = async function* (input, model, signal) {
|
|
1926
|
+
const noopProgress = () => {};
|
|
1927
|
+
const generateAnswer = await getPipeline(model, noopProgress, {}, signal);
|
|
1928
|
+
const { TextStreamer, InterruptableStoppingCriteria } = await loadTransformersSDK();
|
|
1929
|
+
const queue = createStreamEventQueue();
|
|
1930
|
+
const streamer = createStreamingTextStreamer(generateAnswer.tokenizer, queue, TextStreamer);
|
|
1931
|
+
const stopping_criteria = new InterruptableStoppingCriteria;
|
|
1932
|
+
if (signal) {
|
|
1933
|
+
signal.addEventListener("abort", () => stopping_criteria.interrupt(), { once: true });
|
|
1934
|
+
}
|
|
1935
|
+
let pipelineResult;
|
|
1936
|
+
const pipelinePromise = generateAnswer(input.question, input.context, {
|
|
1937
|
+
streamer,
|
|
1938
|
+
stopping_criteria: [stopping_criteria]
|
|
1939
|
+
}).then((result) => {
|
|
1940
|
+
pipelineResult = result;
|
|
1941
|
+
queue.done();
|
|
1942
|
+
}, (err) => queue.error(err));
|
|
1943
|
+
yield* queue.iterable;
|
|
1944
|
+
await pipelinePromise;
|
|
1945
|
+
let answerText = "";
|
|
1946
|
+
if (pipelineResult !== undefined) {
|
|
1947
|
+
if (Array.isArray(pipelineResult)) {
|
|
1948
|
+
answerText = pipelineResult[0]?.answer ?? "";
|
|
1949
|
+
} else {
|
|
1950
|
+
answerText = pipelineResult?.answer ?? "";
|
|
1951
|
+
}
|
|
1952
|
+
}
|
|
1953
|
+
yield { type: "finish", data: { text: answerText } };
|
|
1954
|
+
};
|
|
1955
|
+
|
|
1956
|
+
// src/ai-provider/common/HFT_TextRewriter.ts
|
|
1957
|
+
init_HFT_Pipeline();
|
|
1958
|
+
var HFT_TextRewriter = async (input, model, onProgress, signal) => {
|
|
1959
|
+
const generateText = await getPipeline(model, onProgress, {}, signal);
|
|
1960
|
+
const { TextStreamer, InterruptableStoppingCriteria } = await loadTransformersSDK();
|
|
1961
|
+
const streamer = createTextStreamer(generateText.tokenizer, onProgress, TextStreamer);
|
|
1962
|
+
const stopping_criteria = new InterruptableStoppingCriteria;
|
|
1963
|
+
if (signal) {
|
|
1964
|
+
signal.addEventListener("abort", () => stopping_criteria.interrupt(), { once: true });
|
|
1965
|
+
}
|
|
1966
|
+
const promptedText = (input.prompt ? input.prompt + `
|
|
1967
|
+
` : "") + input.text;
|
|
1968
|
+
let results = await generateText(promptedText, {
|
|
1969
|
+
streamer,
|
|
1970
|
+
stopping_criteria: [stopping_criteria]
|
|
1971
|
+
});
|
|
1972
|
+
if (!Array.isArray(results)) {
|
|
1973
|
+
results = [results];
|
|
1974
|
+
}
|
|
1975
|
+
const text = extractGeneratedText(results[0]?.generated_text);
|
|
1976
|
+
if (text === promptedText) {
|
|
1977
|
+
throw new Error("Rewriter failed to generate new text");
|
|
1978
|
+
}
|
|
1979
|
+
return {
|
|
1980
|
+
text
|
|
1981
|
+
};
|
|
1982
|
+
};
|
|
1983
|
+
var HFT_TextRewriter_Stream = async function* (input, model, signal) {
|
|
1984
|
+
const noopProgress = () => {};
|
|
1985
|
+
const generateText = await getPipeline(model, noopProgress, {}, signal);
|
|
1986
|
+
const { TextStreamer, InterruptableStoppingCriteria } = await loadTransformersSDK();
|
|
1987
|
+
const queue = createStreamEventQueue();
|
|
1988
|
+
const streamer = createStreamingTextStreamer(generateText.tokenizer, queue, TextStreamer);
|
|
1989
|
+
const stopping_criteria = new InterruptableStoppingCriteria;
|
|
1990
|
+
if (signal) {
|
|
1991
|
+
signal.addEventListener("abort", () => stopping_criteria.interrupt(), { once: true });
|
|
1992
|
+
}
|
|
1993
|
+
const promptedText = (input.prompt ? input.prompt + `
|
|
1994
|
+
` : "") + input.text;
|
|
1995
|
+
const pipelinePromise = generateText(promptedText, {
|
|
1996
|
+
streamer,
|
|
1997
|
+
stopping_criteria: [stopping_criteria]
|
|
1998
|
+
}).then(() => queue.done(), (err) => queue.error(err));
|
|
1999
|
+
yield* queue.iterable;
|
|
2000
|
+
await pipelinePromise;
|
|
2001
|
+
yield { type: "finish", data: {} };
|
|
2002
|
+
};
|
|
2003
|
+
|
|
2004
|
+
// src/ai-provider/common/HFT_TextSummary.ts
|
|
2005
|
+
init_HFT_Pipeline();
|
|
2006
|
+
var HFT_TextSummary = async (input, model, onProgress, signal) => {
|
|
2007
|
+
const generateSummary = await getPipeline(model, onProgress, {}, signal);
|
|
2008
|
+
const { TextStreamer, InterruptableStoppingCriteria } = await loadTransformersSDK();
|
|
2009
|
+
const streamer = createTextStreamer(generateSummary.tokenizer, onProgress, TextStreamer);
|
|
2010
|
+
const stopping_criteria = new InterruptableStoppingCriteria;
|
|
2011
|
+
if (signal) {
|
|
2012
|
+
signal.addEventListener("abort", () => stopping_criteria.interrupt(), { once: true });
|
|
2013
|
+
}
|
|
2014
|
+
const result = await generateSummary(input.text, {
|
|
2015
|
+
streamer,
|
|
2016
|
+
stopping_criteria: [stopping_criteria]
|
|
2017
|
+
});
|
|
2018
|
+
let summaryText = "";
|
|
2019
|
+
if (Array.isArray(result)) {
|
|
2020
|
+
summaryText = result[0]?.summary_text || "";
|
|
2021
|
+
} else {
|
|
2022
|
+
summaryText = result?.summary_text || "";
|
|
2023
|
+
}
|
|
2024
|
+
return {
|
|
2025
|
+
text: summaryText
|
|
2026
|
+
};
|
|
2027
|
+
};
|
|
2028
|
+
var HFT_TextSummary_Stream = async function* (input, model, signal) {
|
|
2029
|
+
const noopProgress = () => {};
|
|
2030
|
+
const generateSummary = await getPipeline(model, noopProgress, {}, signal);
|
|
2031
|
+
const { TextStreamer, InterruptableStoppingCriteria } = await loadTransformersSDK();
|
|
2032
|
+
const queue = createStreamEventQueue();
|
|
2033
|
+
const streamer = createStreamingTextStreamer(generateSummary.tokenizer, queue, TextStreamer);
|
|
2034
|
+
const stopping_criteria = new InterruptableStoppingCriteria;
|
|
2035
|
+
if (signal) {
|
|
2036
|
+
signal.addEventListener("abort", () => stopping_criteria.interrupt(), { once: true });
|
|
2037
|
+
}
|
|
2038
|
+
const pipelinePromise = generateSummary(input.text, {
|
|
2039
|
+
streamer,
|
|
2040
|
+
stopping_criteria: [stopping_criteria]
|
|
2041
|
+
}).then(() => queue.done(), (err) => queue.error(err));
|
|
2042
|
+
yield* queue.iterable;
|
|
2043
|
+
await pipelinePromise;
|
|
2044
|
+
yield { type: "finish", data: {} };
|
|
2045
|
+
};
|
|
2046
|
+
|
|
2047
|
+
// src/ai-provider/common/HFT_TextTranslation.ts
|
|
2048
|
+
init_HFT_Pipeline();
|
|
2049
|
+
var HFT_TextTranslation = async (input, model, onProgress, signal) => {
|
|
2050
|
+
const translate = await getPipeline(model, onProgress, {}, signal);
|
|
2051
|
+
const { TextStreamer, InterruptableStoppingCriteria } = await loadTransformersSDK();
|
|
2052
|
+
const streamer = createTextStreamer(translate.tokenizer, onProgress, TextStreamer);
|
|
2053
|
+
const stopping_criteria = new InterruptableStoppingCriteria;
|
|
2054
|
+
if (signal) {
|
|
2055
|
+
signal.addEventListener("abort", () => stopping_criteria.interrupt(), { once: true });
|
|
2056
|
+
}
|
|
2057
|
+
const result = await translate(input.text, {
|
|
2058
|
+
src_lang: input.source_lang,
|
|
2059
|
+
tgt_lang: input.target_lang,
|
|
2060
|
+
streamer,
|
|
2061
|
+
stopping_criteria: [stopping_criteria]
|
|
2062
|
+
});
|
|
2063
|
+
const translatedText = Array.isArray(result) ? result[0]?.translation_text || "" : result?.translation_text || "";
|
|
2064
|
+
return {
|
|
2065
|
+
text: translatedText,
|
|
2066
|
+
target_lang: input.target_lang
|
|
2067
|
+
};
|
|
2068
|
+
};
|
|
2069
|
+
var HFT_TextTranslation_Stream = async function* (input, model, signal) {
|
|
2070
|
+
const noopProgress = () => {};
|
|
2071
|
+
const translate = await getPipeline(model, noopProgress, {}, signal);
|
|
2072
|
+
const { TextStreamer, InterruptableStoppingCriteria } = await loadTransformersSDK();
|
|
2073
|
+
const queue = createStreamEventQueue();
|
|
2074
|
+
const streamer = createStreamingTextStreamer(translate.tokenizer, queue, TextStreamer);
|
|
2075
|
+
const stopping_criteria = new InterruptableStoppingCriteria;
|
|
2076
|
+
if (signal) {
|
|
2077
|
+
signal.addEventListener("abort", () => stopping_criteria.interrupt(), { once: true });
|
|
2078
|
+
}
|
|
2079
|
+
const pipelinePromise = translate(input.text, {
|
|
2080
|
+
src_lang: input.source_lang,
|
|
2081
|
+
tgt_lang: input.target_lang,
|
|
2082
|
+
streamer,
|
|
2083
|
+
stopping_criteria: [stopping_criteria]
|
|
2084
|
+
}).then(() => queue.done(), (err) => queue.error(err));
|
|
2085
|
+
yield* queue.iterable;
|
|
2086
|
+
await pipelinePromise;
|
|
2087
|
+
yield { type: "finish", data: { target_lang: input.target_lang } };
|
|
2088
|
+
};
|
|
2089
|
+
|
|
2090
|
+
// src/ai-provider/common/HFT_Unload.ts
|
|
2091
|
+
init_HFT_Pipeline();
|
|
2092
|
+
function hasBrowserCacheStorage() {
|
|
2093
|
+
return typeof globalThis !== "undefined" && "caches" in globalThis && typeof globalThis.caches?.open === "function";
|
|
2094
|
+
}
|
|
2095
|
+
async function deleteModelCacheFromBrowser(model_path) {
|
|
2096
|
+
const cachesApi = globalThis.caches;
|
|
2097
|
+
const cache = await cachesApi.open(HTF_CACHE_NAME);
|
|
2098
|
+
const keys = await cache.keys();
|
|
2099
|
+
const prefix = `/${model_path}/`;
|
|
2100
|
+
const requestsToDelete = [];
|
|
2101
|
+
for (const request of keys) {
|
|
2102
|
+
const url = new URL(request.url);
|
|
2103
|
+
if (url.pathname.startsWith(prefix)) {
|
|
2104
|
+
requestsToDelete.push(request);
|
|
2105
|
+
}
|
|
2106
|
+
}
|
|
2107
|
+
for (const request of requestsToDelete) {
|
|
2108
|
+
try {
|
|
2109
|
+
const deleted = await cache.delete(request);
|
|
2110
|
+
if (!deleted) {
|
|
2111
|
+
const deletedByUrl = await cache.delete(request.url);
|
|
2112
|
+
if (!deletedByUrl) {}
|
|
2113
|
+
}
|
|
2114
|
+
} catch (error) {
|
|
2115
|
+
console.error(`Failed to delete cache entry: ${request.url}`, error);
|
|
2116
|
+
}
|
|
2117
|
+
}
|
|
2118
|
+
}
|
|
2119
|
+
async function deleteModelCacheFromFilesystem(model) {
|
|
2120
|
+
const { ModelRegistry } = await loadTransformersSDK();
|
|
2121
|
+
const { pipeline: pipelineType, model_path, dtype } = model.provider_config;
|
|
2122
|
+
await ModelRegistry.clear_pipeline_cache(pipelineType, model_path, {
|
|
2123
|
+
...dtype ? { dtype } : {}
|
|
2124
|
+
});
|
|
2125
|
+
}
|
|
2126
|
+
var HFT_Unload = async (input, model, onProgress, _signal) => {
|
|
2127
|
+
const cacheKey = getPipelineCacheKey(model);
|
|
2128
|
+
if (removeCachedPipeline(cacheKey)) {
|
|
2129
|
+
onProgress(50, "Pipeline removed from memory");
|
|
2130
|
+
}
|
|
2131
|
+
const model_path = model.provider_config.model_path;
|
|
2132
|
+
disposeHftSessionsForModel(model_path);
|
|
2133
|
+
if (hasBrowserCacheStorage()) {
|
|
2134
|
+
await deleteModelCacheFromBrowser(model_path);
|
|
2135
|
+
} else {
|
|
2136
|
+
await deleteModelCacheFromFilesystem(model);
|
|
2137
|
+
}
|
|
2138
|
+
onProgress(100, "Model cache deleted");
|
|
2139
|
+
return {
|
|
2140
|
+
model: input.model
|
|
2141
|
+
};
|
|
2142
|
+
};
|
|
2143
|
+
|
|
2144
|
+
// src/ai-provider/common/HFT_JobRunFns.ts
|
|
2145
|
+
var HFT_TASKS = {
|
|
2146
|
+
AiChatTask: HFT_Chat,
|
|
2147
|
+
DownloadModelTask: HFT_Download,
|
|
2148
|
+
UnloadModelTask: HFT_Unload,
|
|
2149
|
+
ModelInfoTask: HFT_ModelInfo,
|
|
2150
|
+
CountTokensTask: HFT_CountTokens,
|
|
2151
|
+
TextEmbeddingTask: HFT_TextEmbedding,
|
|
2152
|
+
TextGenerationTask: HFT_TextGeneration,
|
|
2153
|
+
TextQuestionAnswerTask: HFT_TextQuestionAnswer,
|
|
2154
|
+
TextLanguageDetectionTask: HFT_TextLanguageDetection,
|
|
2155
|
+
TextClassificationTask: HFT_TextClassification,
|
|
2156
|
+
TextFillMaskTask: HFT_TextFillMask,
|
|
2157
|
+
TextNamedEntityRecognitionTask: HFT_TextNamedEntityRecognition,
|
|
2158
|
+
TextRewriterTask: HFT_TextRewriter,
|
|
2159
|
+
TextSummaryTask: HFT_TextSummary,
|
|
2160
|
+
TextTranslationTask: HFT_TextTranslation,
|
|
2161
|
+
ImageSegmentationTask: HFT_ImageSegmentation,
|
|
2162
|
+
ImageToTextTask: HFT_ImageToText,
|
|
2163
|
+
BackgroundRemovalTask: HFT_BackgroundRemoval,
|
|
2164
|
+
ImageEmbeddingTask: HFT_ImageEmbedding,
|
|
2165
|
+
ImageClassificationTask: HFT_ImageClassification,
|
|
2166
|
+
ObjectDetectionTask: HFT_ObjectDetection,
|
|
2167
|
+
ToolCallingTask: HFT_ToolCalling,
|
|
2168
|
+
StructuredGenerationTask: HFT_StructuredGeneration,
|
|
2169
|
+
ModelSearchTask: HFT_ModelSearch
|
|
2170
|
+
};
|
|
2171
|
+
var HFT_STREAM_TASKS = {
|
|
2172
|
+
AiChatTask: HFT_Chat_Stream,
|
|
2173
|
+
TextGenerationTask: HFT_TextGeneration_Stream,
|
|
2174
|
+
TextRewriterTask: HFT_TextRewriter_Stream,
|
|
2175
|
+
TextSummaryTask: HFT_TextSummary_Stream,
|
|
2176
|
+
TextQuestionAnswerTask: HFT_TextQuestionAnswer_Stream,
|
|
2177
|
+
TextTranslationTask: HFT_TextTranslation_Stream,
|
|
2178
|
+
ToolCallingTask: HFT_ToolCalling_Stream,
|
|
2179
|
+
StructuredGenerationTask: HFT_StructuredGeneration_Stream
|
|
2180
|
+
};
|
|
2181
|
+
var HFT_PREVIEW_TASKS = {
|
|
2182
|
+
CountTokensTask: HFT_CountTokens_Preview
|
|
2183
|
+
};
|
|
2184
|
+
|
|
2185
|
+
// src/ai-provider/registerHuggingFaceTransformersInline.ts
|
|
2186
|
+
init_HFT_Pipeline();
|
|
2187
|
+
|
|
2188
|
+
// src/ai-provider/HuggingFaceTransformersQueuedProvider.ts
|
|
2189
|
+
import { QueuedAiProvider } from "@workglow/ai";
|
|
2190
|
+
init_HFT_Pipeline();
|
|
2191
|
+
var GPU_DEVICES = new Set(["webgpu", "gpu", "metal"]);
|
|
2192
|
+
var HFT_CPU_QUEUE_CONCURRENCY_PRODUCTION = 4;
|
|
2193
|
+
function hftIsAutomatedTestEnvironment() {
|
|
2194
|
+
if (typeof process === "undefined") {
|
|
2195
|
+
return false;
|
|
2196
|
+
}
|
|
2197
|
+
const e = process.env;
|
|
2198
|
+
return e.VITEST === "true" || e.NODE_ENV === "test" || e.BUN_TEST === "1" || e.JEST_WORKER_ID !== undefined;
|
|
2199
|
+
}
|
|
2200
|
+
function hftDefaultCpuQueueConcurrency() {
|
|
2201
|
+
return hftIsAutomatedTestEnvironment() ? 1 : HFT_CPU_QUEUE_CONCURRENCY_PRODUCTION;
|
|
2202
|
+
}
|
|
2203
|
+
function resolveHftCpuQueueConcurrency(concurrency, defaultCpu) {
|
|
2204
|
+
if (concurrency === undefined) {
|
|
2205
|
+
return defaultCpu();
|
|
2206
|
+
}
|
|
2207
|
+
if (typeof concurrency === "number") {
|
|
2208
|
+
return defaultCpu();
|
|
2209
|
+
}
|
|
2210
|
+
return concurrency.cpu ?? defaultCpu();
|
|
2211
|
+
}
|
|
2212
|
+
|
|
2213
|
+
class HuggingFaceTransformersQueuedProvider extends QueuedAiProvider {
|
|
2214
|
+
name = HF_TRANSFORMERS_ONNX;
|
|
2215
|
+
displayName = "Hugging Face Transformers (ONNX)";
|
|
2216
|
+
isLocal = true;
|
|
2217
|
+
supportsBrowser = true;
|
|
2218
|
+
cpuStrategy;
|
|
2219
|
+
taskTypes = [
|
|
2220
|
+
"AiChatTask",
|
|
2221
|
+
"DownloadModelTask",
|
|
2222
|
+
"UnloadModelTask",
|
|
2223
|
+
"ModelInfoTask",
|
|
2224
|
+
"CountTokensTask",
|
|
2225
|
+
"TextEmbeddingTask",
|
|
2226
|
+
"TextGenerationTask",
|
|
2227
|
+
"TextQuestionAnswerTask",
|
|
2228
|
+
"TextLanguageDetectionTask",
|
|
2229
|
+
"TextClassificationTask",
|
|
2230
|
+
"TextFillMaskTask",
|
|
2231
|
+
"TextNamedEntityRecognitionTask",
|
|
2232
|
+
"TextRewriterTask",
|
|
2233
|
+
"TextSummaryTask",
|
|
2234
|
+
"TextTranslationTask",
|
|
2235
|
+
"ImageSegmentationTask",
|
|
2236
|
+
"ImageToTextTask",
|
|
2237
|
+
"BackgroundRemovalTask",
|
|
2238
|
+
"ImageEmbeddingTask",
|
|
2239
|
+
"ImageClassificationTask",
|
|
2240
|
+
"ObjectDetectionTask",
|
|
2241
|
+
"ToolCallingTask",
|
|
2242
|
+
"StructuredGenerationTask",
|
|
2243
|
+
"ModelSearchTask"
|
|
2244
|
+
];
|
|
2245
|
+
constructor(tasks, streamTasks, previewTasks) {
|
|
2246
|
+
super(tasks, streamTasks, previewTasks);
|
|
2247
|
+
}
|
|
2248
|
+
createSession(_model) {
|
|
2249
|
+
return crypto.randomUUID();
|
|
2250
|
+
}
|
|
2251
|
+
async disposeSession(sessionId) {
|
|
2252
|
+
deleteHftSession(sessionId);
|
|
2253
|
+
}
|
|
2254
|
+
async afterRegister(options) {
|
|
2255
|
+
await super.afterRegister(options);
|
|
2256
|
+
this.cpuStrategy = this.createQueuedStrategy(HF_TRANSFORMERS_ONNX_CPU, resolveHftCpuQueueConcurrency(options.queue?.concurrency, hftDefaultCpuQueueConcurrency), options);
|
|
2257
|
+
}
|
|
2258
|
+
getStrategyForModel(model) {
|
|
2259
|
+
const device = model.provider_config?.device;
|
|
2260
|
+
if (device && GPU_DEVICES.has(device)) {
|
|
2261
|
+
return this.queuedStrategy;
|
|
2262
|
+
}
|
|
2263
|
+
return this.cpuStrategy;
|
|
2264
|
+
}
|
|
2265
|
+
}
|
|
2266
|
+
|
|
2267
|
+
// src/ai-provider/registerHuggingFaceTransformersInline.ts
|
|
2268
|
+
async function registerHuggingFaceTransformersInline(options) {
|
|
2269
|
+
const { env } = await loadTransformersSDK();
|
|
2270
|
+
env.backends.onnx.wasm.proxy = true;
|
|
2271
|
+
const provider = new HuggingFaceTransformersQueuedProvider(HFT_TASKS, HFT_STREAM_TASKS, HFT_PREVIEW_TASKS);
|
|
2272
|
+
const baseDispose = provider.dispose.bind(provider);
|
|
2273
|
+
provider.dispose = async () => {
|
|
2274
|
+
await clearHftInlinePipelineCache();
|
|
2275
|
+
await baseDispose();
|
|
2276
|
+
};
|
|
2277
|
+
await registerProviderInline(provider, "HuggingFaceTransformers", options);
|
|
2278
|
+
}
|
|
2279
|
+
|
|
2280
|
+
// src/ai-provider/registerHuggingFaceTransformersWorker.ts
|
|
2281
|
+
import { registerProviderWorker } from "@workglow/ai-provider/common";
|
|
2282
|
+
|
|
2283
|
+
// src/ai-provider/HuggingFaceTransformersProvider.ts
|
|
2284
|
+
import { AiProvider } from "@workglow/ai/worker";
|
|
2285
|
+
init_HFT_Pipeline();
|
|
2286
|
+
|
|
2287
|
+
class HuggingFaceTransformersProvider extends AiProvider {
|
|
2288
|
+
name = HF_TRANSFORMERS_ONNX;
|
|
2289
|
+
displayName = "Hugging Face Transformers (ONNX)";
|
|
2290
|
+
isLocal = true;
|
|
2291
|
+
supportsBrowser = true;
|
|
2292
|
+
taskTypes = [
|
|
2293
|
+
"AiChatTask",
|
|
2294
|
+
"DownloadModelTask",
|
|
2295
|
+
"UnloadModelTask",
|
|
2296
|
+
"ModelInfoTask",
|
|
2297
|
+
"CountTokensTask",
|
|
2298
|
+
"TextEmbeddingTask",
|
|
2299
|
+
"TextGenerationTask",
|
|
2300
|
+
"TextQuestionAnswerTask",
|
|
2301
|
+
"TextLanguageDetectionTask",
|
|
2302
|
+
"TextClassificationTask",
|
|
2303
|
+
"TextFillMaskTask",
|
|
2304
|
+
"TextNamedEntityRecognitionTask",
|
|
2305
|
+
"TextRewriterTask",
|
|
2306
|
+
"TextSummaryTask",
|
|
2307
|
+
"TextTranslationTask",
|
|
2308
|
+
"ImageSegmentationTask",
|
|
2309
|
+
"ImageToTextTask",
|
|
2310
|
+
"BackgroundRemovalTask",
|
|
2311
|
+
"ImageEmbeddingTask",
|
|
2312
|
+
"ImageClassificationTask",
|
|
2313
|
+
"ObjectDetectionTask",
|
|
2314
|
+
"ToolCallingTask",
|
|
2315
|
+
"StructuredGenerationTask",
|
|
2316
|
+
"ModelSearchTask"
|
|
2317
|
+
];
|
|
2318
|
+
constructor(tasks, streamTasks, previewTasks) {
|
|
2319
|
+
super(tasks, streamTasks, previewTasks);
|
|
2320
|
+
}
|
|
2321
|
+
createSession(_model) {
|
|
2322
|
+
return crypto.randomUUID();
|
|
2323
|
+
}
|
|
2324
|
+
async disposeSession(sessionId) {
|
|
2325
|
+
deleteHftSession(sessionId);
|
|
2326
|
+
}
|
|
2327
|
+
}
|
|
2328
|
+
|
|
2329
|
+
// src/ai-provider/registerHuggingFaceTransformersWorker.ts
|
|
2330
|
+
init_HFT_Pipeline();
|
|
2331
|
+
async function registerHuggingFaceTransformersWorker() {
|
|
2332
|
+
const sdk = await loadTransformersSDK();
|
|
2333
|
+
globalThis.__HFT__ = sdk;
|
|
2334
|
+
const { env } = sdk;
|
|
2335
|
+
env.backends.onnx.wasm.proxy = true;
|
|
2336
|
+
await registerProviderWorker((ws) => new HuggingFaceTransformersProvider(HFT_TASKS, HFT_STREAM_TASKS, HFT_PREVIEW_TASKS).registerOnWorkerServer(ws), "HuggingFaceTransformers");
|
|
2337
|
+
}
|
|
2338
|
+
export {
|
|
2339
|
+
setHftSession,
|
|
2340
|
+
setHftCacheDir,
|
|
2341
|
+
removeCachedPipeline,
|
|
2342
|
+
registerHuggingFaceTransformersWorker,
|
|
2343
|
+
registerHuggingFaceTransformersInline,
|
|
2344
|
+
parseOnnxQuantizations,
|
|
2345
|
+
loadTransformersSDK,
|
|
2346
|
+
hasCachedPipeline,
|
|
2347
|
+
getPipelineCacheKey,
|
|
2348
|
+
getPipeline,
|
|
2349
|
+
getHftSession,
|
|
2350
|
+
disposeHftSessionsForModel,
|
|
2351
|
+
deleteHftSession,
|
|
2352
|
+
createToolCallMarkupFilter,
|
|
2353
|
+
clearPipelineCache,
|
|
2354
|
+
QuantizationDataType,
|
|
2355
|
+
PipelineUseCase,
|
|
2356
|
+
ONNX_QUANTIZATION_SUFFIX_MAPPING,
|
|
2357
|
+
HfTransformersOnnxModelSchema,
|
|
2358
|
+
HfTransformersOnnxModelRecordSchema,
|
|
2359
|
+
HfTransformersOnnxModelConfigSchema,
|
|
2360
|
+
HTF_CACHE_NAME,
|
|
2361
|
+
HF_TRANSFORMERS_ONNX_GPU,
|
|
2362
|
+
HF_TRANSFORMERS_ONNX_CPU,
|
|
2363
|
+
HF_TRANSFORMERS_ONNX,
|
|
2364
|
+
HFT_NULL_PROCESSOR_PREFIX
|
|
2365
|
+
};
|
|
2366
|
+
|
|
2367
|
+
//# debugId=FDAFA958F0DFFF6D64756E2164756E21
|