@workglow/ai-provider 0.0.108 → 0.0.110

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. package/README.md +1 -1
  2. package/dist/{HFT_JobRunFns-c40ex37f.js → HFT_JobRunFns-0bwh5cmx.js} +4 -2
  3. package/dist/{HFT_JobRunFns-c40ex37f.js.map → HFT_JobRunFns-0bwh5cmx.js.map} +1 -1
  4. package/dist/anthropic/AnthropicProvider.d.ts +3 -1
  5. package/dist/anthropic/AnthropicProvider.d.ts.map +1 -1
  6. package/dist/anthropic/common/Anthropic_JobRunFns.d.ts +2 -1
  7. package/dist/anthropic/common/Anthropic_JobRunFns.d.ts.map +1 -1
  8. package/dist/anthropic/index.js +68 -2
  9. package/dist/anthropic/index.js.map +3 -3
  10. package/dist/google-gemini/GoogleGeminiProvider.d.ts +3 -1
  11. package/dist/google-gemini/GoogleGeminiProvider.d.ts.map +1 -1
  12. package/dist/google-gemini/common/Gemini_JobRunFns.d.ts +2 -1
  13. package/dist/google-gemini/common/Gemini_JobRunFns.d.ts.map +1 -1
  14. package/dist/google-gemini/index.js +68 -2
  15. package/dist/google-gemini/index.js.map +3 -3
  16. package/dist/hf-transformers/HuggingFaceTransformersProvider.d.ts +5 -3
  17. package/dist/hf-transformers/HuggingFaceTransformersProvider.d.ts.map +1 -1
  18. package/dist/hf-transformers/common/HFT_JobRunFns.d.ts +121 -25
  19. package/dist/hf-transformers/common/HFT_JobRunFns.d.ts.map +1 -1
  20. package/dist/hf-transformers/index.js +5 -3
  21. package/dist/hf-transformers/index.js.map +2 -2
  22. package/dist/{index-5qjdc78z.js → index-60ev6k93.js} +4 -1
  23. package/dist/{index-5qjdc78z.js.map → index-60ev6k93.js.map} +3 -3
  24. package/dist/{index-5hjgs1z7.js → index-8651nz8y.js} +4 -1
  25. package/dist/{index-5hjgs1z7.js.map → index-8651nz8y.js.map} +3 -3
  26. package/dist/{index-4fr8p4gy.js → index-dmrxc6ek.js} +302 -175
  27. package/dist/index-dmrxc6ek.js.map +10 -0
  28. package/dist/{index-drcnh4z5.js → index-q2t627d5.js} +4 -1
  29. package/dist/{index-drcnh4z5.js.map → index-q2t627d5.js.map} +3 -3
  30. package/dist/{index-14pbwsc9.js → index-tp5s7355.js} +4 -1
  31. package/dist/{index-14pbwsc9.js.map → index-tp5s7355.js.map} +3 -3
  32. package/dist/{index-aef54vq3.js → index-v72vr07f.js} +4 -1
  33. package/dist/{index-aef54vq3.js.map → index-v72vr07f.js.map} +3 -3
  34. package/dist/{index-cejxxqcz.js → index-weaycaap.js} +6 -3
  35. package/dist/index-weaycaap.js.map +10 -0
  36. package/dist/{index-xc6m9mcp.js → index-wr57rwyx.js} +4 -1
  37. package/dist/{index-xc6m9mcp.js.map → index-wr57rwyx.js.map} +3 -3
  38. package/dist/index.js +11 -8
  39. package/dist/index.js.map +3 -3
  40. package/dist/provider-hf-inference/HfInferenceProvider.d.ts +3 -1
  41. package/dist/provider-hf-inference/HfInferenceProvider.d.ts.map +1 -1
  42. package/dist/provider-hf-inference/common/HFI_JobRunFns.d.ts +2 -1
  43. package/dist/provider-hf-inference/common/HFI_JobRunFns.d.ts.map +1 -1
  44. package/dist/provider-hf-inference/index.js +59 -3
  45. package/dist/provider-hf-inference/index.js.map +3 -3
  46. package/dist/provider-llamacpp/LlamaCppProvider.d.ts +3 -1
  47. package/dist/provider-llamacpp/LlamaCppProvider.d.ts.map +1 -1
  48. package/dist/provider-llamacpp/common/LlamaCpp_JobRunFns.d.ts +2 -1
  49. package/dist/provider-llamacpp/common/LlamaCpp_JobRunFns.d.ts.map +1 -1
  50. package/dist/provider-llamacpp/index.js +87 -4
  51. package/dist/provider-llamacpp/index.js.map +3 -3
  52. package/dist/provider-ollama/OllamaProvider.d.ts +3 -1
  53. package/dist/provider-ollama/OllamaProvider.d.ts.map +1 -1
  54. package/dist/provider-ollama/common/Ollama_JobRunFns.browser.d.ts +2 -1
  55. package/dist/provider-ollama/common/Ollama_JobRunFns.browser.d.ts.map +1 -1
  56. package/dist/provider-ollama/common/Ollama_JobRunFns.d.ts +2 -1
  57. package/dist/provider-ollama/common/Ollama_JobRunFns.d.ts.map +1 -1
  58. package/dist/provider-ollama/index.browser.js +78 -2
  59. package/dist/provider-ollama/index.browser.js.map +4 -4
  60. package/dist/provider-ollama/index.js +78 -5
  61. package/dist/provider-ollama/index.js.map +3 -3
  62. package/dist/provider-openai/OpenAiProvider.d.ts +3 -1
  63. package/dist/provider-openai/OpenAiProvider.d.ts.map +1 -1
  64. package/dist/provider-openai/common/OpenAI_JobRunFns.d.ts +2 -1
  65. package/dist/provider-openai/common/OpenAI_JobRunFns.d.ts.map +1 -1
  66. package/dist/provider-openai/index.js +68 -2
  67. package/dist/provider-openai/index.js.map +3 -3
  68. package/dist/tf-mediapipe/TensorFlowMediaPipeProvider.d.ts +3 -1
  69. package/dist/tf-mediapipe/TensorFlowMediaPipeProvider.d.ts.map +1 -1
  70. package/dist/tf-mediapipe/common/TFMP_JobRunFns.d.ts +67 -5
  71. package/dist/tf-mediapipe/common/TFMP_JobRunFns.d.ts.map +1 -1
  72. package/dist/tf-mediapipe/index.js +20 -1
  73. package/dist/tf-mediapipe/index.js.map +4 -4
  74. package/package.json +20 -19
  75. package/dist/hf-transformers/common/HFT_CallbackStatus.d.ts +0 -36
  76. package/dist/hf-transformers/common/HFT_CallbackStatus.d.ts.map +0 -1
  77. package/dist/index-4fr8p4gy.js.map +0 -10
  78. package/dist/index-cejxxqcz.js.map +0 -10
