@workglow/ai-provider 0.0.126 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. package/dist/provider-anthropic/AnthropicProvider.d.ts +1 -1
  2. package/dist/provider-anthropic/AnthropicProvider.d.ts.map +1 -1
  3. package/dist/provider-anthropic/AnthropicQueuedProvider.d.ts +4 -4
  4. package/dist/provider-anthropic/AnthropicQueuedProvider.d.ts.map +1 -1
  5. package/dist/provider-anthropic/common/Anthropic_Client.d.ts.map +1 -1
  6. package/dist/provider-anthropic/common/Anthropic_CountTokens.d.ts.map +1 -1
  7. package/dist/provider-anthropic/common/Anthropic_JobRunFns.d.ts.map +1 -1
  8. package/dist/provider-anthropic/common/Anthropic_TextGeneration.d.ts.map +1 -1
  9. package/dist/provider-anthropic/common/Anthropic_TextRewriter.d.ts.map +1 -1
  10. package/dist/provider-anthropic/common/Anthropic_TextSummary.d.ts.map +1 -1
  11. package/dist/provider-anthropic/index.js +3 -4
  12. package/dist/provider-anthropic/index.js.map +3 -3
  13. package/dist/provider-anthropic/runtime.js +19 -291
  14. package/dist/provider-anthropic/runtime.js.map +11 -12
  15. package/dist/provider-chrome/common/WebBrowser_TextGeneration.d.ts.map +1 -1
  16. package/dist/provider-chrome/common/WebBrowser_TextLanguageDetection.d.ts.map +1 -1
  17. package/dist/provider-chrome/common/WebBrowser_TextRewriter.d.ts.map +1 -1
  18. package/dist/provider-chrome/common/WebBrowser_TextSummary.d.ts.map +1 -1
  19. package/dist/provider-chrome/common/WebBrowser_TextTranslation.d.ts.map +1 -1
  20. package/dist/provider-chrome/runtime.js +3 -97
  21. package/dist/provider-chrome/runtime.js.map +8 -8
  22. package/dist/provider-gemini/GoogleGeminiProvider.d.ts +1 -1
  23. package/dist/provider-gemini/GoogleGeminiProvider.d.ts.map +1 -1
  24. package/dist/provider-gemini/GoogleGeminiQueuedProvider.d.ts +4 -4
  25. package/dist/provider-gemini/GoogleGeminiQueuedProvider.d.ts.map +1 -1
  26. package/dist/provider-gemini/common/Gemini_CountTokens.d.ts.map +1 -1
  27. package/dist/provider-gemini/common/Gemini_JobRunFns.d.ts.map +1 -1
  28. package/dist/provider-gemini/common/Gemini_TextEmbedding.d.ts.map +1 -1
  29. package/dist/provider-gemini/common/Gemini_TextGeneration.d.ts.map +1 -1
  30. package/dist/provider-gemini/common/Gemini_TextRewriter.d.ts.map +1 -1
  31. package/dist/provider-gemini/common/Gemini_TextSummary.d.ts.map +1 -1
  32. package/dist/provider-gemini/index.js +3 -4
  33. package/dist/provider-gemini/index.js.map +3 -3
  34. package/dist/provider-gemini/runtime.js +12 -257
  35. package/dist/provider-gemini/runtime.js.map +11 -12
  36. package/dist/provider-hf-inference/HfInferenceProvider.d.ts +1 -1
  37. package/dist/provider-hf-inference/HfInferenceProvider.d.ts.map +1 -1
  38. package/dist/provider-hf-inference/HfInferenceQueuedProvider.d.ts +4 -4
  39. package/dist/provider-hf-inference/HfInferenceQueuedProvider.d.ts.map +1 -1
  40. package/dist/provider-hf-inference/common/HFI_Client.d.ts.map +1 -1
  41. package/dist/provider-hf-inference/common/HFI_JobRunFns.d.ts.map +1 -1
  42. package/dist/provider-hf-inference/common/HFI_TextEmbedding.d.ts.map +1 -1
  43. package/dist/provider-hf-inference/common/HFI_TextGeneration.d.ts.map +1 -1
  44. package/dist/provider-hf-inference/common/HFI_TextRewriter.d.ts.map +1 -1
  45. package/dist/provider-hf-inference/common/HFI_TextSummary.d.ts.map +1 -1
  46. package/dist/provider-hf-inference/index.js +3 -4
  47. package/dist/provider-hf-inference/index.js.map +3 -3
  48. package/dist/provider-hf-inference/runtime.js +13 -206
  49. package/dist/provider-hf-inference/runtime.js.map +11 -12
  50. package/dist/provider-hf-transformers/HuggingFaceTransformersProvider.d.ts +1 -1
  51. package/dist/provider-hf-transformers/HuggingFaceTransformersProvider.d.ts.map +1 -1
  52. package/dist/provider-hf-transformers/HuggingFaceTransformersQueuedProvider.d.ts +13 -3
  53. package/dist/provider-hf-transformers/HuggingFaceTransformersQueuedProvider.d.ts.map +1 -1
  54. package/dist/provider-hf-transformers/common/HFT_Constants.d.ts +4 -0
  55. package/dist/provider-hf-transformers/common/HFT_Constants.d.ts.map +1 -1
  56. package/dist/provider-hf-transformers/common/HFT_CountTokens.d.ts.map +1 -1
  57. package/dist/provider-hf-transformers/common/HFT_Download.d.ts.map +1 -1
  58. package/dist/provider-hf-transformers/common/HFT_ImageEmbedding.d.ts.map +1 -1
  59. package/dist/provider-hf-transformers/common/HFT_JobRunFns.d.ts +116 -87
  60. package/dist/provider-hf-transformers/common/HFT_JobRunFns.d.ts.map +1 -1
  61. package/dist/provider-hf-transformers/common/HFT_ModelInfo.d.ts.map +1 -1
  62. package/dist/provider-hf-transformers/common/HFT_ModelSchema.d.ts +30 -0
  63. package/dist/provider-hf-transformers/common/HFT_ModelSchema.d.ts.map +1 -1
  64. package/dist/provider-hf-transformers/common/HFT_OnnxDtypes.d.ts.map +1 -1
  65. package/dist/provider-hf-transformers/common/HFT_Pipeline.d.ts +9 -2
  66. package/dist/provider-hf-transformers/common/HFT_Pipeline.d.ts.map +1 -1
  67. package/dist/provider-hf-transformers/common/HFT_Streaming.d.ts +2 -2
  68. package/dist/provider-hf-transformers/common/HFT_Streaming.d.ts.map +1 -1
  69. package/dist/provider-hf-transformers/common/HFT_TextClassification.d.ts.map +1 -1
  70. package/dist/provider-hf-transformers/common/HFT_TextFillMask.d.ts.map +1 -1
  71. package/dist/provider-hf-transformers/common/HFT_TextGeneration.d.ts.map +1 -1
  72. package/dist/provider-hf-transformers/common/HFT_TextLanguageDetection.d.ts.map +1 -1
  73. package/dist/provider-hf-transformers/common/HFT_TextNamedEntityRecognition.d.ts.map +1 -1
  74. package/dist/provider-hf-transformers/common/HFT_TextQuestionAnswer.d.ts.map +1 -1
  75. package/dist/provider-hf-transformers/common/HFT_TextRewriter.d.ts.map +1 -1
  76. package/dist/provider-hf-transformers/common/HFT_TextSummary.d.ts.map +1 -1
  77. package/dist/provider-hf-transformers/common/HFT_TextTranslation.d.ts.map +1 -1
  78. package/dist/provider-hf-transformers/index.d.ts +0 -1
  79. package/dist/provider-hf-transformers/index.d.ts.map +1 -1
  80. package/dist/provider-hf-transformers/index.js +49 -177
  81. package/dist/provider-hf-transformers/index.js.map +8 -9
  82. package/dist/provider-hf-transformers/registerHuggingFaceTransformersWorker.d.ts.map +1 -1
  83. package/dist/provider-hf-transformers/runtime.d.ts +0 -1
  84. package/dist/provider-hf-transformers/runtime.d.ts.map +1 -1
  85. package/dist/provider-hf-transformers/runtime.js +208 -513
  86. package/dist/provider-hf-transformers/runtime.js.map +27 -29
  87. package/dist/provider-llamacpp/LlamaCppProvider.d.ts +1 -1
  88. package/dist/provider-llamacpp/LlamaCppProvider.d.ts.map +1 -1
  89. package/dist/provider-llamacpp/LlamaCppQueuedProvider.d.ts +1 -1
  90. package/dist/provider-llamacpp/LlamaCppQueuedProvider.d.ts.map +1 -1
  91. package/dist/provider-llamacpp/common/LlamaCpp_CountTokens.d.ts.map +1 -1
  92. package/dist/provider-llamacpp/common/LlamaCpp_JobRunFns.d.ts.map +1 -1
  93. package/dist/provider-llamacpp/common/LlamaCpp_ModelSchema.d.ts +15 -0
  94. package/dist/provider-llamacpp/common/LlamaCpp_ModelSchema.d.ts.map +1 -1
  95. package/dist/provider-llamacpp/common/LlamaCpp_Runtime.d.ts +10 -0
  96. package/dist/provider-llamacpp/common/LlamaCpp_Runtime.d.ts.map +1 -1
  97. package/dist/provider-llamacpp/common/LlamaCpp_StructuredGeneration.d.ts.map +1 -1
  98. package/dist/provider-llamacpp/common/LlamaCpp_TextEmbedding.d.ts.map +1 -1
  99. package/dist/provider-llamacpp/common/LlamaCpp_TextGeneration.d.ts.map +1 -1
  100. package/dist/provider-llamacpp/common/LlamaCpp_TextRewriter.d.ts.map +1 -1
  101. package/dist/provider-llamacpp/common/LlamaCpp_TextSummary.d.ts.map +1 -1
  102. package/dist/provider-llamacpp/index.js +6 -2
  103. package/dist/provider-llamacpp/index.js.map +4 -4
  104. package/dist/provider-llamacpp/runtime.js +82 -230
  105. package/dist/provider-llamacpp/runtime.js.map +13 -14
  106. package/dist/provider-ollama/OllamaProvider.d.ts +1 -1
  107. package/dist/provider-ollama/OllamaProvider.d.ts.map +1 -1
  108. package/dist/provider-ollama/OllamaQueuedProvider.d.ts +4 -4
  109. package/dist/provider-ollama/OllamaQueuedProvider.d.ts.map +1 -1
  110. package/dist/provider-ollama/common/Ollama_JobRunFns.browser.d.ts +13 -71
  111. package/dist/provider-ollama/common/Ollama_JobRunFns.browser.d.ts.map +1 -1
  112. package/dist/provider-ollama/common/Ollama_JobRunFns.d.ts +13 -71
  113. package/dist/provider-ollama/common/Ollama_JobRunFns.d.ts.map +1 -1
  114. package/dist/provider-ollama/common/Ollama_TextGeneration.d.ts.map +1 -1
  115. package/dist/provider-ollama/common/Ollama_TextRewriter.d.ts.map +1 -1
  116. package/dist/provider-ollama/common/Ollama_TextSummary.d.ts.map +1 -1
  117. package/dist/provider-ollama/index.browser.js +3 -4
  118. package/dist/provider-ollama/index.browser.js.map +3 -3
  119. package/dist/provider-ollama/index.js +3 -4
  120. package/dist/provider-ollama/index.js.map +3 -3
  121. package/dist/provider-ollama/runtime.browser.js +8 -179
  122. package/dist/provider-ollama/runtime.browser.js.map +9 -10
  123. package/dist/provider-ollama/runtime.js +8 -174
  124. package/dist/provider-ollama/runtime.js.map +9 -10
  125. package/dist/provider-openai/OpenAiProvider.d.ts +1 -1
  126. package/dist/provider-openai/OpenAiProvider.d.ts.map +1 -1
  127. package/dist/provider-openai/OpenAiQueuedProvider.d.ts +4 -4
  128. package/dist/provider-openai/OpenAiQueuedProvider.d.ts.map +1 -1
  129. package/dist/provider-openai/common/OpenAI_Client.d.ts.map +1 -1
  130. package/dist/provider-openai/common/OpenAI_CountTokens.browser.d.ts.map +1 -1
  131. package/dist/provider-openai/common/OpenAI_CountTokens.d.ts.map +1 -1
  132. package/dist/provider-openai/common/OpenAI_JobRunFns.browser.d.ts.map +1 -1
  133. package/dist/provider-openai/common/OpenAI_JobRunFns.d.ts.map +1 -1
  134. package/dist/provider-openai/common/OpenAI_TextEmbedding.d.ts.map +1 -1
  135. package/dist/provider-openai/common/OpenAI_TextGeneration.d.ts.map +1 -1
  136. package/dist/provider-openai/common/OpenAI_TextRewriter.d.ts.map +1 -1
  137. package/dist/provider-openai/common/OpenAI_TextSummary.d.ts.map +1 -1
  138. package/dist/provider-openai/index.browser.js +3 -4
  139. package/dist/provider-openai/index.browser.js.map +3 -3
  140. package/dist/provider-openai/index.js +3 -4
  141. package/dist/provider-openai/index.js.map +3 -3
  142. package/dist/provider-openai/runtime.browser.js +22 -224
  143. package/dist/provider-openai/runtime.browser.js.map +12 -13
  144. package/dist/provider-openai/runtime.js +22 -224
  145. package/dist/provider-openai/runtime.js.map +12 -13
  146. package/dist/provider-tf-mediapipe/TensorFlowMediaPipeQueuedProvider.d.ts +3 -3
  147. package/dist/provider-tf-mediapipe/TensorFlowMediaPipeQueuedProvider.d.ts.map +1 -1
  148. package/dist/provider-tf-mediapipe/common/TFMP_ImageEmbedding.d.ts.map +1 -1
  149. package/dist/provider-tf-mediapipe/common/TFMP_JobRunFns.d.ts +17 -10
  150. package/dist/provider-tf-mediapipe/common/TFMP_JobRunFns.d.ts.map +1 -1
  151. package/dist/provider-tf-mediapipe/common/TFMP_Unload.d.ts.map +1 -1
  152. package/dist/provider-tf-mediapipe/index.js +3 -3
  153. package/dist/provider-tf-mediapipe/index.js.map +3 -3
  154. package/dist/provider-tf-mediapipe/runtime.js +16 -5
  155. package/dist/provider-tf-mediapipe/runtime.js.map +5 -5
  156. package/package.json +15 -15
  157. package/dist/provider-anthropic/common/Anthropic_ToolCalling.d.ts +0 -10
  158. package/dist/provider-anthropic/common/Anthropic_ToolCalling.d.ts.map +0 -1
  159. package/dist/provider-gemini/common/Gemini_ToolCalling.d.ts +0 -10
  160. package/dist/provider-gemini/common/Gemini_ToolCalling.d.ts.map +0 -1
  161. package/dist/provider-hf-inference/common/HFI_ToolCalling.d.ts +0 -10
  162. package/dist/provider-hf-inference/common/HFI_ToolCalling.d.ts.map +0 -1
  163. package/dist/provider-hf-transformers/common/HFT_ToolCalling.d.ts +0 -10
  164. package/dist/provider-hf-transformers/common/HFT_ToolCalling.d.ts.map +0 -1
  165. package/dist/provider-hf-transformers/common/HFT_ToolMarkup.d.ts +0 -40
  166. package/dist/provider-hf-transformers/common/HFT_ToolMarkup.d.ts.map +0 -1
  167. package/dist/provider-llamacpp/common/LlamaCpp_ToolCalling.d.ts +0 -10
  168. package/dist/provider-llamacpp/common/LlamaCpp_ToolCalling.d.ts.map +0 -1
  169. package/dist/provider-ollama/common/Ollama_ToolCalling.d.ts +0 -16
  170. package/dist/provider-ollama/common/Ollama_ToolCalling.d.ts.map +0 -1
  171. package/dist/provider-openai/common/OpenAI_ToolCalling.d.ts +0 -10
  172. package/dist/provider-openai/common/OpenAI_ToolCalling.d.ts.map +0 -1
