@workglow/ai 0.0.125 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +218 -47
- package/dist/browser.d.ts +7 -0
- package/dist/browser.d.ts.map +1 -0
- package/dist/browser.js +565 -1277
- package/dist/browser.js.map +56 -61
- package/dist/{types.d.ts → bun.d.ts} +1 -1
- package/dist/bun.d.ts.map +1 -0
- package/dist/bun.js +565 -1277
- package/dist/bun.js.map +56 -61
- package/dist/common.d.ts +3 -1
- package/dist/common.d.ts.map +1 -1
- package/dist/execution/DirectExecutionStrategy.d.ts +20 -0
- package/dist/execution/DirectExecutionStrategy.d.ts.map +1 -0
- package/dist/execution/IAiExecutionStrategy.d.ts +33 -0
- package/dist/execution/IAiExecutionStrategy.d.ts.map +1 -0
- package/dist/execution/QueuedExecutionStrategy.d.ts +50 -0
- package/dist/execution/QueuedExecutionStrategy.d.ts.map +1 -0
- package/dist/job/AiJob.d.ts +6 -0
- package/dist/job/AiJob.d.ts.map +1 -1
- package/dist/node.d.ts +7 -0
- package/dist/node.d.ts.map +1 -0
- package/dist/node.js +565 -1277
- package/dist/node.js.map +56 -61
- package/dist/provider/AiProvider.d.ts +16 -2
- package/dist/provider/AiProvider.d.ts.map +1 -1
- package/dist/provider/AiProviderRegistry.d.ts +20 -0
- package/dist/provider/AiProviderRegistry.d.ts.map +1 -1
- package/dist/provider/QueuedAiProvider.d.ts +23 -2
- package/dist/provider/QueuedAiProvider.d.ts.map +1 -1
- package/dist/task/BackgroundRemovalTask.d.ts +3 -3
- package/dist/task/BackgroundRemovalTask.d.ts.map +1 -1
- package/dist/task/ChunkRetrievalTask.d.ts +4 -4
- package/dist/task/ChunkRetrievalTask.d.ts.map +1 -1
- package/dist/task/ChunkToVectorTask.d.ts +4 -4
- package/dist/task/ChunkToVectorTask.d.ts.map +1 -1
- package/dist/task/ChunkVectorHybridSearchTask.d.ts +4 -4
- package/dist/task/ChunkVectorHybridSearchTask.d.ts.map +1 -1
- package/dist/task/ChunkVectorSearchTask.d.ts +4 -4
- package/dist/task/ChunkVectorSearchTask.d.ts.map +1 -1
- package/dist/task/ChunkVectorUpsertTask.d.ts +4 -4
- package/dist/task/ChunkVectorUpsertTask.d.ts.map +1 -1
- package/dist/task/ContextBuilderTask.d.ts +4 -4
- package/dist/task/ContextBuilderTask.d.ts.map +1 -1
- package/dist/task/CountTokensTask.d.ts +11 -29
- package/dist/task/CountTokensTask.d.ts.map +1 -1
- package/dist/task/DocumentEnricherTask.d.ts +4 -4
- package/dist/task/DocumentEnricherTask.d.ts.map +1 -1
- package/dist/task/DownloadModelTask.d.ts +5 -5
- package/dist/task/DownloadModelTask.d.ts.map +1 -1
- package/dist/task/FaceDetectorTask.d.ts +4 -4
- package/dist/task/FaceDetectorTask.d.ts.map +1 -1
- package/dist/task/FaceLandmarkerTask.d.ts +4 -4
- package/dist/task/FaceLandmarkerTask.d.ts.map +1 -1
- package/dist/task/GestureRecognizerTask.d.ts +4 -4
- package/dist/task/GestureRecognizerTask.d.ts.map +1 -1
- package/dist/task/HandLandmarkerTask.d.ts +4 -4
- package/dist/task/HandLandmarkerTask.d.ts.map +1 -1
- package/dist/task/HierarchicalChunkerTask.d.ts +4 -4
- package/dist/task/HierarchicalChunkerTask.d.ts.map +1 -1
- package/dist/task/HierarchyJoinTask.d.ts +4 -4
- package/dist/task/HierarchyJoinTask.d.ts.map +1 -1
- package/dist/task/ImageClassificationTask.d.ts +4 -4
- package/dist/task/ImageClassificationTask.d.ts.map +1 -1
- package/dist/task/ImageEmbeddingTask.d.ts +203 -89
- package/dist/task/ImageEmbeddingTask.d.ts.map +1 -1
- package/dist/task/ImageSegmentationTask.d.ts +4 -4
- package/dist/task/ImageSegmentationTask.d.ts.map +1 -1
- package/dist/task/ImageToTextTask.d.ts +4 -4
- package/dist/task/ImageToTextTask.d.ts.map +1 -1
- package/dist/task/ModelInfoTask.d.ts +5 -5
- package/dist/task/ModelInfoTask.d.ts.map +1 -1
- package/dist/task/ModelSearchTask.d.ts +3 -3
- package/dist/task/ModelSearchTask.d.ts.map +1 -1
- package/dist/task/ObjectDetectionTask.d.ts +4 -4
- package/dist/task/ObjectDetectionTask.d.ts.map +1 -1
- package/dist/task/PoseLandmarkerTask.d.ts +4 -4
- package/dist/task/PoseLandmarkerTask.d.ts.map +1 -1
- package/dist/task/QueryExpanderTask.d.ts +4 -4
- package/dist/task/QueryExpanderTask.d.ts.map +1 -1
- package/dist/task/RerankerTask.d.ts +4 -4
- package/dist/task/RerankerTask.d.ts.map +1 -1
- package/dist/task/StructuralParserTask.d.ts +4 -4
- package/dist/task/StructuralParserTask.d.ts.map +1 -1
- package/dist/task/StructuredGenerationTask.d.ts +4 -4
- package/dist/task/StructuredGenerationTask.d.ts.map +1 -1
- package/dist/task/TextChunkerTask.d.ts +4 -4
- package/dist/task/TextChunkerTask.d.ts.map +1 -1
- package/dist/task/TextClassificationTask.d.ts +24 -62
- package/dist/task/TextClassificationTask.d.ts.map +1 -1
- package/dist/task/TextEmbeddingTask.d.ts +3 -3
- package/dist/task/TextEmbeddingTask.d.ts.map +1 -1
- package/dist/task/TextFillMaskTask.d.ts +29 -73
- package/dist/task/TextFillMaskTask.d.ts.map +1 -1
- package/dist/task/TextGenerationTask.d.ts +13 -32
- package/dist/task/TextGenerationTask.d.ts.map +1 -1
- package/dist/task/TextLanguageDetectionTask.d.ts +24 -62
- package/dist/task/TextLanguageDetectionTask.d.ts.map +1 -1
- package/dist/task/TextNamedEntityRecognitionTask.d.ts +29 -73
- package/dist/task/TextNamedEntityRecognitionTask.d.ts.map +1 -1
- package/dist/task/TextQuestionAnswerTask.d.ts +17 -45
- package/dist/task/TextQuestionAnswerTask.d.ts.map +1 -1
- package/dist/task/TextRewriterTask.d.ts +12 -31
- package/dist/task/TextRewriterTask.d.ts.map +1 -1
- package/dist/task/TextSummaryTask.d.ts +12 -31
- package/dist/task/TextSummaryTask.d.ts.map +1 -1
- package/dist/task/TextTranslationTask.d.ts +12 -31
- package/dist/task/TextTranslationTask.d.ts.map +1 -1
- package/dist/task/TopicSegmenterTask.d.ts +4 -4
- package/dist/task/TopicSegmenterTask.d.ts.map +1 -1
- package/dist/task/UnloadModelTask.d.ts +4 -4
- package/dist/task/UnloadModelTask.d.ts.map +1 -1
- package/dist/task/VectorQuantizeTask.d.ts +4 -4
- package/dist/task/VectorQuantizeTask.d.ts.map +1 -1
- package/dist/task/VectorSimilarityTask.d.ts +4 -4
- package/dist/task/VectorSimilarityTask.d.ts.map +1 -1
- package/dist/task/base/AiTask.d.ts +12 -31
- package/dist/task/base/AiTask.d.ts.map +1 -1
- package/dist/task/base/AiVisionTask.d.ts +7 -12
- package/dist/task/base/AiVisionTask.d.ts.map +1 -1
- package/dist/task/base/StreamingAiTask.d.ts +7 -4
- package/dist/task/base/StreamingAiTask.d.ts.map +1 -1
- package/dist/task/index.d.ts +1 -13
- package/dist/task/index.d.ts.map +1 -1
- package/dist/worker.d.ts +0 -2
- package/dist/worker.d.ts.map +1 -1
- package/dist/worker.js +217 -233
- package/dist/worker.js.map +7 -7
- package/package.json +24 -15
- package/dist/queue/createDefaultQueue.d.ts +0 -17
- package/dist/queue/createDefaultQueue.d.ts.map +0 -1
- package/dist/task/AgentTask.d.ts +0 -524
- package/dist/task/AgentTask.d.ts.map +0 -1
- package/dist/task/AgentTypes.d.ts +0 -181
- package/dist/task/AgentTypes.d.ts.map +0 -1
- package/dist/task/AgentUtils.d.ts +0 -50
- package/dist/task/AgentUtils.d.ts.map +0 -1
- package/dist/task/MessageConversion.d.ts +0 -52
- package/dist/task/MessageConversion.d.ts.map +0 -1
- package/dist/task/ToolCallingTask.d.ts +0 -385
- package/dist/task/ToolCallingTask.d.ts.map +0 -1
- package/dist/task/ToolCallingUtils.d.ts +0 -65
- package/dist/task/ToolCallingUtils.d.ts.map +0 -1
- package/dist/types.d.ts.map +0 -1
package/dist/worker.js
CHANGED
|
@@ -7,14 +7,196 @@ import {
|
|
|
7
7
|
// src/provider/AiProviderRegistry.ts
|
|
8
8
|
import { globalServiceRegistry, WORKER_MANAGER } from "@workglow/util/worker";
|
|
9
9
|
|
|
10
|
+
// src/job/AiJob.ts
|
|
11
|
+
import {
|
|
12
|
+
AbortSignalJobError,
|
|
13
|
+
Job,
|
|
14
|
+
JobStatus,
|
|
15
|
+
PermanentJobError,
|
|
16
|
+
RetryableJobError
|
|
17
|
+
} from "@workglow/job-queue";
|
|
18
|
+
import { getLogger } from "@workglow/util/worker";
|
|
19
|
+
var DEFAULT_AI_TIMEOUT_MS = 120000;
|
|
20
|
+
var LOCAL_LLAMACPP_DEFAULT_TIMEOUT_MS = 600000;
|
|
21
|
+
function resolveAiJobTimeoutMs(aiProvider, explicitMs) {
|
|
22
|
+
if (explicitMs !== undefined) {
|
|
23
|
+
return explicitMs;
|
|
24
|
+
}
|
|
25
|
+
if (aiProvider === "LOCAL_LLAMACPP") {
|
|
26
|
+
return LOCAL_LLAMACPP_DEFAULT_TIMEOUT_MS;
|
|
27
|
+
}
|
|
28
|
+
return DEFAULT_AI_TIMEOUT_MS;
|
|
29
|
+
}
|
|
30
|
+
function classifyProviderError(err, taskType, provider) {
|
|
31
|
+
if (err instanceof PermanentJobError || err instanceof RetryableJobError || err instanceof AbortSignalJobError) {
|
|
32
|
+
return err;
|
|
33
|
+
}
|
|
34
|
+
const message = err instanceof Error ? err.message : String(err);
|
|
35
|
+
const status = typeof err?.status === "number" ? err.status : typeof err?.statusCode === "number" ? err.statusCode : (() => {
|
|
36
|
+
const m = message.match(/\b([45]\d{2})\b/);
|
|
37
|
+
return m ? parseInt(m[1], 10) : undefined;
|
|
38
|
+
})();
|
|
39
|
+
if (err instanceof DOMException && err.name === "AbortError") {
|
|
40
|
+
return new AbortSignalJobError(`Provider call aborted for ${taskType} (${provider})`);
|
|
41
|
+
}
|
|
42
|
+
if (err instanceof DOMException && err.name === "TimeoutError") {
|
|
43
|
+
return new AbortSignalJobError(`Provider call timed out for ${taskType} (${provider})`);
|
|
44
|
+
}
|
|
45
|
+
if (message.includes("Pipeline download aborted") || message.includes("Operation aborted") || message.includes("operation was aborted") || message.includes("The operation was aborted")) {
|
|
46
|
+
return new AbortSignalJobError(`Provider call aborted for ${taskType} (${provider}): ${message}`);
|
|
47
|
+
}
|
|
48
|
+
if (message.startsWith("HFT_NULL_PROCESSOR:")) {
|
|
49
|
+
return new RetryableJobError(message);
|
|
50
|
+
}
|
|
51
|
+
if (status === 429) {
|
|
52
|
+
const retryAfterMatch = message.match(/retry.after[:\s]*(\d+)/i);
|
|
53
|
+
const retryMs = retryAfterMatch ? parseInt(retryAfterMatch[1], 10) * 1000 : 30000;
|
|
54
|
+
return new RetryableJobError(`Rate limited by ${provider} for ${taskType}: ${message}`, new Date(Date.now() + retryMs));
|
|
55
|
+
}
|
|
56
|
+
if (status === 401 || status === 403) {
|
|
57
|
+
return new PermanentJobError(`Authentication failed for ${provider} (${taskType}): ${message}`);
|
|
58
|
+
}
|
|
59
|
+
if (status === 400 || status === 404) {
|
|
60
|
+
return new PermanentJobError(`Invalid request to ${provider} for ${taskType}: ${message}`);
|
|
61
|
+
}
|
|
62
|
+
if (status && status >= 500) {
|
|
63
|
+
return new RetryableJobError(`Server error from ${provider} for ${taskType} (HTTP ${status}): ${message}`);
|
|
64
|
+
}
|
|
65
|
+
if (message.includes("ECONNREFUSED") || message.includes("ECONNRESET") || message.includes("ETIMEDOUT") || message.includes("fetch failed") || message.includes("network") || err instanceof TypeError && message.includes("fetch")) {
|
|
66
|
+
return new RetryableJobError(`Network error calling ${provider} for ${taskType}: ${message}`);
|
|
67
|
+
}
|
|
68
|
+
if (message.includes("timed out") || message.includes("timeout")) {
|
|
69
|
+
return new RetryableJobError(`Timeout calling ${provider} for ${taskType}: ${message}`);
|
|
70
|
+
}
|
|
71
|
+
return new PermanentJobError(`Provider ${provider} failed for ${taskType}: ${message}`);
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
class AiJob extends Job {
|
|
75
|
+
async execute(input, context) {
|
|
76
|
+
if (context.signal.aborted || this.status === JobStatus.ABORTING) {
|
|
77
|
+
throw new AbortSignalJobError("Abort signal aborted before execution of job");
|
|
78
|
+
}
|
|
79
|
+
let abortHandler;
|
|
80
|
+
try {
|
|
81
|
+
const abortPromise = new Promise((_resolve, reject) => {
|
|
82
|
+
const handler = () => {
|
|
83
|
+
reject(new AbortSignalJobError("Abort signal seen, ending job"));
|
|
84
|
+
};
|
|
85
|
+
context.signal.addEventListener("abort", handler, { once: true });
|
|
86
|
+
abortHandler = () => context.signal.removeEventListener("abort", handler);
|
|
87
|
+
});
|
|
88
|
+
const runFn = async () => {
|
|
89
|
+
const fn = getAiProviderRegistry().getDirectRunFn(input.aiProvider, input.taskType);
|
|
90
|
+
const model = input.taskInput.model;
|
|
91
|
+
if (context.signal.aborted) {
|
|
92
|
+
throw new AbortSignalJobError("Job aborted");
|
|
93
|
+
}
|
|
94
|
+
const timeoutMs = resolveAiJobTimeoutMs(input.aiProvider, input.timeoutMs);
|
|
95
|
+
const timeoutSignal = AbortSignal.timeout(timeoutMs);
|
|
96
|
+
const combinedSignal = AbortSignal.any([context.signal, timeoutSignal]);
|
|
97
|
+
return await fn(input.taskInput, model, context.updateProgress, combinedSignal, input.outputSchema);
|
|
98
|
+
};
|
|
99
|
+
const runFnPromise = runFn();
|
|
100
|
+
return await Promise.race([runFnPromise, abortPromise]);
|
|
101
|
+
} catch (err) {
|
|
102
|
+
throw classifyProviderError(err, input.taskType, input.aiProvider);
|
|
103
|
+
} finally {
|
|
104
|
+
if (abortHandler) {
|
|
105
|
+
abortHandler();
|
|
106
|
+
}
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
async* executeStream(input, context) {
|
|
110
|
+
if (context.signal.aborted || this.status === JobStatus.ABORTING) {
|
|
111
|
+
throw new AbortSignalJobError("Abort signal aborted before streaming execution of job");
|
|
112
|
+
}
|
|
113
|
+
const streamFn = getAiProviderRegistry().getStreamFn(input.aiProvider, input.taskType);
|
|
114
|
+
if (!streamFn) {
|
|
115
|
+
const result = await this.execute(input, context);
|
|
116
|
+
yield { type: "finish", data: result };
|
|
117
|
+
return;
|
|
118
|
+
}
|
|
119
|
+
const model = input.taskInput.model;
|
|
120
|
+
let lastFinishData;
|
|
121
|
+
const timeoutMs = resolveAiJobTimeoutMs(input.aiProvider, input.timeoutMs);
|
|
122
|
+
const timeoutSignal = AbortSignal.timeout(timeoutMs);
|
|
123
|
+
const combinedSignal = AbortSignal.any([context.signal, timeoutSignal]);
|
|
124
|
+
try {
|
|
125
|
+
for await (const event of streamFn(input.taskInput, model, combinedSignal, input.outputSchema)) {
|
|
126
|
+
if (event.type === "finish") {
|
|
127
|
+
lastFinishData = event.data;
|
|
128
|
+
}
|
|
129
|
+
yield event;
|
|
130
|
+
}
|
|
131
|
+
} catch (err) {
|
|
132
|
+
const logger = getLogger();
|
|
133
|
+
logger.warn(`AiJob: Stream error for ${input.taskType} (${input.aiProvider}): ${err instanceof Error ? err.message : String(err)}`);
|
|
134
|
+
if (lastFinishData === undefined) {
|
|
135
|
+
yield { type: "finish", data: {} };
|
|
136
|
+
}
|
|
137
|
+
throw classifyProviderError(err, input.taskType, input.aiProvider);
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
// src/execution/DirectExecutionStrategy.ts
|
|
143
|
+
class DirectExecutionStrategy {
|
|
144
|
+
async execute(jobInput, context, runnerId) {
|
|
145
|
+
const job = new AiJob({
|
|
146
|
+
queueName: jobInput.aiProvider,
|
|
147
|
+
jobRunId: runnerId,
|
|
148
|
+
input: jobInput
|
|
149
|
+
});
|
|
150
|
+
const cleanup = job.onJobProgress((progress, message, details) => {
|
|
151
|
+
context.updateProgress(progress, message, details);
|
|
152
|
+
});
|
|
153
|
+
try {
|
|
154
|
+
return await job.execute(jobInput, {
|
|
155
|
+
signal: context.signal,
|
|
156
|
+
updateProgress: context.updateProgress
|
|
157
|
+
});
|
|
158
|
+
} finally {
|
|
159
|
+
cleanup();
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
async* executeStream(jobInput, context, runnerId) {
|
|
163
|
+
const job = new AiJob({
|
|
164
|
+
queueName: jobInput.aiProvider,
|
|
165
|
+
jobRunId: runnerId,
|
|
166
|
+
input: jobInput
|
|
167
|
+
});
|
|
168
|
+
yield* job.executeStream(jobInput, {
|
|
169
|
+
signal: context.signal,
|
|
170
|
+
updateProgress: context.updateProgress
|
|
171
|
+
});
|
|
172
|
+
}
|
|
173
|
+
abort() {}
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
// src/provider/AiProviderRegistry.ts
|
|
10
177
|
class AiProviderRegistry {
|
|
11
178
|
runFnRegistry = new Map;
|
|
12
179
|
streamFnRegistry = new Map;
|
|
13
180
|
reactiveRunFnRegistry = new Map;
|
|
14
181
|
providers = new Map;
|
|
182
|
+
strategyResolvers = new Map;
|
|
183
|
+
defaultStrategy;
|
|
15
184
|
registerProvider(provider) {
|
|
16
185
|
this.providers.set(provider.name, provider);
|
|
17
186
|
}
|
|
187
|
+
unregisterProvider(name) {
|
|
188
|
+
this.providers.delete(name);
|
|
189
|
+
this.strategyResolvers.delete(name);
|
|
190
|
+
for (const [, providerMap] of this.runFnRegistry) {
|
|
191
|
+
providerMap.delete(name);
|
|
192
|
+
}
|
|
193
|
+
for (const [, providerMap] of this.streamFnRegistry) {
|
|
194
|
+
providerMap.delete(name);
|
|
195
|
+
}
|
|
196
|
+
for (const [, providerMap] of this.reactiveRunFnRegistry) {
|
|
197
|
+
providerMap.delete(name);
|
|
198
|
+
}
|
|
199
|
+
}
|
|
18
200
|
getProvider(name) {
|
|
19
201
|
return this.providers.get(name);
|
|
20
202
|
}
|
|
@@ -24,6 +206,18 @@ class AiProviderRegistry {
|
|
|
24
206
|
getInstalledProviderIds() {
|
|
25
207
|
return [...this.providers.keys()].sort();
|
|
26
208
|
}
|
|
209
|
+
registerStrategyResolver(providerName, resolver) {
|
|
210
|
+
this.strategyResolvers.set(providerName, resolver);
|
|
211
|
+
}
|
|
212
|
+
getStrategy(model) {
|
|
213
|
+
const resolver = this.strategyResolvers.get(model.provider);
|
|
214
|
+
if (resolver)
|
|
215
|
+
return resolver(model);
|
|
216
|
+
if (!this.defaultStrategy) {
|
|
217
|
+
this.defaultStrategy = new DirectExecutionStrategy;
|
|
218
|
+
}
|
|
219
|
+
return this.defaultStrategy;
|
|
220
|
+
}
|
|
27
221
|
getProviderIdsForTask(taskType) {
|
|
28
222
|
const taskMap = this.runFnRegistry.get(taskType);
|
|
29
223
|
if (!taskMap)
|
|
@@ -89,15 +283,16 @@ class AiProviderRegistry {
|
|
|
89
283
|
const taskTypeMap = this.runFnRegistry.get(taskType);
|
|
90
284
|
const runFn = taskTypeMap?.get(modelProvider);
|
|
91
285
|
if (!runFn) {
|
|
92
|
-
|
|
286
|
+
const installedProviders = this.getInstalledProviderIds();
|
|
287
|
+
const providersForTask = this.getProviderIdsForTask(taskType);
|
|
288
|
+
const hint = providersForTask.length > 0 ? ` Providers supporting "${taskType}": [${providersForTask.join(", ")}].` : installedProviders.length > 0 ? ` Installed providers: [${installedProviders.join(", ")}] (none support "${taskType}").` : " No providers are registered. Call provider.register() before running AI tasks.";
|
|
289
|
+
throw new Error(`No run function found for task type "${taskType}" and provider "${modelProvider}".${hint}`);
|
|
93
290
|
}
|
|
94
291
|
return runFn;
|
|
95
292
|
}
|
|
96
293
|
}
|
|
97
|
-
var providerRegistry;
|
|
294
|
+
var providerRegistry = new AiProviderRegistry;
|
|
98
295
|
function getAiProviderRegistry() {
|
|
99
|
-
if (!providerRegistry)
|
|
100
|
-
providerRegistry = new AiProviderRegistry;
|
|
101
296
|
return providerRegistry;
|
|
102
297
|
}
|
|
103
298
|
function setAiProviderRegistry(pr) {
|
|
@@ -105,6 +300,16 @@ function setAiProviderRegistry(pr) {
|
|
|
105
300
|
}
|
|
106
301
|
|
|
107
302
|
// src/provider/AiProvider.ts
|
|
303
|
+
function resolveAiProviderGpuQueueConcurrency(concurrency) {
|
|
304
|
+
if (concurrency === undefined) {
|
|
305
|
+
return 1;
|
|
306
|
+
}
|
|
307
|
+
if (typeof concurrency === "number") {
|
|
308
|
+
return concurrency;
|
|
309
|
+
}
|
|
310
|
+
return concurrency.gpu ?? 1;
|
|
311
|
+
}
|
|
312
|
+
|
|
108
313
|
class AiProvider {
|
|
109
314
|
tasks;
|
|
110
315
|
streamTasks;
|
|
@@ -164,7 +369,12 @@ class AiProvider {
|
|
|
164
369
|
}
|
|
165
370
|
}
|
|
166
371
|
registry.registerProvider(this);
|
|
167
|
-
|
|
372
|
+
try {
|
|
373
|
+
await this.afterRegister(options);
|
|
374
|
+
} catch (err) {
|
|
375
|
+
registry.unregisterProvider(this.name);
|
|
376
|
+
throw err;
|
|
377
|
+
}
|
|
168
378
|
}
|
|
169
379
|
registerOnWorkerServer(workerServer) {
|
|
170
380
|
if (!this.tasks) {
|
|
@@ -188,228 +398,6 @@ class AiProvider {
|
|
|
188
398
|
async dispose() {}
|
|
189
399
|
async afterRegister(_options) {}
|
|
190
400
|
}
|
|
191
|
-
// src/task/ToolCallingUtils.ts
|
|
192
|
-
import { getLogger } from "@workglow/util/worker";
|
|
193
|
-
function buildToolDescription(tool) {
|
|
194
|
-
let desc = tool.description;
|
|
195
|
-
if (tool.outputSchema && typeof tool.outputSchema === "object") {
|
|
196
|
-
desc += `
|
|
197
|
-
|
|
198
|
-
Returns: ${JSON.stringify(tool.outputSchema)}`;
|
|
199
|
-
}
|
|
200
|
-
return desc;
|
|
201
|
-
}
|
|
202
|
-
function isAllowedToolName(name, allowedTools) {
|
|
203
|
-
return allowedTools.some((t) => t.name === name);
|
|
204
|
-
}
|
|
205
|
-
function filterValidToolCalls(toolCalls, allowedTools) {
|
|
206
|
-
return toolCalls.filter((tc) => {
|
|
207
|
-
if (tc.name && isAllowedToolName(tc.name, allowedTools)) {
|
|
208
|
-
return true;
|
|
209
|
-
}
|
|
210
|
-
getLogger().warn(`Filtered out tool call with unknown name "${tc.name ?? "(missing)"}"`, {
|
|
211
|
-
callId: tc.id,
|
|
212
|
-
toolName: tc.name
|
|
213
|
-
});
|
|
214
|
-
return false;
|
|
215
|
-
});
|
|
216
|
-
}
|
|
217
|
-
// src/task/MessageConversion.ts
|
|
218
|
-
function getInputMessages(input) {
|
|
219
|
-
const messages = input.messages;
|
|
220
|
-
if (!messages || messages.length === 0)
|
|
221
|
-
return;
|
|
222
|
-
return messages;
|
|
223
|
-
}
|
|
224
|
-
function toOpenAIMessages(input) {
|
|
225
|
-
const messages = [];
|
|
226
|
-
if (input.systemPrompt) {
|
|
227
|
-
messages.push({ role: "system", content: input.systemPrompt });
|
|
228
|
-
}
|
|
229
|
-
const inputMessages = getInputMessages(input);
|
|
230
|
-
if (!inputMessages) {
|
|
231
|
-
if (!Array.isArray(input.prompt)) {
|
|
232
|
-
messages.push({ role: "user", content: input.prompt });
|
|
233
|
-
} else if (input.prompt.every((item) => typeof item === "string")) {
|
|
234
|
-
messages.push({ role: "user", content: input.prompt.join(`
|
|
235
|
-
`) });
|
|
236
|
-
} else {
|
|
237
|
-
const parts = [];
|
|
238
|
-
for (const item of input.prompt) {
|
|
239
|
-
if (typeof item === "string") {
|
|
240
|
-
parts.push({ type: "text", text: item });
|
|
241
|
-
} else {
|
|
242
|
-
const b = item;
|
|
243
|
-
if (b.type === "text") {
|
|
244
|
-
parts.push({ type: "text", text: b.text });
|
|
245
|
-
} else if (b.type === "image") {
|
|
246
|
-
parts.push({
|
|
247
|
-
type: "image_url",
|
|
248
|
-
image_url: { url: `data:${b.mimeType};base64,${b.data}` }
|
|
249
|
-
});
|
|
250
|
-
} else if (b.type === "audio") {
|
|
251
|
-
const format = b.mimeType.replace(/^audio\//, "");
|
|
252
|
-
parts.push({
|
|
253
|
-
type: "input_audio",
|
|
254
|
-
input_audio: { data: b.data, format }
|
|
255
|
-
});
|
|
256
|
-
}
|
|
257
|
-
}
|
|
258
|
-
}
|
|
259
|
-
messages.push({ role: "user", content: parts });
|
|
260
|
-
}
|
|
261
|
-
return messages;
|
|
262
|
-
}
|
|
263
|
-
for (const msg of inputMessages) {
|
|
264
|
-
if (msg.role === "user") {
|
|
265
|
-
if (typeof msg.content === "string") {
|
|
266
|
-
messages.push({ role: "user", content: msg.content });
|
|
267
|
-
} else if (Array.isArray(msg.content) && msg.content.length > 0 && typeof msg.content[0]?.type === "string") {
|
|
268
|
-
const parts = [];
|
|
269
|
-
for (const block of msg.content) {
|
|
270
|
-
const b = block;
|
|
271
|
-
if (b.type === "text") {
|
|
272
|
-
parts.push({ type: "text", text: b.text });
|
|
273
|
-
} else if (b.type === "image") {
|
|
274
|
-
parts.push({
|
|
275
|
-
type: "image_url",
|
|
276
|
-
image_url: { url: `data:${b.mimeType};base64,${b.data}` }
|
|
277
|
-
});
|
|
278
|
-
} else if (b.type === "audio") {
|
|
279
|
-
const format = b.mimeType.replace(/^audio\//, "");
|
|
280
|
-
parts.push({
|
|
281
|
-
type: "input_audio",
|
|
282
|
-
input_audio: { data: b.data, format }
|
|
283
|
-
});
|
|
284
|
-
}
|
|
285
|
-
}
|
|
286
|
-
messages.push({ role: "user", content: parts });
|
|
287
|
-
} else {
|
|
288
|
-
try {
|
|
289
|
-
messages.push({ role: "user", content: JSON.stringify(msg.content) });
|
|
290
|
-
} catch {
|
|
291
|
-
messages.push({ role: "user", content: String(msg.content) });
|
|
292
|
-
}
|
|
293
|
-
}
|
|
294
|
-
} else if (msg.role === "assistant") {
|
|
295
|
-
if (typeof msg.content === "string") {
|
|
296
|
-
messages.push({ role: "assistant", content: msg.content.length > 0 ? msg.content : null });
|
|
297
|
-
} else if (Array.isArray(msg.content)) {
|
|
298
|
-
const textParts = msg.content.filter((b) => b.type === "text").map((b) => b.text).join("");
|
|
299
|
-
const toolCalls = msg.content.filter((b) => b.type === "tool_use").map((b) => ({
|
|
300
|
-
id: b.id,
|
|
301
|
-
type: "function",
|
|
302
|
-
function: {
|
|
303
|
-
name: b.name,
|
|
304
|
-
arguments: JSON.stringify(b.input)
|
|
305
|
-
}
|
|
306
|
-
}));
|
|
307
|
-
const entry = {
|
|
308
|
-
role: "assistant",
|
|
309
|
-
content: textParts.length > 0 ? textParts : null
|
|
310
|
-
};
|
|
311
|
-
if (toolCalls.length > 0) {
|
|
312
|
-
entry.tool_calls = toolCalls;
|
|
313
|
-
}
|
|
314
|
-
messages.push(entry);
|
|
315
|
-
}
|
|
316
|
-
} else if (msg.role === "tool" && Array.isArray(msg.content)) {
|
|
317
|
-
for (const block of msg.content) {
|
|
318
|
-
const b = block;
|
|
319
|
-
let content;
|
|
320
|
-
if (typeof b.content === "string") {
|
|
321
|
-
content = b.content;
|
|
322
|
-
} else if (Array.isArray(b.content)) {
|
|
323
|
-
const parts = [];
|
|
324
|
-
for (const inner of b.content) {
|
|
325
|
-
if (inner.type === "text") {
|
|
326
|
-
parts.push({ type: "text", text: inner.text });
|
|
327
|
-
} else if (inner.type === "image") {
|
|
328
|
-
parts.push({
|
|
329
|
-
type: "image_url",
|
|
330
|
-
image_url: { url: `data:${inner.mimeType};base64,${inner.data}` }
|
|
331
|
-
});
|
|
332
|
-
}
|
|
333
|
-
}
|
|
334
|
-
content = parts;
|
|
335
|
-
} else {
|
|
336
|
-
content = "";
|
|
337
|
-
}
|
|
338
|
-
messages.push({
|
|
339
|
-
role: "tool",
|
|
340
|
-
content,
|
|
341
|
-
tool_call_id: b.tool_use_id
|
|
342
|
-
});
|
|
343
|
-
}
|
|
344
|
-
}
|
|
345
|
-
}
|
|
346
|
-
return messages;
|
|
347
|
-
}
|
|
348
|
-
function toTextFlatMessages(input) {
|
|
349
|
-
const messages = [];
|
|
350
|
-
if (input.systemPrompt) {
|
|
351
|
-
messages.push({ role: "system", content: input.systemPrompt });
|
|
352
|
-
}
|
|
353
|
-
const inputMessages = getInputMessages(input);
|
|
354
|
-
if (!inputMessages) {
|
|
355
|
-
let promptContent;
|
|
356
|
-
if (!Array.isArray(input.prompt)) {
|
|
357
|
-
promptContent = input.prompt;
|
|
358
|
-
} else {
|
|
359
|
-
promptContent = input.prompt.map((item) => {
|
|
360
|
-
if (typeof item === "string")
|
|
361
|
-
return item;
|
|
362
|
-
const b = item;
|
|
363
|
-
return b.type === "text" ? b.text : "";
|
|
364
|
-
}).filter((s) => s !== "").join(`
|
|
365
|
-
`);
|
|
366
|
-
}
|
|
367
|
-
messages.push({ role: "user", content: promptContent });
|
|
368
|
-
return messages;
|
|
369
|
-
}
|
|
370
|
-
for (const msg of inputMessages) {
|
|
371
|
-
if (msg.role === "user") {
|
|
372
|
-
let content = "";
|
|
373
|
-
if (typeof msg.content === "string") {
|
|
374
|
-
content = msg.content;
|
|
375
|
-
} else if (Array.isArray(msg.content) && msg.content.length > 0 && typeof msg.content[0]?.type === "string") {
|
|
376
|
-
content = msg.content.filter((b) => b.type === "text").map((b) => b.text).join("");
|
|
377
|
-
} else if (msg.content != null) {
|
|
378
|
-
try {
|
|
379
|
-
content = JSON.stringify(msg.content);
|
|
380
|
-
} catch {
|
|
381
|
-
content = String(msg.content);
|
|
382
|
-
}
|
|
383
|
-
}
|
|
384
|
-
messages.push({ role: "user", content });
|
|
385
|
-
} else if (msg.role === "assistant") {
|
|
386
|
-
if (typeof msg.content === "string") {
|
|
387
|
-
if (msg.content) {
|
|
388
|
-
messages.push({ role: "assistant", content: msg.content });
|
|
389
|
-
}
|
|
390
|
-
} else if (Array.isArray(msg.content)) {
|
|
391
|
-
const text = msg.content.filter((b) => b.type === "text").map((b) => b.text).join("");
|
|
392
|
-
if (text) {
|
|
393
|
-
messages.push({ role: "assistant", content: text });
|
|
394
|
-
}
|
|
395
|
-
}
|
|
396
|
-
} else if (msg.role === "tool" && Array.isArray(msg.content)) {
|
|
397
|
-
for (const block of msg.content) {
|
|
398
|
-
const b = block;
|
|
399
|
-
let content;
|
|
400
|
-
if (typeof b.content === "string") {
|
|
401
|
-
content = b.content;
|
|
402
|
-
} else if (Array.isArray(b.content)) {
|
|
403
|
-
content = b.content.filter((inner) => inner.type === "text").map((inner) => inner.text).join("");
|
|
404
|
-
} else {
|
|
405
|
-
content = "";
|
|
406
|
-
}
|
|
407
|
-
messages.push({ role: "tool", content });
|
|
408
|
-
}
|
|
409
|
-
}
|
|
410
|
-
}
|
|
411
|
-
return messages;
|
|
412
|
-
}
|
|
413
401
|
// src/model/ModelSchema.ts
|
|
414
402
|
var ModelConfigSchema = {
|
|
415
403
|
type: "object",
|
|
@@ -452,13 +440,9 @@ var ModelRecordSchema = {
|
|
|
452
440
|
};
|
|
453
441
|
var ModelPrimaryKeyNames = ["model_id"];
|
|
454
442
|
export {
|
|
455
|
-
toTextFlatMessages,
|
|
456
|
-
toOpenAIMessages,
|
|
457
443
|
setAiProviderRegistry,
|
|
458
|
-
|
|
444
|
+
resolveAiProviderGpuQueueConcurrency,
|
|
459
445
|
getAiProviderRegistry,
|
|
460
|
-
filterValidToolCalls,
|
|
461
|
-
buildToolDescription,
|
|
462
446
|
ModelRecordSchema,
|
|
463
447
|
ModelPrimaryKeyNames,
|
|
464
448
|
ModelConfigSchema,
|
|
@@ -466,4 +450,4 @@ export {
|
|
|
466
450
|
AiProvider
|
|
467
451
|
};
|
|
468
452
|
|
|
469
|
-
//# debugId=
|
|
453
|
+
//# debugId=CC1A9EA8E04D69B564756E2164756E21
|