@@ -12,10 +12,10 @@ var _transformersSdk;
12
12
  async function loadTransformersSDK() {
13
13
  if (!_transformersSdk) {
14
14
  try {
15
- _transformersSdk = await import("@sroussey/transformers");
15
+ _transformersSdk = await import("@huggingface/transformers");
16
16
  _transformersSdk.env.fetch = abortableFetch;
17
17
  } catch {
18
- throw new Error("@sroussey/transformers is required for HuggingFace Transformers tasks. Install it with: bun add @sroussey/transformers");
18
+ throw new Error("@huggingface/transformers is required for HuggingFace Transformers tasks. Install it with: bun add @huggingface/transformers");
19
19
  }
20
20
  }
21
21
  return _transformersSdk;
@@ -64,61 +64,47 @@ var getPipeline = async (model, onProgress, options = {}, signal, progressScaleM
64
64
  return loadPromise;
65
65
  };
66
66
  var doGetPipeline = async (model, onProgress, options, progressScaleMax, cacheKey, signal) => {
67
- const fileSizes = new Map;
68
- const fileProgress = new Map;
69
- const fileCompleted = new Set;
70
- const fileFirstSent = new Set;
71
- const fileLastSent = new Set;
72
- const fileLastEventTime = new Map;
73
- const pendingProgressByFile = new Map;
67
+ let lastProgressTime = 0;
68
+ let pendingProgress = null;
74
69
  let throttleTimer = null;
75
70
  const THROTTLE_MS = 160;
76
- const estimatedTinyFiles = 3;
77
- const estimatedMediumFiles = 1;
78
- const estimatedTinySize = 1024;
79
- const estimatedMediumSize = 20971520;
80
- const estimatedLargeSize = 1073741824;
81
- const baseEstimate = estimatedTinyFiles * estimatedTinySize + estimatedMediumFiles * estimatedMediumSize;
82
- const sendProgress = (overallProgress, file, fileProgressValue, isFirst, isLast) => {
71
+ const sendProgress = (progress, file, fileProgress) => {
83
72
  const now = Date.now();
84
- const lastTime = fileLastEventTime.get(file) || 0;
85
- const timeSinceLastEvent = now - lastTime;
86
- const shouldThrottle = !isFirst && !isLast && timeSinceLastEvent < THROTTLE_MS;
87
- if (shouldThrottle) {
88
- pendingProgressByFile.set(file, {
89
- progress: overallProgress,
90
- file,
91
- fileProgress: fileProgressValue
92
- });
73
+ const timeSinceLastEvent = now - lastProgressTime;
74
+ const isFirst = lastProgressTime === 0;
75
+ const isFinal = progress >= progressScaleMax;
76
+ if (isFirst || isFinal) {
77
+ if (throttleTimer) {
78
+ clearTimeout(throttleTimer);
79
+ throttleTimer = null;
80
+ }
81
+ pendingProgress = null;
82
+ onProgress(Math.round(progress), "Downloading model", { file, progress: fileProgress });
83
+ lastProgressTime = now;
84
+ return;
85
+ }
86
+ if (timeSinceLastEvent < THROTTLE_MS) {
87
+ pendingProgress = { progress, file, fileProgress };
93
88
  if (!throttleTimer) {
94
89
  const timeRemaining = Math.max(1, THROTTLE_MS - timeSinceLastEvent);
95
90
  throttleTimer = setTimeout(() => {
96
- for (const [pendingFile, pending] of pendingProgressByFile.entries()) {
97
- onProgress(Math.round(pending.progress), "Downloading model", {
98
- file: pendingFile,
99
- progress: pending.fileProgress
91
+ throttleTimer = null;
92
+ if (pendingProgress) {
93
+ onProgress(Math.round(pendingProgress.progress), "Downloading model", {
94
+ file: pendingProgress.file,
95
+ progress: pendingProgress.fileProgress
100
96
  });
101
- fileLastEventTime.set(pendingFile, Date.now());
97
+ lastProgressTime = Date.now();
98
+ pendingProgress = null;
102
99
  }
103
- pendingProgressByFile.clear();
104
- throttleTimer = null;
105
100
  }, timeRemaining);
106
101
  }
107
102
  return;
108
103
  }
109
- onProgress(Math.round(overallProgress), "Downloading model", {
110
- file,
111
- progress: fileProgressValue
112
- });
113
- fileLastEventTime.set(file, now);
114
- pendingProgressByFile.delete(file);
115
- if (throttleTimer && pendingProgressByFile.size === 0) {
116
- clearTimeout(throttleTimer);
117
- throttleTimer = null;
118
- }
104
+ onProgress(Math.round(progress), "Downloading model", { file, progress: fileProgress });
105
+ lastProgressTime = now;
106
+ pendingProgress = null;
119
107
  };
120
- let hasSeenSubstantialFile = false;
121
- const substantialFileThreshold = 1048576;
122
108
  const abortSignal = signal;
123
109
  const modelPath = model.provider_config.model_path;
124
110
  const modelController = new AbortController;
@@ -131,116 +117,31 @@ var doGetPipeline = async (model, onProgress, options, progressScaleMax, cacheKe
131
117
  }
132
118
  }
133
119
  const progressCallback = (status) => {
134
- if (abortSignal?.aborted) {
120
+ if (abortSignal?.aborted)
135
121
  return;
136
- }
137
- if (status.status === "progress") {
138
- const file = status.file;
139
- const fileTotal = status.total;
140
- const fileProgressValue = status.progress;
141
- if (!fileSizes.has(file)) {
142
- fileSizes.set(file, fileTotal);
143
- fileProgress.set(file, 0);
144
- if (fileTotal >= substantialFileThreshold) {
145
- hasSeenSubstantialFile = true;
146
- }
147
- }
148
- fileProgress.set(file, fileProgressValue);
149
- const isComplete = fileProgressValue >= 100;
150
- if (isComplete && !fileCompleted.has(file)) {
151
- fileCompleted.add(file);
152
- fileProgress.set(file, 100);
153
- }
154
- let actualLoadedSize = 0;
155
- let actualTotalSize = 0;
156
- const tinyThreshold = 102400;
157
- const mediumThreshold = 104857600;
158
- let seenTinyCount = 0;
159
- let seenMediumCount = 0;
160
- let seenLargeCount = 0;
161
- for (const [trackedFile, size] of fileSizes.entries()) {
162
- actualTotalSize += size;
163
- const progress = fileProgress.get(trackedFile) || 0;
164
- actualLoadedSize += size * progress / 100;
165
- if (size < tinyThreshold) {
166
- seenTinyCount++;
167
- } else if (size < mediumThreshold) {
168
- seenMediumCount++;
169
- } else {
170
- seenLargeCount++;
171
- }
172
- }
173
- const unseenTinyFiles = Math.max(0, estimatedTinyFiles - seenTinyCount);
174
- const unseenMediumFiles = Math.max(0, estimatedMediumFiles - seenMediumCount);
175
- let estimatedLargeFiles;
176
- if (seenLargeCount > 0) {
177
- estimatedLargeFiles = 2;
178
- } else {
179
- estimatedLargeFiles = 1;
180
- }
181
- const unseenLargeFiles = Math.max(0, estimatedLargeFiles - seenLargeCount);
182
- const adjustedTotalSize = actualTotalSize + unseenTinyFiles * estimatedTinySize + unseenMediumFiles * estimatedMediumSize + unseenLargeFiles * estimatedLargeSize;
183
- const rawProgress = adjustedTotalSize > 0 ? actualLoadedSize / adjustedTotalSize * 100 : 0;
184
- const overallProgress = rawProgress * progressScaleMax / 100;
185
- const isFirst = !fileFirstSent.has(file);
186
- const isLast = isComplete && !fileLastSent.has(file);
187
- if (isFirst) {
188
- fileFirstSent.add(file);
189
- }
190
- if (isLast) {
191
- fileLastSent.add(file);
192
- }
193
- if (hasSeenSubstantialFile) {
194
- sendProgress(overallProgress, file, fileProgressValue, isFirst, isLast);
195
- }
196
- } else if (status.status === "done" || status.status === "download") {
197
- const file = status.file;
198
- const fileSize = fileSizes.get(file) || 0;
199
- if (fileSize >= substantialFileThreshold) {
200
- hasSeenSubstantialFile = true;
201
- }
202
- if (!fileCompleted.has(file)) {
203
- fileCompleted.add(file);
204
- fileProgress.set(file, 100);
205
- let actualLoadedSize = 0;
206
- let actualTotalSize = 0;
207
- const tinyThreshold = 102400;
208
- const mediumThreshold = 104857600;
209
- let seenTinyCount = 0;
210
- let seenMediumCount = 0;
211
- let seenLargeCount = 0;
212
- for (const [trackedFile, size] of fileSizes.entries()) {
213
- actualTotalSize += size;
214
- const progress = fileProgress.get(trackedFile) || 0;
215
- actualLoadedSize += size * progress / 100;
216
- if (size < tinyThreshold) {
217
- seenTinyCount++;
218
- } else if (size < mediumThreshold) {
219
- seenMediumCount++;
220
- } else {
221
- seenLargeCount++;
122
+ if (status.status === "progress_total") {
123
+ const totalStatus = status;
124
+ const scaledProgress = totalStatus.progress * progressScaleMax / 100;
125
+ let activeFile = "";
126
+ let activeFileProgress = 0;
127
+ const files = totalStatus.files;
128
+ if (files) {
129
+ for (const [file, info] of Object.entries(files)) {
130
+ if (info.loaded < info.total) {
131
+ activeFile = file;
132
+ activeFileProgress = info.total > 0 ? info.loaded / info.total * 100 : 0;
133
+ break;
222
134
  }
223
135
  }
224
- const unseenTinyFiles = Math.max(0, estimatedTinyFiles - seenTinyCount);
225
- const unseenMediumFiles = Math.max(0, estimatedMediumFiles - seenMediumCount);
226
- let estimatedLargeFiles;
227
- if (seenLargeCount > 0) {
228
- estimatedLargeFiles = 2;
229
- } else {
230
- estimatedLargeFiles = 1;
231
- }
232
- const unseenLargeFiles = Math.max(0, estimatedLargeFiles - seenLargeCount);
233
- const adjustedTotalSize = actualTotalSize + unseenTinyFiles * estimatedTinySize + unseenMediumFiles * estimatedMediumSize + unseenLargeFiles * estimatedLargeSize;
234
- const rawProgress = adjustedTotalSize > 0 ? actualLoadedSize / adjustedTotalSize * 100 : 0;
235
- const overallProgress = rawProgress * progressScaleMax / 100;
236
- const isLast = !fileLastSent.has(file);
237
- if (isLast) {
238
- fileLastSent.add(file);
239
- if (hasSeenSubstantialFile) {
240
- sendProgress(overallProgress, file, 100, false, true);
136
+ if (!activeFile) {
137
+ const fileNames = Object.keys(files);
138
+ if (fileNames.length > 0) {
139
+ activeFile = fileNames[fileNames.length - 1];
140
+ activeFileProgress = 100;
241
141
  }
242
142
  }
243
143
  }
144
+ sendProgress(scaledProgress, activeFile, activeFileProgress);
244
145
  }
245
146
  };
246
147
  const pipelineOptions = {
@@ -261,6 +162,18 @@ var doGetPipeline = async (model, onProgress, options, progressScaleMax, cacheKe
261
162
  logger.time(pipelineTimerLabel, { pipelineType, modelPath });
262
163
  try {
263
164
  const result = await pipeline(pipelineType, model.provider_config.model_path, pipelineOptions);
165
+ if (throttleTimer) {
166
+ clearTimeout(throttleTimer);
167
+ throttleTimer = null;
168
+ }
169
+ const finalPending = pendingProgress;
170
+ if (finalPending) {
171
+ onProgress(Math.round(finalPending.progress), "Downloading model", {
172
+ file: finalPending.file,
173
+ progress: finalPending.fileProgress
174
+ });
175
+ pendingProgress = null;
176
+ }
264
177
  if (abortSignal?.aborted) {
265
178
  logger.timeEnd(pipelineTimerLabel, { status: "aborted" });
266
179
  throw new Error("Operation aborted after pipeline creation");
@@ -365,12 +278,22 @@ var HFT_TextEmbedding = async (input, model, onProgress, signal) => {
365
278
  return { vector: hfVector.data };
366
279
  };
367
280
  var HFT_TextClassification = async (input, model, onProgress, signal) => {
281
+ const isArrayInput = Array.isArray(input.text);
368
282
  if (model?.provider_config?.pipeline === "zero-shot-classification") {
369
283
  if (!input.candidateLabels || !Array.isArray(input.candidateLabels) || input.candidateLabels.length === 0) {
370
284
  throw new Error("Zero-shot text classification requires candidate labels");
371
285
  }
372
286
  const zeroShotClassifier = await getPipeline(model, onProgress, {}, signal);
373
287
  const result2 = await zeroShotClassifier(input.text, input.candidateLabels, {});
288
+ if (isArrayInput) {
289
+ const results = Array.isArray(result2) && Array.isArray(result2[0]?.labels) ? result2 : [result2];
290
+ return {
291
+ categories: results.map((r) => r.labels.map((label, idx) => ({
292
+ label,
293
+ score: r.scores[idx]
294
+ })))
295
+ };
296
+ }
374
297
  return {
375
298
  categories: result2.labels.map((label, idx) => ({
376
299
  label,
@@ -382,6 +305,17 @@ var HFT_TextClassification = async (input, model, onProgress, signal) => {
382
305
  const result = await TextClassification(input.text, {
383
306
  top_k: input.maxCategories || undefined
384
307
  });
308
+ if (isArrayInput) {
309
+ return {
310
+ categories: result.map((perInput) => {
311
+ const items = Array.isArray(perInput) ? perInput : [perInput];
312
+ return items.map((category) => ({
313
+ label: category.label,
314
+ score: category.score
315
+ }));
316
+ })
317
+ };
318
+ }
385
319
  if (Array.isArray(result[0])) {
386
320
  return {
387
321
  categories: result[0].map((category) => ({
@@ -398,10 +332,22 @@ var HFT_TextClassification = async (input, model, onProgress, signal) => {
398
332
  };
399
333
  };
400
334
  var HFT_TextLanguageDetection = async (input, model, onProgress, signal) => {
335
+ const isArrayInput = Array.isArray(input.text);
401
336
  const TextClassification = await getPipeline(model, onProgress, {}, signal);
402
337
  const result = await TextClassification(input.text, {
403
338
  top_k: input.maxLanguages || undefined
404
339
  });
340
+ if (isArrayInput) {
341
+ return {
342
+ languages: result.map((perInput) => {
343
+ const items = Array.isArray(perInput) ? perInput : [perInput];
344
+ return items.map((category) => ({
345
+ language: category.label,
346
+ score: category.score
347
+ }));
348
+ })
349
+ };
350
+ }
405
351
  if (Array.isArray(result[0])) {
406
352
  return {
407
353
  languages: result[0].map((category) => ({
@@ -418,10 +364,23 @@ var HFT_TextLanguageDetection = async (input, model, onProgress, signal) => {
418
364
  };
419
365
  };
420
366
  var HFT_TextNamedEntityRecognition = async (input, model, onProgress, signal) => {
367
+ const isArrayInput = Array.isArray(input.text);
421
368
  const textNamedEntityRecognition = await getPipeline(model, onProgress, {}, signal);
422
- let results = await textNamedEntityRecognition(input.text, {
369
+ const results = await textNamedEntityRecognition(input.text, {
423
370
  ignore_labels: input.blockList
424
371
  });
372
+ if (isArrayInput) {
373
+ return {
374
+ entities: results.map((perInput) => {
375
+ const items = Array.isArray(perInput) ? perInput : [perInput];
376
+ return items.map((entity) => ({
377
+ entity: entity.entity,
378
+ score: entity.score,
379
+ word: entity.word
380
+ }));
381
+ })
382
+ };
383
+ }
425
384
  let entities = [];
426
385
  if (!Array.isArray(results)) {
427
386
  entities = [results];
@@ -437,8 +396,21 @@ var HFT_TextNamedEntityRecognition = async (input, model, onProgress, signal) =>
437
396
  };
438
397
  };
439
398
  var HFT_TextFillMask = async (input, model, onProgress, signal) => {
399
+ const isArrayInput = Array.isArray(input.text);
440
400
  const unmasker = await getPipeline(model, onProgress, {}, signal);
441
- let results = await unmasker(input.text);
401
+ const results = await unmasker(input.text);
402
+ if (isArrayInput) {
403
+ return {
404
+ predictions: results.map((perInput) => {
405
+ const items = Array.isArray(perInput) ? perInput : [perInput];
406
+ return items.map((prediction) => ({
407
+ entity: prediction.token_str,
408
+ score: prediction.score,
409
+ sequence: prediction.sequence
410
+ }));
411
+ })
412
+ };
413
+ }
442
414
  let predictions = [];
443
415
  if (!Array.isArray(results)) {
444
416
  predictions = [results];
@@ -457,35 +429,50 @@ var HFT_TextGeneration = async (input, model, onProgress, signal) => {
457
429
  const logger = getLogger();
458
430
  const timerLabel = `hft:TextGeneration:${model?.provider_config.model_path}`;
459
431
  logger.time(timerLabel, { model: model?.provider_config.model_path });
432
+ const isArrayInput = Array.isArray(input.prompt);
460
433
  const generateText = await getPipeline(model, onProgress, {}, signal);
461
434
  logger.debug("HFT TextGeneration: pipeline ready, generating text", {
462
435
  model: model?.provider_config.model_path,
463
- promptLength: input.prompt?.length
436
+ promptLength: isArrayInput ? input.prompt.length : input.prompt?.length
464
437
  });
465
- const streamer = createTextStreamer(generateText.tokenizer, onProgress);
438
+ const streamer = isArrayInput ? undefined : createTextStreamer(generateText.tokenizer, onProgress);
466
439
  let results = await generateText(input.prompt, {
467
- streamer
440
+ ...streamer ? { streamer } : {}
468
441
  });
442
+ if (isArrayInput) {
443
+ const batchResults = Array.isArray(results) ? results : [results];
444
+ const texts = batchResults.map((r) => {
445
+ const seqs = Array.isArray(r) ? r : [r];
446
+ return extractGeneratedText(seqs[0]?.generated_text);
447
+ });
448
+ logger.timeEnd(timerLabel, { batchSize: texts.length });
449
+ return { text: texts };
450
+ }
469
451
  if (!Array.isArray(results)) {
470
452
  results = [results];
471
453
  }
472
- let text = results[0]?.generated_text;
473
- if (Array.isArray(text)) {
474
- text = text[text.length - 1]?.content;
475
- }
454
+ const text = extractGeneratedText(results[0]?.generated_text);
476
455
  logger.timeEnd(timerLabel, { outputLength: text?.length });
477
456
  return {
478
457
  text
479
458
  };
480
459
  };
481
460
  var HFT_TextTranslation = async (input, model, onProgress, signal) => {
461
+ const isArrayInput = Array.isArray(input.text);
482
462
  const translate = await getPipeline(model, onProgress, {}, signal);
483
- const streamer = createTextStreamer(translate.tokenizer, onProgress);
463
+ const streamer = isArrayInput ? undefined : createTextStreamer(translate.tokenizer, onProgress);
484
464
  const result = await translate(input.text, {
485
465
  src_lang: input.source_lang,
486
466
  tgt_lang: input.target_lang,
487
- streamer
467
+ ...streamer ? { streamer } : {}
488
468
  });
469
+ if (isArrayInput) {
470
+ const batchResults = Array.isArray(result) ? result : [result];
471
+ return {
472
+ text: batchResults.map((r) => r?.translation_text || ""),
473
+ target_lang: input.target_lang
474
+ };
475
+ }
489
476
  const translatedText = Array.isArray(result) ? result[0]?.translation_text || "" : result?.translation_text || "";
490
477
  return {
491
478
  text: translatedText,
@@ -493,20 +480,34 @@ var HFT_TextTranslation = async (input, model, onProgress, signal) => {
493
480
  };
494
481
  };
495
482
  var HFT_TextRewriter = async (input, model, onProgress, signal) => {
483
+ const isArrayInput = Array.isArray(input.text);
496
484
  const generateText = await getPipeline(model, onProgress, {}, signal);
497
- const streamer = createTextStreamer(generateText.tokenizer, onProgress);
485
+ const streamer = isArrayInput ? undefined : createTextStreamer(generateText.tokenizer, onProgress);
486
+ if (isArrayInput) {
487
+ const texts = input.text;
488
+ const promptedTexts = texts.map((t) => (input.prompt ? input.prompt + `
489
+ ` : "") + t);
490
+ let results2 = await generateText(promptedTexts, {});
491
+ const batchResults = Array.isArray(results2) ? results2 : [results2];
492
+ const outputTexts = batchResults.map((r, i) => {
493
+ const seqs = Array.isArray(r) ? r : [r];
494
+ const text2 = extractGeneratedText(seqs[0]?.generated_text);
495
+ if (text2 === promptedTexts[i]) {
496
+ throw new Error("Rewriter failed to generate new text");
497
+ }
498
+ return text2;
499
+ });
500
+ return { text: outputTexts };
501
+ }
498
502
  const promptedText = (input.prompt ? input.prompt + `
499
503
  ` : "") + input.text;
500
504
  let results = await generateText(promptedText, {
501
- streamer
505
+ ...streamer ? { streamer } : {}
502
506
  });
503
507
  if (!Array.isArray(results)) {
504
508
  results = [results];
505
509
  }
506
- let text = results[0]?.generated_text;
507
- if (Array.isArray(text)) {
508
- text = text[text.length - 1]?.content;
509
- }
510
+ const text = extractGeneratedText(results[0]?.generated_text);
510
511
  if (text === promptedText) {
511
512
  throw new Error("Rewriter failed to generate new text");
512
513
  }
@@ -515,11 +516,18 @@ var HFT_TextRewriter = async (input, model, onProgress, signal) => {
515
516
  };
516
517
  };
517
518
  var HFT_TextSummary = async (input, model, onProgress, signal) => {
519
+ const isArrayInput = Array.isArray(input.text);
518
520
  const generateSummary = await getPipeline(model, onProgress, {}, signal);
519
- const streamer = createTextStreamer(generateSummary.tokenizer, onProgress);
520
- let result = await generateSummary(input.text, {
521
- streamer
521
+ const streamer = isArrayInput ? undefined : createTextStreamer(generateSummary.tokenizer, onProgress);
522
+ const result = await generateSummary(input.text, {
523
+ ...streamer ? { streamer } : {}
522
524
  });
525
+ if (isArrayInput) {
526
+ const batchResults = Array.isArray(result) ? result : [result];
527
+ return {
528
+ text: batchResults.map((r) => r?.summary_text || "")
529
+ };
530
+ }
523
531
  let summaryText = "";
524
532
  if (Array.isArray(result)) {
525
533
  summaryText = result[0]?.summary_text || "";
@@ -531,7 +539,27 @@ var HFT_TextSummary = async (input, model, onProgress, signal) => {
531
539
  };
532
540
  };
533
541
  var HFT_TextQuestionAnswer = async (input, model, onProgress, signal) => {
542
+ const isArrayInput = Array.isArray(input.question);
534
543
  const generateAnswer = await getPipeline(model, onProgress, {}, signal);
544
+ if (isArrayInput) {
545
+ const questions = input.question;
546
+ const contexts = input.context;
547
+ if (questions.length !== contexts.length) {
548
+ throw new Error(`question[] and context[] must have the same length: ${questions.length} != ${contexts.length}`);
549
+ }
550
+ const answers = [];
551
+ for (let i = 0;i < questions.length; i++) {
552
+ const result2 = await generateAnswer(questions[i], contexts[i], {});
553
+ let answerText2 = "";
554
+ if (Array.isArray(result2)) {
555
+ answerText2 = result2[0]?.answer || "";
556
+ } else {
557
+ answerText2 = result2?.answer || "";
558
+ }
559
+ answers.push(answerText2);
560
+ }
561
+ return { text: answers };
562
+ }
535
563
  const streamer = createTextStreamer(generateAnswer.tokenizer, onProgress);
536
564
  const result = await generateAnswer(input.question, input.context, {
537
565
  streamer
@@ -670,6 +698,24 @@ function createTextStreamer(tokenizer, updateProgress) {
670
698
  }
671
699
  });
672
700
  }
701
+ function extractGeneratedText(generatedText) {
702
+ if (generatedText == null)
703
+ return "";
704
+ if (typeof generatedText === "string")
705
+ return generatedText;
706
+ const lastMessage = generatedText[generatedText.length - 1];
707
+ if (!lastMessage)
708
+ return "";
709
+ const content = lastMessage.content;
710
+ if (typeof content === "string")
711
+ return content;
712
+ for (const part of content) {
713
+ if (part.type === "text" && "text" in part) {
714
+ return part.text;
715
+ }
716
+ }
717
+ return "";
718
+ }
673
719
  function createStreamEventQueue() {
674
720
  const buffer = [];
675
721
  let resolve = null;
@@ -872,10 +918,16 @@ var HFT_TextTranslation_Stream = async function* (input, model, signal) {
872
918
  yield { type: "finish", data: { target_lang: input.target_lang } };
873
919
  };
874
920
  var HFT_CountTokens = async (input, model, onProgress, signal) => {
921
+ const isArrayInput = Array.isArray(input.text);
875
922
  const { AutoTokenizer } = _transformersSdk;
876
923
  const tokenizer = await AutoTokenizer.from_pretrained(model.provider_config.model_path, {
877
924
  progress_callback: (progress) => onProgress(progress?.progress ?? 0)
878
925
  });
926
+ if (isArrayInput) {
927
+ const texts = input.text;
928
+ const counts = texts.map((t) => tokenizer.encode(t).length);
929
+ return { count: counts };
930
+ }
879
931
  const tokenIds = tokenizer.encode(input.text);
880
932
  return { count: tokenIds.length };
881
933
  };
@@ -1022,7 +1074,40 @@ ${requiredInstruction}` };
1022
1074
  return mapHFTTools(input.tools);
1023
1075
  }
1024
1076
  var HFT_ToolCalling = async (input, model, onProgress, signal) => {
1077
+ const isArrayInput = Array.isArray(input.prompt);
1025
1078
  const generateText = await getPipeline(model, onProgress, {}, signal);
1079
+ if (isArrayInput) {
1080
+ const prompts = input.prompt;
1081
+ const texts = [];
1082
+ const allToolCalls = [];
1083
+ for (const promptText of prompts) {
1084
+ const messages2 = [];
1085
+ if (input.systemPrompt) {
1086
+ messages2.push({ role: "system", content: input.systemPrompt });
1087
+ }
1088
+ messages2.push({ role: "user", content: promptText });
1089
+ const singleInput = { ...input, prompt: promptText };
1090
+ const tools2 = resolveHFTToolsAndMessages(singleInput, messages2);
1091
+ const prompt2 = generateText.tokenizer.apply_chat_template(messages2, {
1092
+ tools: tools2,
1093
+ tokenize: false,
1094
+ add_generation_prompt: true
1095
+ });
1096
+ let results2 = await generateText(prompt2, {
1097
+ max_new_tokens: input.maxTokens ?? 1024,
1098
+ temperature: input.temperature ?? undefined,
1099
+ return_full_text: false
1100
+ });
1101
+ if (!Array.isArray(results2)) {
1102
+ results2 = [results2];
1103
+ }
1104
+ const responseText2 = extractGeneratedText(results2[0]?.generated_text).trim();
1105
+ const parsed = parseToolCallsFromText(responseText2);
1106
+ texts.push(parsed.text);
1107
+ allToolCalls.push(filterValidToolCalls(parsed.toolCalls, input.tools));
1108
+ }
1109
+ return { text: texts, toolCalls: allToolCalls };
1110
+ }
1026
1111
  const messages = [];
1027
1112
  if (input.systemPrompt) {
1028
1113
  messages.push({ role: "system", content: input.systemPrompt });
@@ -1044,11 +1129,7 @@ var HFT_ToolCalling = async (input, model, onProgress, signal) => {
1044
1129
  if (!Array.isArray(results)) {
1045
1130
  results = [results];
1046
1131
  }
1047
- let responseText = results[0]?.generated_text;
1048
- if (Array.isArray(responseText)) {
1049
- responseText = responseText[responseText.length - 1]?.content;
1050
- }
1051
- responseText = (responseText ?? "").trim();
1132
+ const responseText = extractGeneratedText(results[0]?.generated_text).trim();
1052
1133
  const { text, toolCalls } = parseToolCallsFromText(responseText);
1053
1134
  return { text, toolCalls: filterValidToolCalls(toolCalls, input.tools) };
1054
1135
  };
@@ -1113,9 +1194,55 @@ var HFT_ToolCalling_Stream = async function* (input, model, signal) {
1113
1194
  data: { text: cleanedText, toolCalls: validToolCalls }
1114
1195
  };
1115
1196
  };
1197
+ var HFT_ModelInfo = async (input, model) => {
1198
+ const logger = getLogger();
1199
+ const { ModelRegistry } = await loadTransformersSDK();
1200
+ const timerLabel = `hft:ModelInfo:${model?.provider_config.model_path}`;
1201
+ logger.time(timerLabel, { model: model?.provider_config.model_path });
1202
+ const detail = input.detail;
1203
+ const is_loaded = pipelines.has(getPipelineCacheKey(model));
1204
+ const { pipeline: pipelineType, model_path, dtype, device } = model.provider_config;
1205
+ const cacheStatus = await ModelRegistry.is_pipeline_cached(pipelineType, model_path, {
1206
+ ...dtype ? { dtype } : {},
1207
+ ...device ? { device } : {}
1208
+ });
1209
+ logger.error("cacheStatus", cacheStatus);
1210
+ const is_cached = is_loaded || cacheStatus.allCached;
1211
+ let file_sizes = null;
1212
+ if (detail === "files" && cacheStatus.files.length > 0) {
1213
+ const sizes = {};
1214
+ for (const { file } of cacheStatus.files) {
1215
+ sizes[file] = 0;
1216
+ }
1217
+ file_sizes = sizes;
1218
+ } else if (detail === "files_with_metadata" && cacheStatus.files.length > 0) {
1219
+ const sizes = {};
1220
+ await Promise.all(cacheStatus.files.map(async ({ file }) => {
1221
+ const metadata = await ModelRegistry.get_file_metadata(model_path, file);
1222
+ if (metadata.exists && metadata.size !== undefined) {
1223
+ sizes[file] = metadata.size;
1224
+ }
1225
+ }));
1226
+ if (Object.keys(sizes).length > 0) {
1227
+ file_sizes = sizes;
1228
+ }
1229
+ }
1230
+ logger.timeEnd(timerLabel, { model: model?.provider_config.model_path });
1231
+ return {
1232
+ model: input.model,
1233
+ is_local: true,
1234
+ is_remote: false,
1235
+ supports_browser: true,
1236
+ supports_node: true,
1237
+ is_cached,
1238
+ is_loaded,
1239
+ file_sizes
1240
+ };
1241
+ };
1116
1242
  var HFT_TASKS = {
1117
1243
  DownloadModelTask: HFT_Download,
1118
1244
  UnloadModelTask: HFT_Unload,
1245
+ ModelInfoTask: HFT_ModelInfo,
1119
1246
  CountTokensTask: HFT_CountTokens,
1120
1247
  TextEmbeddingTask: HFT_TextEmbedding,
1121
1248
  TextGenerationTask: HFT_TextGeneration,
@@ -1147,6 +1274,6 @@ var HFT_REACTIVE_TASKS = {
1147
1274
  CountTokensTask: HFT_CountTokens_Reactive
1148
1275
  };
1149
1276
 
1150
- export { clearPipelineCache, HFT_Download, HFT_Unload, HFT_TextEmbedding, HFT_TextClassification, HFT_TextLanguageDetection, HFT_TextNamedEntityRecognition, HFT_TextFillMask, HFT_TextGeneration, HFT_TextTranslation, HFT_TextRewriter, HFT_TextSummary, HFT_TextQuestionAnswer, HFT_ImageSegmentation, HFT_ImageToText, HFT_BackgroundRemoval, HFT_ImageEmbedding, HFT_ImageClassification, HFT_ObjectDetection, createToolCallMarkupFilter, HFT_TextGeneration_Stream, HFT_TextRewriter_Stream, HFT_TextSummary_Stream, HFT_TextQuestionAnswer_Stream, HFT_TextTranslation_Stream, HFT_CountTokens, HFT_CountTokens_Reactive, parseToolCallsFromText, HFT_ToolCalling, HFT_ToolCalling_Stream, HFT_TASKS, HFT_STREAM_TASKS, HFT_REACTIVE_TASKS };
1277
+ export { clearPipelineCache, HFT_Download, HFT_Unload, HFT_TextEmbedding, HFT_TextClassification, HFT_TextLanguageDetection, HFT_TextNamedEntityRecognition, HFT_TextFillMask, HFT_TextGeneration, HFT_TextTranslation, HFT_TextRewriter, HFT_TextSummary, HFT_TextQuestionAnswer, HFT_ImageSegmentation, HFT_ImageToText, HFT_BackgroundRemoval, HFT_ImageEmbedding, HFT_ImageClassification, HFT_ObjectDetection, createToolCallMarkupFilter, HFT_TextGeneration_Stream, HFT_TextRewriter_Stream, HFT_TextSummary_Stream, HFT_TextQuestionAnswer_Stream, HFT_TextTranslation_Stream, HFT_CountTokens, HFT_CountTokens_Reactive, parseToolCallsFromText, HFT_ToolCalling, HFT_ToolCalling_Stream, HFT_ModelInfo, HFT_TASKS, HFT_STREAM_TASKS, HFT_REACTIVE_TASKS };
1151
1278
 
1152
- //# debugId=55B90A6AAE9C20DF64756E2164756E21
1279
+ //# debugId=FF878FE45BB6B2A664756E2164756E21