@@ -30,7 +30,8 @@ __export(exports_HFT_Pipeline, {
30
30
  hasCachedPipeline: () => hasCachedPipeline,
31
31
  getPipelineCacheKey: () => getPipelineCacheKey,
32
32
  getPipeline: () => getPipeline,
33
- clearPipelineCache: () => clearPipelineCache
33
+ clearPipelineCache: () => clearPipelineCache,
34
+ HFT_NULL_PROCESSOR_PREFIX: () => HFT_NULL_PROCESSOR_PREFIX
34
35
  });
35
36
  import { getLogger } from "@workglow/util/worker";
36
37
  function setHftCacheDir(dir) {
@@ -53,18 +54,99 @@ async function loadTransformersSDK() {
53
54
  }
54
55
  return _transformersSdk;
55
56
  }
57
+ function combineAbortSignals(existingSignal, modelSignal) {
58
+ if (!existingSignal) {
59
+ return modelSignal;
60
+ }
61
+ if (!modelSignal) {
62
+ return existingSignal;
63
+ }
64
+ if (existingSignal.aborted || modelSignal.aborted) {
65
+ return AbortSignal.abort(existingSignal.reason ?? modelSignal.reason);
66
+ }
67
+ if (typeof AbortSignal.any === "function") {
68
+ return AbortSignal.any([existingSignal, modelSignal]);
69
+ }
70
+ const controller = new AbortController;
71
+ const abort = (event) => {
72
+ const signal = event.target;
73
+ controller.abort(signal.reason);
74
+ };
75
+ existingSignal.addEventListener("abort", abort, { once: true });
76
+ modelSignal.addEventListener("abort", abort, { once: true });
77
+ return controller.signal;
78
+ }
79
+ function createAbortError(signal) {
80
+ const reason = signal.reason;
81
+ if (reason instanceof Error) {
82
+ return reason;
83
+ }
84
+ return new Error(String(reason ?? "Fetch aborted"));
85
+ }
86
+ function wrapAbortableResponse(response, signal) {
87
+ if (!signal || !response.body) {
88
+ return response;
89
+ }
90
+ const contentLengthHeader = response.headers.get("content-length");
91
+ const expectedSize = contentLengthHeader && /^\d+$/.test(contentLengthHeader) ? Number.parseInt(contentLengthHeader, 10) : undefined;
92
+ const sourceBody = response.body;
93
+ const body = new ReadableStream({
94
+ async start(controller) {
95
+ const reader = sourceBody.getReader();
96
+ const abort = () => controller.error(createAbortError(signal));
97
+ signal.addEventListener("abort", abort, { once: true });
98
+ let loaded = 0;
99
+ try {
100
+ while (true) {
101
+ if (signal.aborted) {
102
+ throw createAbortError(signal);
103
+ }
104
+ const { done, value } = await reader.read();
105
+ if (done) {
106
+ if (signal.aborted) {
107
+ throw createAbortError(signal);
108
+ }
109
+ if (expectedSize !== undefined && loaded < expectedSize) {
110
+ throw new Error(`Fetch ended before reading the full response body (${loaded}/${expectedSize} bytes)`);
111
+ }
112
+ controller.close();
113
+ return;
114
+ }
115
+ loaded += value.length;
116
+ controller.enqueue(value);
117
+ }
118
+ } catch (error) {
119
+ controller.error(error);
120
+ } finally {
121
+ signal.removeEventListener("abort", abort);
122
+ reader.releaseLock();
123
+ }
124
+ },
125
+ async cancel(reason) {
126
+ try {
127
+ await sourceBody.cancel(reason);
128
+ } catch {}
129
+ }
130
+ });
131
+ return new Response(body, {
132
+ headers: new Headers(response.headers),
133
+ status: response.status,
134
+ statusText: response.statusText
135
+ });
136
+ }
56
137
  function abortableFetch(url, options) {
57
- let signal;
138
+ let modelSignal;
58
139
  try {
59
140
  const pathname = new URL(url).pathname;
60
141
  for (const [modelPath, controller] of modelAbortControllers) {
61
142
  if (pathname.includes(`/${modelPath}/`)) {
62
- signal = controller.signal;
143
+ modelSignal = controller.signal;
63
144
  break;
64
145
  }
65
146
  }
66
147
  } catch {}
67
- return fetch(url, { ...options, ...signal ? { signal } : {} });
148
+ const combinedSignal = combineAbortSignals(options.signal, modelSignal);
149
+ return fetch(url, { ...options, ...combinedSignal ? { signal: combinedSignal } : {} }).then((response) => wrapAbortableResponse(response, combinedSignal));
68
150
  }
69
151
  function clearPipelineCache() {
70
152
  pipelines.clear();
@@ -85,9 +167,10 @@ function isBrowserEnv() {
85
167
  return false;
86
168
  }
87
169
  function getPipelineCacheKey(model) {
88
- const dtype = model.provider_config.dtype || "q8";
170
+ const dtype = model.provider_config.dtype || "";
89
171
  const device = model.provider_config.device || "";
90
- return `${model.provider_config.model_path}:${model.provider_config.pipeline}:${dtype}:${device}`;
172
+ const revision = model.provider_config.revision || "main";
173
+ return `${model.provider_config.model_path}:${model.provider_config.pipeline}:${dtype}:${device}:${revision}`;
91
174
  }
92
175
  async function getPipeline(model, onProgress, options = {}, signal, progressScaleMax = 10) {
93
176
  const cacheKey = getPipelineCacheKey(model);
@@ -108,7 +191,7 @@ async function getPipeline(model, onProgress, options = {}, signal, progressScal
108
191
  pipelineLoadPromises.set(cacheKey, loadPromise);
109
192
  return loadPromise;
110
193
  }
111
- var _transformersSdk, _cacheDir, modelAbortControllers, pipelines, pipelineLoadPromises, doGetPipeline = async (model, onProgress, options, progressScaleMax, cacheKey, signal) => {
194
+ var _transformersSdk, _cacheDir, modelAbortControllers, pipelines, pipelineLoadPromises, IMAGE_PIPELINE_TYPES, HFT_NULL_PROCESSOR_PREFIX = "HFT_NULL_PROCESSOR:", doGetPipeline = async (model, onProgress, options, progressScaleMax, cacheKey, signal) => {
112
195
  let lastProgressTime = 0;
113
196
  let pendingProgress = null;
114
197
  let throttleTimer = null;
@@ -206,16 +289,18 @@ var _transformersSdk, _cacheDir, modelAbortControllers, pipelines, pipelineLoadP
206
289
  device = "wasm";
207
290
  }
208
291
  if (device !== "wasm" && device !== "webgpu") {
209
- device = "webgpu";
292
+ device = "wasm";
210
293
  }
211
294
  } else {
212
295
  if (device === "wasm" || device === "webgpu") {
213
296
  device = undefined;
214
297
  }
215
298
  }
299
+ const dtype = model.provider_config.dtype || "";
216
300
  const pipelineOptions = {
217
- dtype: model.provider_config.dtype || "q8",
301
+ revision: model.provider_config.revision || "main",
218
302
  ...model.provider_config.use_external_data_format ? { useExternalDataFormat: model.provider_config.use_external_data_format } : {},
303
+ ...dtype ? { dtype } : {},
219
304
  ...device ? { device } : {},
220
305
  ...options,
221
306
  progress_callback: progressCallback
@@ -244,27 +329,44 @@ var _transformersSdk, _cacheDir, modelAbortControllers, pipelines, pipelineLoadP
244
329
  logger.timeEnd(pipelineTimerLabel, { status: "aborted" });
245
330
  throw new Error("Operation aborted after pipeline creation");
246
331
  }
332
+ if (IMAGE_PIPELINE_TYPES.has(pipelineType) && result.processor == null) {
333
+ throw new Error(`${HFT_NULL_PROCESSOR_PREFIX} Image processor not initialized for ${pipelineType}/${modelPath}. Model cache may be incomplete.`);
334
+ }
247
335
  logger.timeEnd(pipelineTimerLabel, { status: "loaded" });
248
336
  pipelines.set(cacheKey, result);
249
337
  return result;
250
338
  } catch (error) {
251
339
  logger.timeEnd(pipelineTimerLabel, { status: "error", error: String(error) });
252
- if (abortSignal?.aborted || modelController.signal.aborted) {
340
+ if (!error?.message?.startsWith(HFT_NULL_PROCESSOR_PREFIX) && (abortSignal?.aborted || modelController.signal.aborted)) {
253
341
  throw new Error("Pipeline download aborted");
254
342
  }
255
343
  throw error;
256
344
  } finally {
257
345
  modelAbortControllers.delete(modelPath);
346
+ const { random } = await loadTransformersSDK();
347
+ random.seed(model.provider_config.seed ?? undefined);
258
348
  }
259
349
  };
260
350
  var init_HFT_Pipeline = __esm(() => {
261
351
  modelAbortControllers = new Map;
262
352
  pipelines = new Map;
263
353
  pipelineLoadPromises = new Map;
354
+ IMAGE_PIPELINE_TYPES = new Set([
355
+ "image-classification",
356
+ "image-segmentation",
357
+ "object-detection",
358
+ "image-to-text",
359
+ "image-feature-extraction",
360
+ "zero-shot-image-classification",
361
+ "depth-estimation",
362
+ "mask-generation"
363
+ ]);
264
364
  });
265
365
 
266
366
  // src/provider-hf-transformers/common/HFT_Constants.ts
267
367
  var HF_TRANSFORMERS_ONNX = "HF_TRANSFORMERS_ONNX";
368
+ var HF_TRANSFORMERS_ONNX_GPU = `${HF_TRANSFORMERS_ONNX}_gpu`;
369
+ var HF_TRANSFORMERS_ONNX_CPU = `${HF_TRANSFORMERS_ONNX}_cpu`;
268
370
  var HTF_CACHE_NAME = "transformers-cache";
269
371
  var QuantizationDataType = {
270
372
  auto: "auto",
@@ -340,6 +442,11 @@ var HfTransformersOnnxModelSchema = {
340
442
  type: "string",
341
443
  description: "Filesystem path or URI for the ONNX model."
342
444
  },
445
+ revision: {
446
+ type: "string",
447
+ description: "Git revision (branch, tag, or commit hash) of the model repository.",
448
+ default: "main"
449
+ },
343
450
  dtype: {
344
451
  type: "string",
345
452
  enum: Object.values(QuantizationDataType),
@@ -389,6 +496,11 @@ var HfTransformersOnnxModelSchema = {
389
496
  type: "string",
390
497
  description: "The language style of the model."
391
498
  },
499
+ seed: {
500
+ type: "integer",
501
+ description: "RNG seed passed to transformers.js sampling. Omit for time-based seeding; set for reproducible generation.",
502
+ minimum: 0
503
+ },
392
504
  mrl: {
393
505
  type: "boolean",
394
506
  description: "Whether the model uses matryoshka.",
@@ -487,178 +599,6 @@ function parseOnnxQuantizations(params) {
487
599
  return set !== undefined && set.size === allBaseNames.size;
488
600
  });
489
601
  }
490
- // src/provider-hf-transformers/common/HFT_ToolMarkup.ts
491
- function parseToolCallsFromText(responseText) {
492
- const toolCalls = [];
493
- let callIndex = 0;
494
- let cleanedText = responseText;
495
- const toolCallTagRegex = /<tool_call>([\s\S]*?)<\/tool_call>/g;
496
- let tagMatch;
497
- while ((tagMatch = toolCallTagRegex.exec(responseText)) !== null) {
498
- try {
499
- const parsed = JSON.parse(tagMatch[1].trim());
500
- const id = `call_${callIndex++}`;
501
- toolCalls.push({
502
- id,
503
- name: parsed.name ?? parsed.function?.name ?? "",
504
- input: parsed.arguments ?? parsed.function?.arguments ?? parsed.parameters ?? {}
505
- });
506
- } catch {}
507
- }
508
- if (toolCalls.length > 0) {
509
- cleanedText = responseText.replace(/<tool_call>[\s\S]*?<\/tool_call>/g, "").trim();
510
- return { text: cleanedText, toolCalls };
511
- }
512
- const jsonCandidates = [];
513
- (function collectBalancedJsonBlocks(source) {
514
- const length = source.length;
515
- let i = 0;
516
- while (i < length) {
517
- if (source[i] !== "{") {
518
- i++;
519
- continue;
520
- }
521
- let depth = 1;
522
- let j = i + 1;
523
- let inString = false;
524
- let escape = false;
525
- while (j < length && depth > 0) {
526
- const ch = source[j];
527
- if (inString) {
528
- if (escape) {
529
- escape = false;
530
- } else if (ch === "\\") {
531
- escape = true;
532
- } else if (ch === '"') {
533
- inString = false;
534
- }
535
- } else {
536
- if (ch === '"') {
537
- inString = true;
538
- } else if (ch === "{") {
539
- depth++;
540
- } else if (ch === "}") {
541
- depth--;
542
- }
543
- }
544
- j++;
545
- }
546
- if (depth === 0) {
547
- jsonCandidates.push({ text: source.slice(i, j), start: i, end: j });
548
- i = j;
549
- } else {
550
- break;
551
- }
552
- }
553
- })(responseText);
554
- const matchedRanges = [];
555
- for (const candidate of jsonCandidates) {
556
- try {
557
- const parsed = JSON.parse(candidate.text);
558
- if (parsed.name && (parsed.arguments !== undefined || parsed.parameters !== undefined)) {
559
- const id = `call_${callIndex++}`;
560
- toolCalls.push({
561
- id,
562
- name: parsed.name,
563
- input: parsed.arguments ?? parsed.parameters ?? {}
564
- });
565
- matchedRanges.push({ start: candidate.start, end: candidate.end });
566
- } else if (parsed.function?.name) {
567
- let functionArgs = parsed.function.arguments ?? {};
568
- if (typeof functionArgs === "string") {
569
- try {
570
- functionArgs = JSON.parse(functionArgs);
571
- } catch (innerError) {
572
- console.warn("Failed to parse tool call function.arguments as JSON", innerError);
573
- functionArgs = {};
574
- }
575
- }
576
- const id = `call_${callIndex++}`;
577
- toolCalls.push({
578
- id,
579
- name: parsed.function.name,
580
- input: functionArgs ?? {}
581
- });
582
- matchedRanges.push({ start: candidate.start, end: candidate.end });
583
- }
584
- } catch {}
585
- }
586
- if (toolCalls.length > 0) {
587
- let result = "";
588
- let lastIndex = 0;
589
- for (const range of matchedRanges) {
590
- result += responseText.slice(lastIndex, range.start);
591
- lastIndex = range.end;
592
- }
593
- result += responseText.slice(lastIndex);
594
- cleanedText = result.trim();
595
- }
596
- return { text: cleanedText, toolCalls };
597
- }
598
- function createToolCallMarkupFilter(emit) {
599
- const OPEN_TAG = "<tool_call>";
600
- const CLOSE_TAG = "</tool_call>";
601
- let state = "text";
602
- let pending = "";
603
- function feed(token) {
604
- if (state === "tag") {
605
- pending += token;
606
- const closeIdx = pending.indexOf(CLOSE_TAG);
607
- if (closeIdx !== -1) {
608
- const afterClose = pending.slice(closeIdx + CLOSE_TAG.length);
609
- pending = "";
610
- state = "text";
611
- if (afterClose.length > 0) {
612
- feed(afterClose);
613
- }
614
- }
615
- return;
616
- }
617
- const combined = pending + token;
618
- const openIdx = combined.indexOf(OPEN_TAG);
619
- if (openIdx !== -1) {
620
- const before = combined.slice(0, openIdx);
621
- if (before.length > 0) {
622
- emit(before);
623
- }
624
- pending = "";
625
- state = "tag";
626
- const afterOpen = combined.slice(openIdx + OPEN_TAG.length);
627
- if (afterOpen.length > 0) {
628
- feed(afterOpen);
629
- }
630
- return;
631
- }
632
- let prefixLen = 0;
633
- for (let len = Math.min(combined.length, OPEN_TAG.length - 1);len >= 1; len--) {
634
- if (combined.endsWith(OPEN_TAG.slice(0, len))) {
635
- prefixLen = len;
636
- break;
637
- }
638
- }
639
- if (prefixLen > 0) {
640
- const safe = combined.slice(0, combined.length - prefixLen);
641
- if (safe.length > 0) {
642
- emit(safe);
643
- }
644
- pending = combined.slice(combined.length - prefixLen);
645
- } else {
646
- if (combined.length > 0) {
647
- emit(combined);
648
- }
649
- pending = "";
650
- }
651
- }
652
- function flush() {
653
- if (pending.length > 0 && state === "text") {
654
- emit(pending);
655
- pending = "";
656
- }
657
- pending = "";
658
- state = "text";
659
- }
660
- return { feed, flush };
661
- }
662
602
  // src/provider-hf-transformers/common/HFT_InlineLifecycle.ts
663
603
  async function clearHftInlinePipelineCache() {
664
604
  const { clearPipelineCache: clearPipelineCache2 } = await Promise.resolve().then(() => (init_HFT_Pipeline(), exports_HFT_Pipeline));
@@ -792,16 +732,10 @@ var HFT_BackgroundRemoval = async (input, model, onProgress, signal) => {
792
732
  // src/provider-hf-transformers/common/HFT_CountTokens.ts
793
733
  init_HFT_Pipeline();
794
734
  var HFT_CountTokens = async (input, model, onProgress, _signal) => {
795
- const isArrayInput = Array.isArray(input.text);
796
735
  const { AutoTokenizer } = await loadTransformersSDK();
797
736
  const tokenizer = await AutoTokenizer.from_pretrained(model.provider_config.model_path, {
798
737
  progress_callback: (progress) => onProgress(progress?.progress ?? 0)
799
738
  });
800
- if (isArrayInput) {
801
- const texts = input.text;
802
- const counts = texts.map((t) => tokenizer.encode(t).length);
803
- return { count: counts };
804
- }
805
739
  const tokenIds = tokenizer.encode(input.text);
806
740
  return { count: tokenIds.length };
807
741
  };
@@ -865,6 +799,15 @@ var HFT_ImageEmbedding = async (input, model, onProgress, signal) => {
865
799
  logger.debug("HFT ImageEmbedding: pipeline ready, generating embedding", {
866
800
  model: model?.provider_config.model_path
867
801
  });
802
+ if (Array.isArray(input.image)) {
803
+ const vectors = [];
804
+ for (const image of input.image) {
805
+ const result2 = await embedder(image);
806
+ vectors.push(result2.data);
807
+ }
808
+ logger.timeEnd(timerLabel, { count: vectors.length });
809
+ return { vector: vectors };
810
+ }
868
811
  const result = await embedder(input.image);
869
812
  logger.timeEnd(timerLabel, { dimensions: result?.data?.length });
870
813
  return {
@@ -914,18 +857,14 @@ var HFT_ModelInfo = async (input, model) => {
914
857
  logger.time(timerLabel, { model: model?.provider_config.model_path });
915
858
  const detail = input.detail;
916
859
  const is_loaded = hasCachedPipeline(getPipelineCacheKey(model));
917
- const { pipeline: pipelineType, model_path, dtype } = model.provider_config;
918
- const cacheStatus = await ModelRegistry.is_pipeline_cached_files(pipelineType, model_path, {
919
- ...dtype ? { dtype } : {}
920
- });
860
+ const { pipeline: pipelineType, model_path, dtype, device } = model.provider_config;
861
+ const cacheOptions = {
862
+ ...dtype ? { dtype } : {},
863
+ ...device ? { device } : {}
864
+ };
865
+ const cacheStatus = await ModelRegistry.is_pipeline_cached_files(pipelineType, model_path, cacheOptions);
921
866
  logger.debug("is_pipeline_cached", {
922
- input: [
923
- pipelineType,
924
- model_path,
925
- {
926
- ...dtype ? { dtype } : {}
927
- }
928
- ],
867
+ input: [pipelineType, model_path, cacheOptions],
929
868
  result: cacheStatus
930
869
  });
931
870
  const is_cached = is_loaded || cacheStatus.allCached;
@@ -1009,6 +948,7 @@ init_HFT_Pipeline();
1009
948
  import { parsePartialJson } from "@workglow/util/worker";
1010
949
 
1011
950
  // src/provider-hf-transformers/common/HFT_Streaming.ts
951
+ import { TaskAbortedError } from "@workglow/task-graph";
1012
952
  function createStreamEventQueue() {
1013
953
  const buffer = [];
1014
954
  let resolve = null;
@@ -1060,21 +1000,27 @@ function createStreamEventQueue() {
1060
1000
  };
1061
1001
  return { push, done, error, iterable };
1062
1002
  }
1063
- function createStreamingTextStreamer(tokenizer, queue, textStreamer) {
1003
+ function createStreamingTextStreamer(tokenizer, queue, textStreamer, signal) {
1064
1004
  return new textStreamer(tokenizer, {
1065
1005
  skip_prompt: true,
1066
1006
  decode_kwargs: { skip_special_tokens: true },
1067
1007
  callback_function: (text) => {
1008
+ if (signal?.aborted) {
1009
+ throw signal.reason ?? new TaskAbortedError("Generation aborted");
1010
+ }
1068
1011
  queue.push({ type: "text-delta", port: "text", textDelta: text });
1069
1012
  }
1070
1013
  });
1071
1014
  }
1072
- function createTextStreamer(tokenizer, updateProgress, textStreamer) {
1015
+ function createTextStreamer(tokenizer, updateProgress, textStreamer, signal) {
1073
1016
  let count = 0;
1074
1017
  return new textStreamer(tokenizer, {
1075
1018
  skip_prompt: true,
1076
1019
  decode_kwargs: { skip_special_tokens: true },
1077
1020
  callback_function: (text) => {
1021
+ if (signal?.aborted) {
1022
+ throw signal.reason ?? new TaskAbortedError("Generation aborted");
1023
+ }
1078
1024
  count++;
1079
1025
  const result = 100 * (1 - Math.exp(-0.05 * count));
1080
1026
  const progress = Math.round(Math.min(result, 100));
@@ -1137,7 +1083,7 @@ var HFT_StructuredGeneration = async (input, model, onProgress, signal) => {
1137
1083
  tokenize: false,
1138
1084
  add_generation_prompt: true
1139
1085
  });
1140
- const streamer = createTextStreamer(generateText.tokenizer, onProgress, TextStreamer);
1086
+ const streamer = createTextStreamer(generateText.tokenizer, onProgress, TextStreamer, signal);
1141
1087
  let results = await generateText(formattedPrompt, {
1142
1088
  max_new_tokens: input.maxTokens ?? 1024,
1143
1089
  temperature: input.temperature ?? undefined,
@@ -1162,7 +1108,7 @@ var HFT_StructuredGeneration_Stream = async function* (input, model, signal) {
1162
1108
  add_generation_prompt: true
1163
1109
  });
1164
1110
  const queue = createStreamEventQueue();
1165
- const streamer = createStreamingTextStreamer(generateText.tokenizer, queue, TextStreamer);
1111
+ const streamer = createStreamingTextStreamer(generateText.tokenizer, queue, TextStreamer, signal);
1166
1112
  let fullText = "";
1167
1113
  const originalPush = queue.push;
1168
1114
  queue.push = (event) => {
@@ -1198,22 +1144,12 @@ var HFT_StructuredGeneration_Stream = async function* (input, model, signal) {
1198
1144
  // src/provider-hf-transformers/common/HFT_TextClassification.ts
1199
1145
  init_HFT_Pipeline();
1200
1146
  var HFT_TextClassification = async (input, model, onProgress, signal) => {
1201
- const isArrayInput = Array.isArray(input.text);
1202
1147
  if (model?.provider_config?.pipeline === "zero-shot-classification") {
1203
1148
  if (!input.candidateLabels || !Array.isArray(input.candidateLabels) || input.candidateLabels.length === 0) {
1204
1149
  throw new Error("Zero-shot text classification requires candidate labels");
1205
1150
  }
1206
1151
  const zeroShotClassifier = await getPipeline(model, onProgress, {}, signal);
1207
1152
  const result2 = await zeroShotClassifier(input.text, input.candidateLabels, {});
1208
- if (isArrayInput) {
1209
- const results = Array.isArray(result2) && Array.isArray(result2[0]?.labels) ? result2 : [result2];
1210
- return {
1211
- categories: results.map((r) => r.labels.map((label, idx) => ({
1212
- label,
1213
- score: r.scores[idx]
1214
- })))
1215
- };
1216
- }
1217
1153
  return {
1218
1154
  categories: result2.labels.map((label, idx) => ({
1219
1155
  label,
@@ -1225,27 +1161,9 @@ var HFT_TextClassification = async (input, model, onProgress, signal) => {
1225
1161
  const result = await TextClassification(input.text, {
1226
1162
  top_k: input.maxCategories || undefined
1227
1163
  });
1228
- if (isArrayInput) {
1229
- return {
1230
- categories: result.map((perInput) => {
1231
- const items = Array.isArray(perInput) ? perInput : [perInput];
1232
- return items.map((category) => ({
1233
- label: category.label,
1234
- score: category.score
1235
- }));
1236
- })
1237
- };
1238
- }
1239
- if (Array.isArray(result[0])) {
1240
- return {
1241
- categories: result[0].map((category) => ({
1242
- label: category.label,
1243
- score: category.score
1244
- }))
1245
- };
1246
- }
1164
+ const items = Array.isArray(result[0]) ? result[0] : result;
1247
1165
  return {
1248
- categories: result.map((category) => ({
1166
+ categories: items.map((category) => ({
1249
1167
  label: category.label,
1250
1168
  score: category.score
1251
1169
  }))
@@ -1295,21 +1213,8 @@ var HFT_TextEmbedding = async (input, model, onProgress, signal) => {
1295
1213
  // src/provider-hf-transformers/common/HFT_TextFillMask.ts
1296
1214
  init_HFT_Pipeline();
1297
1215
  var HFT_TextFillMask = async (input, model, onProgress, signal) => {
1298
- const isArrayInput = Array.isArray(input.text);
1299
1216
  const unmasker = await getPipeline(model, onProgress, {}, signal);
1300
1217
  const results = await unmasker(input.text);
1301
- if (isArrayInput) {
1302
- return {
1303
- predictions: results.map((perInput) => {
1304
- const items = Array.isArray(perInput) ? perInput : [perInput];
1305
- return items.map((prediction) => ({
1306
- entity: prediction.token_str,
1307
- score: prediction.score,
1308
- sequence: prediction.sequence
1309
- }));
1310
- })
1311
- };
1312
- }
1313
1218
  let predictions = [];
1314
1219
  if (!Array.isArray(results)) {
1315
1220
  predictions = [results];
@@ -1332,26 +1237,16 @@ var HFT_TextGeneration = async (input, model, onProgress, signal) => {
1332
1237
  const logger = getLogger6();
1333
1238
  const timerLabel = `hft:TextGeneration:${model?.provider_config.model_path}`;
1334
1239
  logger.time(timerLabel, { model: model?.provider_config.model_path });
1335
- const isArrayInput = Array.isArray(input.prompt);
1336
1240
  const generateText = await getPipeline(model, onProgress, {}, signal);
1337
1241
  const { TextStreamer } = await loadTransformersSDK();
1338
1242
  logger.debug("HFT TextGeneration: pipeline ready, generating text", {
1339
1243
  model: model?.provider_config.model_path,
1340
- promptLength: isArrayInput ? input.prompt.length : input.prompt?.length
1244
+ promptLength: input.prompt?.length
1341
1245
  });
1342
- const streamer = isArrayInput ? undefined : createTextStreamer(generateText.tokenizer, onProgress, TextStreamer);
1246
+ const streamer = createTextStreamer(generateText.tokenizer, onProgress, TextStreamer, signal);
1343
1247
  let results = await generateText(input.prompt, {
1344
- ...streamer ? { streamer } : {}
1248
+ streamer
1345
1249
  });
1346
- if (isArrayInput) {
1347
- const batchResults = Array.isArray(results) ? results : [results];
1348
- const texts = batchResults.map((r) => {
1349
- const seqs = Array.isArray(r) ? r : [r];
1350
- return extractGeneratedText(seqs[0]?.generated_text);
1351
- });
1352
- logger.timeEnd(timerLabel, { batchSize: texts.length });
1353
- return { text: texts };
1354
- }
1355
1250
  if (!Array.isArray(results)) {
1356
1251
  results = [results];
1357
1252
  }
@@ -1366,7 +1261,7 @@ var HFT_TextGeneration_Stream = async function* (input, model, signal) {
1366
1261
  const generateText = await getPipeline(model, noopProgress, {}, signal);
1367
1262
  const { TextStreamer } = await loadTransformersSDK();
1368
1263
  const queue = createStreamEventQueue();
1369
- const streamer = createStreamingTextStreamer(generateText.tokenizer, queue, TextStreamer);
1264
+ const streamer = createStreamingTextStreamer(generateText.tokenizer, queue, TextStreamer, signal);
1370
1265
  const pipelinePromise = generateText(input.prompt, {
1371
1266
  streamer
1372
1267
  }).then(() => queue.done(), (err) => queue.error(err));
@@ -1378,22 +1273,10 @@ var HFT_TextGeneration_Stream = async function* (input, model, signal) {
1378
1273
  // src/provider-hf-transformers/common/HFT_TextLanguageDetection.ts
1379
1274
  init_HFT_Pipeline();
1380
1275
  var HFT_TextLanguageDetection = async (input, model, onProgress, signal) => {
1381
- const isArrayInput = Array.isArray(input.text);
1382
1276
  const TextClassification = await getPipeline(model, onProgress, {}, signal);
1383
1277
  const result = await TextClassification(input.text, {
1384
1278
  top_k: input.maxLanguages || undefined
1385
1279
  });
1386
- if (isArrayInput) {
1387
- return {
1388
- languages: result.map((perInput) => {
1389
- const items = Array.isArray(perInput) ? perInput : [perInput];
1390
- return items.map((category) => ({
1391
- language: category.label,
1392
- score: category.score
1393
- }));
1394
- })
1395
- };
1396
- }
1397
1280
  if (Array.isArray(result[0])) {
1398
1281
  return {
1399
1282
  languages: result[0].map((category) => ({
@@ -1413,23 +1296,10 @@ var HFT_TextLanguageDetection = async (input, model, onProgress, signal) => {
1413
1296
  // src/provider-hf-transformers/common/HFT_TextNamedEntityRecognition.ts
1414
1297
  init_HFT_Pipeline();
1415
1298
  var HFT_TextNamedEntityRecognition = async (input, model, onProgress, signal) => {
1416
- const isArrayInput = Array.isArray(input.text);
1417
1299
  const textNamedEntityRecognition = await getPipeline(model, onProgress, {}, signal);
1418
1300
  const results = await textNamedEntityRecognition(input.text, {
1419
1301
  ignore_labels: input.blockList
1420
1302
  });
1421
- if (isArrayInput) {
1422
- return {
1423
- entities: results.map((perInput) => {
1424
- const items = Array.isArray(perInput) ? perInput : [perInput];
1425
- return items.map((entity) => ({
1426
- entity: entity.entity,
1427
- score: entity.score,
1428
- word: entity.word
1429
- }));
1430
- })
1431
- };
1432
- }
1433
1303
  let entities = [];
1434
1304
  if (!Array.isArray(results)) {
1435
1305
  entities = [results];
@@ -1448,29 +1318,9 @@ var HFT_TextNamedEntityRecognition = async (input, model, onProgress, signal) =>
1448
1318
  // src/provider-hf-transformers/common/HFT_TextQuestionAnswer.ts
1449
1319
  init_HFT_Pipeline();
1450
1320
  var HFT_TextQuestionAnswer = async (input, model, onProgress, signal) => {
1451
- const isArrayInput = Array.isArray(input.question);
1452
1321
  const generateAnswer = await getPipeline(model, onProgress, {}, signal);
1453
- if (isArrayInput) {
1454
- const questions = input.question;
1455
- const contexts = input.context;
1456
- if (questions.length !== contexts.length) {
1457
- throw new Error(`question[] and context[] must have the same length: ${questions.length} != ${contexts.length}`);
1458
- }
1459
- const answers = [];
1460
- for (let i = 0;i < questions.length; i++) {
1461
- const result2 = await generateAnswer(questions[i], contexts[i], {});
1462
- let answerText2 = "";
1463
- if (Array.isArray(result2)) {
1464
- answerText2 = result2[0]?.answer || "";
1465
- } else {
1466
- answerText2 = result2?.answer || "";
1467
- }
1468
- answers.push(answerText2);
1469
- }
1470
- return { text: answers };
1471
- }
1472
1322
  const { TextStreamer } = await loadTransformersSDK();
1473
- const streamer = createTextStreamer(generateAnswer.tokenizer, onProgress, TextStreamer);
1323
+ const streamer = createTextStreamer(generateAnswer.tokenizer, onProgress, TextStreamer, signal);
1474
1324
  const result = await generateAnswer(input.question, input.context, {
1475
1325
  streamer
1476
1326
  });
@@ -1489,7 +1339,7 @@ var HFT_TextQuestionAnswer_Stream = async function* (input, model, signal) {
1489
1339
  const generateAnswer = await getPipeline(model, noopProgress, {}, signal);
1490
1340
  const { TextStreamer } = await loadTransformersSDK();
1491
1341
  const queue = createStreamEventQueue();
1492
- const streamer = createStreamingTextStreamer(generateAnswer.tokenizer, queue, TextStreamer);
1342
+ const streamer = createStreamingTextStreamer(generateAnswer.tokenizer, queue, TextStreamer, signal);
1493
1343
  let pipelineResult;
1494
1344
  const pipelinePromise = generateAnswer(input.question, input.context, {
1495
1345
  streamer
@@ -1513,30 +1363,13 @@ var HFT_TextQuestionAnswer_Stream = async function* (input, model, signal) {
1513
1363
  // src/provider-hf-transformers/common/HFT_TextRewriter.ts
1514
1364
  init_HFT_Pipeline();
1515
1365
  var HFT_TextRewriter = async (input, model, onProgress, signal) => {
1516
- const isArrayInput = Array.isArray(input.text);
1517
1366
  const generateText = await getPipeline(model, onProgress, {}, signal);
1518
1367
  const { TextStreamer } = await loadTransformersSDK();
1519
- const streamer = isArrayInput ? undefined : createTextStreamer(generateText.tokenizer, onProgress, TextStreamer);
1520
- if (isArrayInput) {
1521
- const texts = input.text;
1522
- const promptedTexts = texts.map((t) => (input.prompt ? input.prompt + `
1523
- ` : "") + t);
1524
- let results2 = await generateText(promptedTexts, {});
1525
- const batchResults = Array.isArray(results2) ? results2 : [results2];
1526
- const outputTexts = batchResults.map((r, i) => {
1527
- const seqs = Array.isArray(r) ? r : [r];
1528
- const text2 = extractGeneratedText(seqs[0]?.generated_text);
1529
- if (text2 === promptedTexts[i]) {
1530
- throw new Error("Rewriter failed to generate new text");
1531
- }
1532
- return text2;
1533
- });
1534
- return { text: outputTexts };
1535
- }
1368
+ const streamer = createTextStreamer(generateText.tokenizer, onProgress, TextStreamer, signal);
1536
1369
  const promptedText = (input.prompt ? input.prompt + `
1537
1370
  ` : "") + input.text;
1538
1371
  let results = await generateText(promptedText, {
1539
- ...streamer ? { streamer } : {}
1372
+ streamer
1540
1373
  });
1541
1374
  if (!Array.isArray(results)) {
1542
1375
  results = [results];
@@ -1554,7 +1387,7 @@ var HFT_TextRewriter_Stream = async function* (input, model, signal) {
1554
1387
  const generateText = await getPipeline(model, noopProgress, {}, signal);
1555
1388
  const { TextStreamer } = await loadTransformersSDK();
1556
1389
  const queue = createStreamEventQueue();
1557
- const streamer = createStreamingTextStreamer(generateText.tokenizer, queue, TextStreamer);
1390
+ const streamer = createStreamingTextStreamer(generateText.tokenizer, queue, TextStreamer, signal);
1558
1391
  const promptedText = (input.prompt ? input.prompt + `
1559
1392
  ` : "") + input.text;
1560
1393
  const pipelinePromise = generateText(promptedText, {
@@ -1568,19 +1401,12 @@ var HFT_TextRewriter_Stream = async function* (input, model, signal) {
1568
1401
  // src/provider-hf-transformers/common/HFT_TextSummary.ts
1569
1402
  init_HFT_Pipeline();
1570
1403
  var HFT_TextSummary = async (input, model, onProgress, signal) => {
1571
- const isArrayInput = Array.isArray(input.text);
1572
1404
  const generateSummary = await getPipeline(model, onProgress, {}, signal);
1573
1405
  const { TextStreamer } = await loadTransformersSDK();
1574
- const streamer = isArrayInput ? undefined : createTextStreamer(generateSummary.tokenizer, onProgress, TextStreamer);
1406
+ const streamer = createTextStreamer(generateSummary.tokenizer, onProgress, TextStreamer, signal);
1575
1407
  const result = await generateSummary(input.text, {
1576
- ...streamer ? { streamer } : {}
1408
+ streamer
1577
1409
  });
1578
- if (isArrayInput) {
1579
- const batchResults = Array.isArray(result) ? result : [result];
1580
- return {
1581
- text: batchResults.map((r) => r?.summary_text || "")
1582
- };
1583
- }
1584
1410
  let summaryText = "";
1585
1411
  if (Array.isArray(result)) {
1586
1412
  summaryText = result[0]?.summary_text || "";
@@ -1596,7 +1422,7 @@ var HFT_TextSummary_Stream = async function* (input, model, signal) {
1596
1422
  const generateSummary = await getPipeline(model, noopProgress, {}, signal);
1597
1423
  const { TextStreamer } = await loadTransformersSDK();
1598
1424
  const queue = createStreamEventQueue();
1599
- const streamer = createStreamingTextStreamer(generateSummary.tokenizer, queue, TextStreamer);
1425
+ const streamer = createStreamingTextStreamer(generateSummary.tokenizer, queue, TextStreamer, signal);
1600
1426
  const pipelinePromise = generateSummary(input.text, {
1601
1427
  streamer
1602
1428
  }).then(() => queue.done(), (err) => queue.error(err));
@@ -1608,22 +1434,14 @@ var HFT_TextSummary_Stream = async function* (input, model, signal) {
1608
1434
  // src/provider-hf-transformers/common/HFT_TextTranslation.ts
1609
1435
  init_HFT_Pipeline();
1610
1436
  var HFT_TextTranslation = async (input, model, onProgress, signal) => {
1611
- const isArrayInput = Array.isArray(input.text);
1612
1437
  const translate = await getPipeline(model, onProgress, {}, signal);
1613
1438
  const { TextStreamer } = await loadTransformersSDK();
1614
- const streamer = isArrayInput ? undefined : createTextStreamer(translate.tokenizer, onProgress, TextStreamer);
1439
+ const streamer = createTextStreamer(translate.tokenizer, onProgress, TextStreamer, signal);
1615
1440
  const result = await translate(input.text, {
1616
1441
  src_lang: input.source_lang,
1617
1442
  tgt_lang: input.target_lang,
1618
- ...streamer ? { streamer } : {}
1443
+ streamer
1619
1444
  });
1620
- if (isArrayInput) {
1621
- const batchResults = Array.isArray(result) ? result : [result];
1622
- return {
1623
- text: batchResults.map((r) => r?.translation_text || ""),
1624
- target_lang: input.target_lang
1625
- };
1626
- }
1627
1445
  const translatedText = Array.isArray(result) ? result[0]?.translation_text || "" : result?.translation_text || "";
1628
1446
  return {
1629
1447
  text: translatedText,
@@ -1635,7 +1453,7 @@ var HFT_TextTranslation_Stream = async function* (input, model, signal) {
1635
1453
  const translate = await getPipeline(model, noopProgress, {}, signal);
1636
1454
  const { TextStreamer } = await loadTransformersSDK();
1637
1455
  const queue = createStreamEventQueue();
1638
- const streamer = createStreamingTextStreamer(translate.tokenizer, queue, TextStreamer);
1456
+ const streamer = createStreamingTextStreamer(translate.tokenizer, queue, TextStreamer, signal);
1639
1457
  const pipelinePromise = translate(input.text, {
1640
1458
  src_lang: input.source_lang,
1641
1459
  tgt_lang: input.target_lang,
@@ -1646,162 +1464,6 @@ var HFT_TextTranslation_Stream = async function* (input, model, signal) {
1646
1464
  yield { type: "finish", data: { target_lang: input.target_lang } };
1647
1465
  };
1648
1466
 
1649
- // src/provider-hf-transformers/common/HFT_ToolCalling.ts
1650
- init_HFT_Pipeline();
1651
- import {
1652
- buildToolDescription,
1653
- filterValidToolCalls,
1654
- toTextFlatMessages
1655
- } from "@workglow/ai/worker";
1656
- function mapHFTTools(tools) {
1657
- return tools.map((t) => ({
1658
- type: "function",
1659
- function: {
1660
- name: t.name,
1661
- description: buildToolDescription(t),
1662
- parameters: t.inputSchema
1663
- }
1664
- }));
1665
- }
1666
- function resolveHFTToolsAndMessages(input, messages) {
1667
- if (input.toolChoice === "none") {
1668
- return;
1669
- }
1670
- if (input.toolChoice === "required") {
1671
- const requiredInstruction = "You must call at least one tool from the provided tool list when answering.";
1672
- if (messages.length > 0 && messages[0].role === "system") {
1673
- messages[0] = { ...messages[0], content: `${messages[0].content}
1674
-
1675
- ${requiredInstruction}` };
1676
- } else {
1677
- messages.unshift({ role: "system", content: requiredInstruction });
1678
- }
1679
- return mapHFTTools(input.tools);
1680
- }
1681
- if (typeof input.toolChoice === "string" && input.toolChoice !== "auto") {
1682
- const selectedTools = input.tools?.filter((tool) => tool.name === input.toolChoice);
1683
- const toolsToMap = selectedTools && selectedTools.length > 0 ? selectedTools : input.tools;
1684
- return mapHFTTools(toolsToMap);
1685
- }
1686
- return mapHFTTools(input.tools);
1687
- }
1688
- var HFT_ToolCalling = async (input, model, onProgress, signal) => {
1689
- const isArrayInput = Array.isArray(input.prompt);
1690
- const generateText = await getPipeline(model, onProgress, {}, signal);
1691
- const { TextStreamer } = await loadTransformersSDK();
1692
- if (isArrayInput) {
1693
- const prompts = input.prompt;
1694
- const texts = [];
1695
- const toolCallsList = [];
1696
- for (const singlePrompt of prompts) {
1697
- const singleInput = { ...input, prompt: singlePrompt };
1698
- const messages2 = toTextFlatMessages(singleInput);
1699
- const tools2 = resolveHFTToolsAndMessages(singleInput, messages2);
1700
- const prompt2 = generateText.tokenizer.apply_chat_template(messages2, {
1701
- tools: tools2,
1702
- tokenize: false,
1703
- add_generation_prompt: true
1704
- });
1705
- const streamer2 = createTextStreamer(generateText.tokenizer, onProgress, TextStreamer);
1706
- let results2 = await generateText(prompt2, {
1707
- max_new_tokens: input.maxTokens ?? 1024,
1708
- temperature: input.temperature ?? undefined,
1709
- return_full_text: false,
1710
- streamer: streamer2
1711
- });
1712
- if (!Array.isArray(results2)) {
1713
- results2 = [results2];
1714
- }
1715
- const responseText2 = extractGeneratedText(results2[0]?.generated_text).trim();
1716
- const { text: text2, toolCalls: toolCalls2 } = parseToolCallsFromText(responseText2);
1717
- texts.push(text2);
1718
- toolCallsList.push(filterValidToolCalls(toolCalls2, singleInput.tools));
1719
- }
1720
- return { text: texts, toolCalls: toolCallsList };
1721
- }
1722
- const messages = toTextFlatMessages(input);
1723
- const tools = resolveHFTToolsAndMessages(input, messages);
1724
- const prompt = generateText.tokenizer.apply_chat_template(messages, {
1725
- tools,
1726
- tokenize: false,
1727
- add_generation_prompt: true
1728
- });
1729
- const streamer = createTextStreamer(generateText.tokenizer, onProgress, TextStreamer);
1730
- let results = await generateText(prompt, {
1731
- max_new_tokens: input.maxTokens ?? 1024,
1732
- temperature: input.temperature ?? undefined,
1733
- return_full_text: false,
1734
- streamer
1735
- });
1736
- if (!Array.isArray(results)) {
1737
- results = [results];
1738
- }
1739
- const responseText = extractGeneratedText(results[0]?.generated_text).trim();
1740
- const { text, toolCalls } = parseToolCallsFromText(responseText);
1741
- return {
1742
- text,
1743
- toolCalls: filterValidToolCalls(toolCalls, input.tools)
1744
- };
1745
- };
1746
- var HFT_ToolCalling_Stream = async function* (input, model, signal) {
1747
- const noopProgress = () => {};
1748
- const generateText = await getPipeline(model, noopProgress, {}, signal);
1749
- const { TextStreamer } = await loadTransformersSDK();
1750
- const messages = toTextFlatMessages(input);
1751
- const tools = resolveHFTToolsAndMessages(input, messages);
1752
- const prompt = generateText.tokenizer.apply_chat_template(messages, {
1753
- tools,
1754
- tokenize: false,
1755
- add_generation_prompt: true
1756
- });
1757
- const innerQueue = createStreamEventQueue();
1758
- const outerQueue = createStreamEventQueue();
1759
- const streamer = createStreamingTextStreamer(generateText.tokenizer, innerQueue, TextStreamer);
1760
- let fullText = "";
1761
- const filter = createToolCallMarkupFilter((text) => {
1762
- outerQueue.push({ type: "text-delta", port: "text", textDelta: text });
1763
- });
1764
- const originalPush = innerQueue.push;
1765
- innerQueue.push = (event) => {
1766
- if (event.type === "text-delta" && "textDelta" in event) {
1767
- fullText += event.textDelta;
1768
- filter.feed(event.textDelta);
1769
- } else {
1770
- outerQueue.push(event);
1771
- }
1772
- originalPush(event);
1773
- };
1774
- const originalDone = innerQueue.done;
1775
- innerQueue.done = () => {
1776
- filter.flush();
1777
- outerQueue.done();
1778
- originalDone();
1779
- };
1780
- const originalError = innerQueue.error;
1781
- innerQueue.error = (e) => {
1782
- filter.flush();
1783
- outerQueue.error(e);
1784
- originalError(e);
1785
- };
1786
- const pipelinePromise = generateText(prompt, {
1787
- max_new_tokens: input.maxTokens ?? 1024,
1788
- temperature: input.temperature ?? undefined,
1789
- return_full_text: false,
1790
- streamer
1791
- }).then(() => innerQueue.done(), (err) => innerQueue.error(err));
1792
- yield* outerQueue.iterable;
1793
- await pipelinePromise;
1794
- const { text: cleanedText, toolCalls } = parseToolCallsFromText(fullText);
1795
- const validToolCalls = filterValidToolCalls(toolCalls, input.tools);
1796
- if (validToolCalls.length > 0) {
1797
- yield { type: "object-delta", port: "toolCalls", objectDelta: [...validToolCalls] };
1798
- }
1799
- yield {
1800
- type: "finish",
1801
- data: { text: cleanedText, toolCalls: validToolCalls }
1802
- };
1803
- };
1804
-
1805
1467
  // src/provider-hf-transformers/common/HFT_Unload.ts
1806
1468
  init_HFT_Pipeline();
1807
1469
  function hasBrowserCacheStorage() {
@@ -1877,7 +1539,6 @@ var HFT_TASKS = {
1877
1539
  ImageEmbeddingTask: HFT_ImageEmbedding,
1878
1540
  ImageClassificationTask: HFT_ImageClassification,
1879
1541
  ObjectDetectionTask: HFT_ObjectDetection,
1880
- ToolCallingTask: HFT_ToolCalling,
1881
1542
  StructuredGenerationTask: HFT_StructuredGeneration,
1882
1543
  ModelSearchTask: HFT_ModelSearch
1883
1544
  };
@@ -1887,7 +1548,6 @@ var HFT_STREAM_TASKS = {
1887
1548
  TextSummaryTask: HFT_TextSummary_Stream,
1888
1549
  TextQuestionAnswerTask: HFT_TextQuestionAnswer_Stream,
1889
1550
  TextTranslationTask: HFT_TextTranslation_Stream,
1890
- ToolCallingTask: HFT_ToolCalling_Stream,
1891
1551
  StructuredGenerationTask: HFT_StructuredGeneration_Stream
1892
1552
  };
1893
1553
  var HFT_REACTIVE_TASKS = {
@@ -1898,11 +1558,34 @@ var HFT_REACTIVE_TASKS = {
1898
1558
  import {
1899
1559
  QueuedAiProvider
1900
1560
  } from "@workglow/ai";
1561
+ var GPU_DEVICES = new Set(["webgpu", "gpu", "metal"]);
1562
+ var HFT_CPU_QUEUE_CONCURRENCY_PRODUCTION = 4;
1563
+ function hftIsAutomatedTestEnvironment() {
1564
+ if (typeof process === "undefined") {
1565
+ return false;
1566
+ }
1567
+ const e = process.env;
1568
+ return e.VITEST === "true" || e.NODE_ENV === "test" || e.BUN_TEST === "1" || e.JEST_WORKER_ID !== undefined;
1569
+ }
1570
+ function hftDefaultCpuQueueConcurrency() {
1571
+ return hftIsAutomatedTestEnvironment() ? 1 : HFT_CPU_QUEUE_CONCURRENCY_PRODUCTION;
1572
+ }
1573
+ function resolveHftCpuQueueConcurrency(concurrency, defaultCpu) {
1574
+ if (concurrency === undefined) {
1575
+ return defaultCpu();
1576
+ }
1577
+ if (typeof concurrency === "number") {
1578
+ return defaultCpu();
1579
+ }
1580
+ return concurrency.cpu ?? defaultCpu();
1581
+ }
1582
+
1901
1583
  class HuggingFaceTransformersQueuedProvider extends QueuedAiProvider {
1902
1584
  name = HF_TRANSFORMERS_ONNX;
1903
1585
  displayName = "Hugging Face Transformers (ONNX)";
1904
1586
  isLocal = true;
1905
1587
  supportsBrowser = true;
1588
+ cpuStrategy;
1906
1589
  taskTypes = [
1907
1590
  "DownloadModelTask",
1908
1591
  "UnloadModelTask",
@@ -1924,12 +1607,22 @@ class HuggingFaceTransformersQueuedProvider extends QueuedAiProvider {
1924
1607
  "ImageEmbeddingTask",
1925
1608
  "ImageClassificationTask",
1926
1609
  "ObjectDetectionTask",
1927
- "ToolCallingTask",
1928
1610
  "ModelSearchTask"
1929
1611
  ];
1930
1612
  constructor(tasks, streamTasks, reactiveTasks) {
1931
1613
  super(tasks, streamTasks, reactiveTasks);
1932
1614
  }
1615
+ async afterRegister(options) {
1616
+ await super.afterRegister(options);
1617
+ this.cpuStrategy = this.createQueuedStrategy(HF_TRANSFORMERS_ONNX_CPU, resolveHftCpuQueueConcurrency(options.queue?.concurrency, hftDefaultCpuQueueConcurrency), options);
1618
+ }
1619
+ getStrategyForModel(model) {
1620
+ const device = model.provider_config?.device;
1621
+ if (device && GPU_DEVICES.has(device)) {
1622
+ return this.queuedStrategy;
1623
+ }
1624
+ return this.cpuStrategy;
1625
+ }
1933
1626
  }
1934
1627
 
1935
1628
  // src/provider-hf-transformers/registerHuggingFaceTransformersInline.ts
@@ -1978,7 +1671,6 @@ class HuggingFaceTransformersProvider extends AiProvider {
1978
1671
  "ImageEmbeddingTask",
1979
1672
  "ImageClassificationTask",
1980
1673
  "ObjectDetectionTask",
1981
- "ToolCallingTask",
1982
1674
  "ModelSearchTask"
1983
1675
  ];
1984
1676
  constructor(tasks, streamTasks, reactiveTasks) {
@@ -1989,7 +1681,9 @@ class HuggingFaceTransformersProvider extends AiProvider {
1989
1681
  // src/provider-hf-transformers/registerHuggingFaceTransformersWorker.ts
1990
1682
  init_HFT_Pipeline();
1991
1683
  async function registerHuggingFaceTransformersWorker() {
1992
- const { env } = await loadTransformersSDK();
1684
+ const sdk = await loadTransformersSDK();
1685
+ globalThis.__HFT__ = sdk;
1686
+ const { env } = sdk;
1993
1687
  env.backends.onnx.wasm.proxy = true;
1994
1688
  const workerServer = globalServiceRegistry.get(WORKER_SERVER);
1995
1689
  new HuggingFaceTransformersProvider(HFT_TASKS, HFT_STREAM_TASKS, HFT_REACTIVE_TASKS).registerOnWorkerServer(workerServer);
@@ -2004,13 +1698,11 @@ export {
2004
1698
  removeCachedPipeline,
2005
1699
  registerHuggingFaceTransformersWorker,
2006
1700
  registerHuggingFaceTransformersInline,
2007
- parseToolCallsFromText,
2008
1701
  parseOnnxQuantizations,
2009
1702
  loadTransformersSDK,
2010
1703
  hasCachedPipeline,
2011
1704
  getPipelineCacheKey,
2012
1705
  getPipeline,
2013
- createToolCallMarkupFilter,
2014
1706
  clearPipelineCache,
2015
1707
  QuantizationDataType,
2016
1708
  PipelineUseCase,
@@ -2019,7 +1711,10 @@ export {
2019
1711
  HfTransformersOnnxModelRecordSchema,
2020
1712
  HfTransformersOnnxModelConfigSchema,
2021
1713
  HTF_CACHE_NAME,
2022
- HF_TRANSFORMERS_ONNX
1714
+ HF_TRANSFORMERS_ONNX_GPU,
1715
+ HF_TRANSFORMERS_ONNX_CPU,
1716
+ HF_TRANSFORMERS_ONNX,
1717
+ HFT_NULL_PROCESSOR_PREFIX
2023
1718
  };
2024
1719
 
2025
- //# debugId=6F10F5E049CF8D0264756E2164756E21
1720
+ //# debugId=E4863FB6D65AEC2D64756E2164756E21