@huggingface/inference 3.7.0 → 3.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.cjs +1152 -839
- package/dist/index.js +1154 -841
- package/dist/src/lib/getProviderHelper.d.ts +37 -0
- package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
- package/dist/src/lib/makeRequestOptions.d.ts +0 -2
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/black-forest-labs.d.ts +14 -18
- package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
- package/dist/src/providers/cerebras.d.ts +4 -2
- package/dist/src/providers/cerebras.d.ts.map +1 -1
- package/dist/src/providers/cohere.d.ts +5 -2
- package/dist/src/providers/cohere.d.ts.map +1 -1
- package/dist/src/providers/fal-ai.d.ts +50 -3
- package/dist/src/providers/fal-ai.d.ts.map +1 -1
- package/dist/src/providers/fireworks-ai.d.ts +5 -2
- package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
- package/dist/src/providers/hf-inference.d.ts +125 -2
- package/dist/src/providers/hf-inference.d.ts.map +1 -1
- package/dist/src/providers/hyperbolic.d.ts +31 -2
- package/dist/src/providers/hyperbolic.d.ts.map +1 -1
- package/dist/src/providers/nebius.d.ts +20 -18
- package/dist/src/providers/nebius.d.ts.map +1 -1
- package/dist/src/providers/novita.d.ts +21 -18
- package/dist/src/providers/novita.d.ts.map +1 -1
- package/dist/src/providers/openai.d.ts +4 -2
- package/dist/src/providers/openai.d.ts.map +1 -1
- package/dist/src/providers/providerHelper.d.ts +182 -0
- package/dist/src/providers/providerHelper.d.ts.map +1 -0
- package/dist/src/providers/replicate.d.ts +23 -19
- package/dist/src/providers/replicate.d.ts.map +1 -1
- package/dist/src/providers/sambanova.d.ts +4 -2
- package/dist/src/providers/sambanova.d.ts.map +1 -1
- package/dist/src/providers/together.d.ts +32 -2
- package/dist/src/providers/together.d.ts.map +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/audio/utils.d.ts +2 -1
- package/dist/src/tasks/audio/utils.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts +0 -2
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts +0 -2
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/index.d.ts +6 -6
- package/dist/src/tasks/index.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
- package/dist/src/types.d.ts +3 -13
- package/dist/src/types.d.ts.map +1 -1
- package/package.json +3 -3
- package/src/lib/getProviderHelper.ts +270 -0
- package/src/lib/makeRequestOptions.ts +34 -91
- package/src/providers/black-forest-labs.ts +73 -22
- package/src/providers/cerebras.ts +6 -27
- package/src/providers/cohere.ts +9 -28
- package/src/providers/fal-ai.ts +195 -77
- package/src/providers/fireworks-ai.ts +8 -29
- package/src/providers/hf-inference.ts +555 -34
- package/src/providers/hyperbolic.ts +107 -29
- package/src/providers/nebius.ts +65 -29
- package/src/providers/novita.ts +68 -32
- package/src/providers/openai.ts +6 -32
- package/src/providers/providerHelper.ts +354 -0
- package/src/providers/replicate.ts +124 -34
- package/src/providers/sambanova.ts +5 -30
- package/src/providers/together.ts +92 -28
- package/src/snippets/getInferenceSnippets.ts +16 -9
- package/src/snippets/templates.exported.ts +1 -1
- package/src/tasks/audio/audioClassification.ts +4 -7
- package/src/tasks/audio/audioToAudio.ts +3 -26
- package/src/tasks/audio/automaticSpeechRecognition.ts +4 -3
- package/src/tasks/audio/textToSpeech.ts +5 -29
- package/src/tasks/audio/utils.ts +2 -1
- package/src/tasks/custom/request.ts +0 -2
- package/src/tasks/custom/streamingRequest.ts +0 -2
- package/src/tasks/cv/imageClassification.ts +3 -7
- package/src/tasks/cv/imageSegmentation.ts +3 -8
- package/src/tasks/cv/imageToImage.ts +3 -6
- package/src/tasks/cv/imageToText.ts +3 -6
- package/src/tasks/cv/objectDetection.ts +3 -18
- package/src/tasks/cv/textToImage.ts +9 -137
- package/src/tasks/cv/textToVideo.ts +11 -62
- package/src/tasks/cv/zeroShotImageClassification.ts +3 -7
- package/src/tasks/index.ts +6 -6
- package/src/tasks/multimodal/documentQuestionAnswering.ts +3 -19
- package/src/tasks/multimodal/visualQuestionAnswering.ts +3 -11
- package/src/tasks/nlp/chatCompletion.ts +5 -20
- package/src/tasks/nlp/chatCompletionStream.ts +1 -2
- package/src/tasks/nlp/featureExtraction.ts +3 -18
- package/src/tasks/nlp/fillMask.ts +3 -16
- package/src/tasks/nlp/questionAnswering.ts +3 -22
- package/src/tasks/nlp/sentenceSimilarity.ts +3 -7
- package/src/tasks/nlp/summarization.ts +3 -6
- package/src/tasks/nlp/tableQuestionAnswering.ts +3 -27
- package/src/tasks/nlp/textClassification.ts +3 -8
- package/src/tasks/nlp/textGeneration.ts +12 -79
- package/src/tasks/nlp/tokenClassification.ts +3 -18
- package/src/tasks/nlp/translation.ts +3 -6
- package/src/tasks/nlp/zeroShotClassification.ts +3 -16
- package/src/tasks/tabular/tabularClassification.ts +3 -6
- package/src/tasks/tabular/tabularRegression.ts +3 -6
- package/src/types.ts +3 -14
package/dist/index.cjs
CHANGED
|
@@ -98,442 +98,1088 @@ __export(tasks_exports, {
|
|
|
98
98
|
zeroShotImageClassification: () => zeroShotImageClassification
|
|
99
99
|
});
|
|
100
100
|
|
|
101
|
+
// package.json
|
|
102
|
+
var name = "@huggingface/inference";
|
|
103
|
+
var version = "3.7.1";
|
|
104
|
+
|
|
101
105
|
// src/config.ts
|
|
102
106
|
var HF_HUB_URL = "https://huggingface.co";
|
|
103
107
|
var HF_ROUTER_URL = "https://router.huggingface.co";
|
|
104
108
|
var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
|
|
105
109
|
|
|
106
|
-
// src/
|
|
107
|
-
var
|
|
108
|
-
|
|
109
|
-
|
|
110
|
+
// src/lib/InferenceOutputError.ts
|
|
111
|
+
var InferenceOutputError = class extends TypeError {
|
|
112
|
+
constructor(message) {
|
|
113
|
+
super(
|
|
114
|
+
`Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.`
|
|
115
|
+
);
|
|
116
|
+
this.name = "InferenceOutputError";
|
|
117
|
+
}
|
|
110
118
|
};
|
|
111
|
-
|
|
112
|
-
|
|
119
|
+
|
|
120
|
+
// src/utils/delay.ts
|
|
121
|
+
function delay(ms) {
|
|
122
|
+
return new Promise((resolve) => {
|
|
123
|
+
setTimeout(() => resolve(), ms);
|
|
124
|
+
});
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
// src/utils/pick.ts
|
|
128
|
+
function pick(o, props) {
|
|
129
|
+
return Object.assign(
|
|
130
|
+
{},
|
|
131
|
+
...props.map((prop) => {
|
|
132
|
+
if (o[prop] !== void 0) {
|
|
133
|
+
return { [prop]: o[prop] };
|
|
134
|
+
}
|
|
135
|
+
})
|
|
136
|
+
);
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
// src/utils/typedInclude.ts
|
|
140
|
+
function typedInclude(arr, v) {
|
|
141
|
+
return arr.includes(v);
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
// src/utils/omit.ts
|
|
145
|
+
function omit(o, props) {
|
|
146
|
+
const propsArr = Array.isArray(props) ? props : [props];
|
|
147
|
+
const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
|
|
148
|
+
return pick(o, letsKeep);
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
// src/utils/toArray.ts
|
|
152
|
+
function toArray(obj) {
|
|
153
|
+
if (Array.isArray(obj)) {
|
|
154
|
+
return obj;
|
|
155
|
+
}
|
|
156
|
+
return [obj];
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
// src/providers/providerHelper.ts
|
|
160
|
+
var TaskProviderHelper = class {
|
|
161
|
+
constructor(provider, baseUrl, clientSideRoutingOnly = false) {
|
|
162
|
+
this.provider = provider;
|
|
163
|
+
this.baseUrl = baseUrl;
|
|
164
|
+
this.clientSideRoutingOnly = clientSideRoutingOnly;
|
|
165
|
+
}
|
|
166
|
+
/**
|
|
167
|
+
* Prepare the base URL for the request
|
|
168
|
+
*/
|
|
169
|
+
makeBaseUrl(params) {
|
|
170
|
+
return params.authMethod !== "provider-key" ? `${HF_ROUTER_URL}/${this.provider}` : this.baseUrl;
|
|
171
|
+
}
|
|
172
|
+
/**
|
|
173
|
+
* Prepare the body for the request
|
|
174
|
+
*/
|
|
175
|
+
makeBody(params) {
|
|
176
|
+
if ("data" in params.args && !!params.args.data) {
|
|
177
|
+
return params.args.data;
|
|
178
|
+
}
|
|
179
|
+
return JSON.stringify(this.preparePayload(params));
|
|
180
|
+
}
|
|
181
|
+
/**
|
|
182
|
+
* Prepare the URL for the request
|
|
183
|
+
*/
|
|
184
|
+
makeUrl(params) {
|
|
185
|
+
const baseUrl = this.makeBaseUrl(params);
|
|
186
|
+
const route = this.makeRoute(params).replace(/^\/+/, "");
|
|
187
|
+
return `${baseUrl}/${route}`;
|
|
188
|
+
}
|
|
189
|
+
/**
|
|
190
|
+
* Prepare the headers for the request
|
|
191
|
+
*/
|
|
192
|
+
prepareHeaders(params, isBinary) {
|
|
193
|
+
const headers = { Authorization: `Bearer ${params.accessToken}` };
|
|
194
|
+
if (!isBinary) {
|
|
195
|
+
headers["Content-Type"] = "application/json";
|
|
196
|
+
}
|
|
197
|
+
return headers;
|
|
198
|
+
}
|
|
113
199
|
};
|
|
114
|
-
var
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
}
|
|
118
|
-
|
|
200
|
+
var BaseConversationalTask = class extends TaskProviderHelper {
|
|
201
|
+
constructor(provider, baseUrl, clientSideRoutingOnly = false) {
|
|
202
|
+
super(provider, baseUrl, clientSideRoutingOnly);
|
|
203
|
+
}
|
|
204
|
+
makeRoute() {
|
|
205
|
+
return "v1/chat/completions";
|
|
206
|
+
}
|
|
207
|
+
preparePayload(params) {
|
|
208
|
+
return {
|
|
209
|
+
...params.args,
|
|
210
|
+
model: params.model
|
|
211
|
+
};
|
|
212
|
+
}
|
|
213
|
+
async getResponse(response) {
|
|
214
|
+
if (typeof response === "object" && Array.isArray(response?.choices) && typeof response?.created === "number" && typeof response?.id === "string" && typeof response?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
|
|
215
|
+
(response.system_fingerprint === void 0 || response.system_fingerprint === null || typeof response.system_fingerprint === "string") && typeof response?.usage === "object") {
|
|
216
|
+
return response;
|
|
217
|
+
}
|
|
218
|
+
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
119
219
|
}
|
|
120
220
|
};
|
|
121
|
-
var
|
|
122
|
-
|
|
221
|
+
var BaseTextGenerationTask = class extends TaskProviderHelper {
|
|
222
|
+
constructor(provider, baseUrl, clientSideRoutingOnly = false) {
|
|
223
|
+
super(provider, baseUrl, clientSideRoutingOnly);
|
|
224
|
+
}
|
|
225
|
+
preparePayload(params) {
|
|
226
|
+
return {
|
|
227
|
+
...params.args,
|
|
228
|
+
model: params.model
|
|
229
|
+
};
|
|
230
|
+
}
|
|
231
|
+
makeRoute() {
|
|
232
|
+
return "v1/completions";
|
|
233
|
+
}
|
|
234
|
+
async getResponse(response) {
|
|
235
|
+
const res = toArray(response);
|
|
236
|
+
if (Array.isArray(res) && res.length > 0 && res.every(
|
|
237
|
+
(x) => typeof x === "object" && !!x && "generated_text" in x && typeof x.generated_text === "string"
|
|
238
|
+
)) {
|
|
239
|
+
return res[0];
|
|
240
|
+
}
|
|
241
|
+
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
242
|
+
}
|
|
123
243
|
};
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
244
|
+
|
|
245
|
+
// src/providers/black-forest-labs.ts
|
|
246
|
+
var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
|
|
247
|
+
var BlackForestLabsTextToImageTask = class extends TaskProviderHelper {
|
|
248
|
+
constructor() {
|
|
249
|
+
super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
|
|
250
|
+
}
|
|
251
|
+
preparePayload(params) {
|
|
252
|
+
return {
|
|
253
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
254
|
+
...params.args.parameters,
|
|
255
|
+
prompt: params.args.inputs
|
|
256
|
+
};
|
|
257
|
+
}
|
|
258
|
+
prepareHeaders(params, binary) {
|
|
259
|
+
const headers = {
|
|
260
|
+
Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`
|
|
261
|
+
};
|
|
262
|
+
if (!binary) {
|
|
263
|
+
headers["Content-Type"] = "application/json";
|
|
264
|
+
}
|
|
265
|
+
return headers;
|
|
266
|
+
}
|
|
267
|
+
makeRoute(params) {
|
|
268
|
+
if (!params) {
|
|
269
|
+
throw new Error("Params are required");
|
|
270
|
+
}
|
|
271
|
+
return `/v1/${params.model}`;
|
|
272
|
+
}
|
|
273
|
+
async getResponse(response, url, headers, outputType) {
|
|
274
|
+
const urlObj = new URL(response.polling_url);
|
|
275
|
+
for (let step = 0; step < 5; step++) {
|
|
276
|
+
await delay(1e3);
|
|
277
|
+
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
|
|
278
|
+
urlObj.searchParams.set("attempt", step.toString(10));
|
|
279
|
+
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
|
|
280
|
+
if (!resp.ok) {
|
|
281
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
282
|
+
}
|
|
283
|
+
const payload = await resp.json();
|
|
284
|
+
if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
|
|
285
|
+
if (outputType === "url") {
|
|
286
|
+
return payload.result.sample;
|
|
287
|
+
}
|
|
288
|
+
const image = await fetch(payload.result.sample);
|
|
289
|
+
return await image.blob();
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
293
|
+
}
|
|
129
294
|
};
|
|
130
295
|
|
|
131
296
|
// src/providers/cerebras.ts
|
|
132
|
-
var
|
|
133
|
-
|
|
134
|
-
|
|
297
|
+
var CerebrasConversationalTask = class extends BaseConversationalTask {
|
|
298
|
+
constructor() {
|
|
299
|
+
super("cerebras", "https://api.cerebras.ai");
|
|
300
|
+
}
|
|
135
301
|
};
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
302
|
+
|
|
303
|
+
// src/providers/cohere.ts
|
|
304
|
+
var CohereConversationalTask = class extends BaseConversationalTask {
|
|
305
|
+
constructor() {
|
|
306
|
+
super("cohere", "https://api.cohere.com");
|
|
307
|
+
}
|
|
308
|
+
makeRoute() {
|
|
309
|
+
return "/compatibility/v1/chat/completions";
|
|
310
|
+
}
|
|
311
|
+
};
|
|
312
|
+
|
|
313
|
+
// src/lib/isUrl.ts
|
|
314
|
+
function isUrl(modelOrUrl) {
|
|
315
|
+
return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
// src/providers/fal-ai.ts
|
|
319
|
+
var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
|
|
320
|
+
var FalAITask = class extends TaskProviderHelper {
|
|
321
|
+
constructor(url) {
|
|
322
|
+
super("fal-ai", url || "https://fal.run");
|
|
323
|
+
}
|
|
324
|
+
preparePayload(params) {
|
|
325
|
+
return params.args;
|
|
326
|
+
}
|
|
327
|
+
makeRoute(params) {
|
|
328
|
+
return `/${params.model}`;
|
|
329
|
+
}
|
|
330
|
+
prepareHeaders(params, binary) {
|
|
331
|
+
const headers = {
|
|
332
|
+
Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`
|
|
333
|
+
};
|
|
334
|
+
if (!binary) {
|
|
335
|
+
headers["Content-Type"] = "application/json";
|
|
336
|
+
}
|
|
337
|
+
return headers;
|
|
338
|
+
}
|
|
339
|
+
};
|
|
340
|
+
var FalAITextToImageTask = class extends FalAITask {
|
|
341
|
+
preparePayload(params) {
|
|
342
|
+
return {
|
|
343
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
344
|
+
...params.args.parameters,
|
|
345
|
+
sync_mode: true,
|
|
346
|
+
prompt: params.args.inputs
|
|
347
|
+
};
|
|
348
|
+
}
|
|
349
|
+
async getResponse(response, outputType) {
|
|
350
|
+
if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") {
|
|
351
|
+
if (outputType === "url") {
|
|
352
|
+
return response.images[0].url;
|
|
353
|
+
}
|
|
354
|
+
const urlResponse = await fetch(response.images[0].url);
|
|
355
|
+
return await urlResponse.blob();
|
|
356
|
+
}
|
|
357
|
+
throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
|
|
358
|
+
}
|
|
359
|
+
};
|
|
360
|
+
var FalAITextToVideoTask = class extends FalAITask {
|
|
361
|
+
constructor() {
|
|
362
|
+
super("https://queue.fal.run");
|
|
363
|
+
}
|
|
364
|
+
makeRoute(params) {
|
|
365
|
+
if (params.authMethod !== "provider-key") {
|
|
366
|
+
return `/${params.model}?_subdomain=queue`;
|
|
367
|
+
}
|
|
368
|
+
return `/${params.model}`;
|
|
369
|
+
}
|
|
370
|
+
preparePayload(params) {
|
|
371
|
+
return {
|
|
372
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
373
|
+
...params.args.parameters,
|
|
374
|
+
prompt: params.args.inputs
|
|
375
|
+
};
|
|
376
|
+
}
|
|
377
|
+
async getResponse(response, url, headers) {
|
|
378
|
+
if (!url || !headers) {
|
|
379
|
+
throw new InferenceOutputError("URL and headers are required for text-to-video task");
|
|
380
|
+
}
|
|
381
|
+
const requestId = response.request_id;
|
|
382
|
+
if (!requestId) {
|
|
383
|
+
throw new InferenceOutputError("No request ID found in the response");
|
|
384
|
+
}
|
|
385
|
+
let status = response.status;
|
|
386
|
+
const parsedUrl = new URL(url);
|
|
387
|
+
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
|
|
388
|
+
const modelId = new URL(response.response_url).pathname;
|
|
389
|
+
const queryParams = parsedUrl.search;
|
|
390
|
+
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
|
|
391
|
+
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
|
|
392
|
+
while (status !== "COMPLETED") {
|
|
393
|
+
await delay(500);
|
|
394
|
+
const statusResponse = await fetch(statusUrl, { headers });
|
|
395
|
+
if (!statusResponse.ok) {
|
|
396
|
+
throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
|
|
397
|
+
}
|
|
398
|
+
try {
|
|
399
|
+
status = (await statusResponse.json()).status;
|
|
400
|
+
} catch (error) {
|
|
401
|
+
throw new InferenceOutputError("Failed to parse status response from fal-ai API");
|
|
402
|
+
}
|
|
403
|
+
}
|
|
404
|
+
const resultResponse = await fetch(resultUrl, { headers });
|
|
405
|
+
let result;
|
|
406
|
+
try {
|
|
407
|
+
result = await resultResponse.json();
|
|
408
|
+
} catch (error) {
|
|
409
|
+
throw new InferenceOutputError("Failed to parse result response from fal-ai API");
|
|
410
|
+
}
|
|
411
|
+
if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
|
|
412
|
+
const urlResponse = await fetch(result.video.url);
|
|
413
|
+
return await urlResponse.blob();
|
|
414
|
+
} else {
|
|
415
|
+
throw new InferenceOutputError(
|
|
416
|
+
"Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
|
|
417
|
+
);
|
|
418
|
+
}
|
|
419
|
+
}
|
|
141
420
|
};
|
|
142
|
-
var
|
|
143
|
-
|
|
421
|
+
var FalAIAutomaticSpeechRecognitionTask = class extends FalAITask {
|
|
422
|
+
prepareHeaders(params, binary) {
|
|
423
|
+
const headers = super.prepareHeaders(params, binary);
|
|
424
|
+
headers["Content-Type"] = "application/json";
|
|
425
|
+
return headers;
|
|
426
|
+
}
|
|
427
|
+
async getResponse(response) {
|
|
428
|
+
const res = response;
|
|
429
|
+
if (typeof res?.text !== "string") {
|
|
430
|
+
throw new InferenceOutputError(
|
|
431
|
+
`Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
|
|
432
|
+
);
|
|
433
|
+
}
|
|
434
|
+
return { text: res.text };
|
|
435
|
+
}
|
|
144
436
|
};
|
|
145
|
-
var
|
|
146
|
-
|
|
437
|
+
var FalAITextToSpeechTask = class extends FalAITask {
|
|
438
|
+
preparePayload(params) {
|
|
439
|
+
return {
|
|
440
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
441
|
+
...params.args.parameters,
|
|
442
|
+
lyrics: params.args.inputs
|
|
443
|
+
};
|
|
444
|
+
}
|
|
445
|
+
async getResponse(response) {
|
|
446
|
+
const res = response;
|
|
447
|
+
if (typeof res?.audio?.url !== "string") {
|
|
448
|
+
throw new InferenceOutputError(
|
|
449
|
+
`Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
|
|
450
|
+
);
|
|
451
|
+
}
|
|
452
|
+
try {
|
|
453
|
+
const urlResponse = await fetch(res.audio.url);
|
|
454
|
+
if (!urlResponse.ok) {
|
|
455
|
+
throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
|
|
456
|
+
}
|
|
457
|
+
return await urlResponse.blob();
|
|
458
|
+
} catch (error) {
|
|
459
|
+
throw new InferenceOutputError(
|
|
460
|
+
`Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)}`
|
|
461
|
+
);
|
|
462
|
+
}
|
|
463
|
+
}
|
|
147
464
|
};
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
465
|
+
|
|
466
|
+
// src/providers/fireworks-ai.ts
|
|
467
|
+
var FireworksConversationalTask = class extends BaseConversationalTask {
|
|
468
|
+
constructor() {
|
|
469
|
+
super("fireworks-ai", "https://api.fireworks.ai");
|
|
470
|
+
}
|
|
471
|
+
makeRoute() {
|
|
472
|
+
return "/inference/v1/chat/completions";
|
|
473
|
+
}
|
|
153
474
|
};
|
|
154
475
|
|
|
155
|
-
// src/providers/
|
|
156
|
-
var
|
|
157
|
-
|
|
158
|
-
|
|
476
|
+
// src/providers/hf-inference.ts
|
|
477
|
+
var HFInferenceTask = class extends TaskProviderHelper {
|
|
478
|
+
constructor() {
|
|
479
|
+
super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
|
|
480
|
+
}
|
|
481
|
+
preparePayload(params) {
|
|
482
|
+
return params.args;
|
|
483
|
+
}
|
|
484
|
+
makeUrl(params) {
|
|
485
|
+
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
|
|
486
|
+
return params.model;
|
|
487
|
+
}
|
|
488
|
+
return super.makeUrl(params);
|
|
489
|
+
}
|
|
490
|
+
makeRoute(params) {
|
|
491
|
+
if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
|
|
492
|
+
return `pipeline/${params.task}/${params.model}`;
|
|
493
|
+
}
|
|
494
|
+
return `models/${params.model}`;
|
|
495
|
+
}
|
|
496
|
+
async getResponse(response) {
|
|
497
|
+
return response;
|
|
498
|
+
}
|
|
159
499
|
};
|
|
160
|
-
var
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
500
|
+
var HFInferenceTextToImageTask = class extends HFInferenceTask {
|
|
501
|
+
async getResponse(response, url, headers, outputType) {
|
|
502
|
+
if (!response) {
|
|
503
|
+
throw new InferenceOutputError("response is undefined");
|
|
504
|
+
}
|
|
505
|
+
if (typeof response == "object") {
|
|
506
|
+
if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
|
|
507
|
+
const base64Data = response.data[0].b64_json;
|
|
508
|
+
if (outputType === "url") {
|
|
509
|
+
return `data:image/jpeg;base64,${base64Data}`;
|
|
510
|
+
}
|
|
511
|
+
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
512
|
+
return await base64Response.blob();
|
|
513
|
+
}
|
|
514
|
+
if ("output" in response && Array.isArray(response.output)) {
|
|
515
|
+
if (outputType === "url") {
|
|
516
|
+
return response.output[0];
|
|
517
|
+
}
|
|
518
|
+
const urlResponse = await fetch(response.output[0]);
|
|
519
|
+
const blob = await urlResponse.blob();
|
|
520
|
+
return blob;
|
|
521
|
+
}
|
|
522
|
+
}
|
|
523
|
+
if (response instanceof Blob) {
|
|
524
|
+
if (outputType === "url") {
|
|
525
|
+
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
|
|
526
|
+
return `data:image/jpeg;base64,${b64}`;
|
|
527
|
+
}
|
|
528
|
+
return response;
|
|
529
|
+
}
|
|
530
|
+
throw new InferenceOutputError("Expected a Blob ");
|
|
531
|
+
}
|
|
532
|
+
};
|
|
533
|
+
var HFInferenceConversationalTask = class extends HFInferenceTask {
|
|
534
|
+
makeUrl(params) {
|
|
535
|
+
let url;
|
|
536
|
+
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
|
|
537
|
+
url = params.model.trim();
|
|
538
|
+
} else {
|
|
539
|
+
url = `${this.makeBaseUrl(params)}/models/${params.model}`;
|
|
540
|
+
}
|
|
541
|
+
url = url.replace(/\/+$/, "");
|
|
542
|
+
if (url.endsWith("/v1")) {
|
|
543
|
+
url += "/chat/completions";
|
|
544
|
+
} else if (!url.endsWith("/chat/completions")) {
|
|
545
|
+
url += "/v1/chat/completions";
|
|
546
|
+
}
|
|
547
|
+
return url;
|
|
548
|
+
}
|
|
549
|
+
preparePayload(params) {
|
|
550
|
+
return {
|
|
551
|
+
...params.args,
|
|
552
|
+
model: params.model
|
|
553
|
+
};
|
|
554
|
+
}
|
|
555
|
+
async getResponse(response) {
|
|
556
|
+
return response;
|
|
557
|
+
}
|
|
165
558
|
};
|
|
166
|
-
var
|
|
167
|
-
|
|
559
|
+
var HFInferenceTextGenerationTask = class extends HFInferenceTask {
|
|
560
|
+
async getResponse(response) {
|
|
561
|
+
const res = toArray(response);
|
|
562
|
+
if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
|
|
563
|
+
return res?.[0];
|
|
564
|
+
}
|
|
565
|
+
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
566
|
+
}
|
|
168
567
|
};
|
|
169
|
-
var
|
|
170
|
-
|
|
568
|
+
var HFInferenceAudioClassificationTask = class extends HFInferenceTask {
|
|
569
|
+
async getResponse(response) {
|
|
570
|
+
if (Array.isArray(response) && response.every(
|
|
571
|
+
(x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
|
|
572
|
+
)) {
|
|
573
|
+
return response;
|
|
574
|
+
}
|
|
575
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
|
|
576
|
+
}
|
|
171
577
|
};
|
|
172
|
-
var
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
makeUrl: makeUrl3
|
|
578
|
+
var HFInferenceAutomaticSpeechRecognitionTask = class extends HFInferenceTask {
|
|
579
|
+
async getResponse(response) {
|
|
580
|
+
return response;
|
|
581
|
+
}
|
|
177
582
|
};
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
583
|
+
var HFInferenceAudioToAudioTask = class extends HFInferenceTask {
|
|
584
|
+
async getResponse(response) {
|
|
585
|
+
if (!Array.isArray(response)) {
|
|
586
|
+
throw new InferenceOutputError("Expected Array");
|
|
587
|
+
}
|
|
588
|
+
if (!response.every((elem) => {
|
|
589
|
+
return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
|
|
590
|
+
})) {
|
|
591
|
+
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
|
|
592
|
+
}
|
|
593
|
+
return response;
|
|
594
|
+
}
|
|
595
|
+
};
|
|
596
|
+
var HFInferenceDocumentQuestionAnsweringTask = class extends HFInferenceTask {
|
|
597
|
+
async getResponse(response) {
|
|
598
|
+
if (Array.isArray(response) && response.every(
|
|
599
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
|
|
600
|
+
)) {
|
|
601
|
+
return response[0];
|
|
602
|
+
}
|
|
603
|
+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
604
|
+
}
|
|
605
|
+
};
|
|
606
|
+
var HFInferenceFeatureExtractionTask = class extends HFInferenceTask {
|
|
607
|
+
async getResponse(response) {
|
|
608
|
+
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
|
|
609
|
+
if (curDepth > maxDepth)
|
|
610
|
+
return false;
|
|
611
|
+
if (arr.every((x) => Array.isArray(x))) {
|
|
612
|
+
return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
|
|
613
|
+
} else {
|
|
614
|
+
return arr.every((x) => typeof x === "number");
|
|
615
|
+
}
|
|
616
|
+
};
|
|
617
|
+
if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
|
|
618
|
+
return response;
|
|
619
|
+
}
|
|
620
|
+
throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
|
|
621
|
+
}
|
|
622
|
+
};
|
|
623
|
+
var HFInferenceImageClassificationTask = class extends HFInferenceTask {
|
|
624
|
+
async getResponse(response) {
|
|
625
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
|
|
626
|
+
return response;
|
|
627
|
+
}
|
|
628
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
629
|
+
}
|
|
630
|
+
};
|
|
631
|
+
var HFInferenceImageSegmentationTask = class extends HFInferenceTask {
|
|
632
|
+
async getResponse(response) {
|
|
633
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) {
|
|
634
|
+
return response;
|
|
635
|
+
}
|
|
636
|
+
throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
|
|
637
|
+
}
|
|
638
|
+
};
|
|
639
|
+
var HFInferenceImageToTextTask = class extends HFInferenceTask {
|
|
640
|
+
async getResponse(response) {
|
|
641
|
+
if (typeof response?.generated_text !== "string") {
|
|
642
|
+
throw new InferenceOutputError("Expected {generated_text: string}");
|
|
643
|
+
}
|
|
644
|
+
return response;
|
|
645
|
+
}
|
|
646
|
+
};
|
|
647
|
+
var HFInferenceImageToImageTask = class extends HFInferenceTask {
|
|
648
|
+
async getResponse(response) {
|
|
649
|
+
if (response instanceof Blob) {
|
|
650
|
+
return response;
|
|
651
|
+
}
|
|
652
|
+
throw new InferenceOutputError("Expected Blob");
|
|
653
|
+
}
|
|
654
|
+
};
|
|
655
|
+
var HFInferenceObjectDetectionTask = class extends HFInferenceTask {
|
|
656
|
+
async getResponse(response) {
|
|
657
|
+
if (Array.isArray(response) && response.every(
|
|
658
|
+
(x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
|
|
659
|
+
)) {
|
|
660
|
+
return response;
|
|
661
|
+
}
|
|
662
|
+
throw new InferenceOutputError(
|
|
663
|
+
"Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
|
|
184
664
|
);
|
|
185
|
-
this.name = "InferenceOutputError";
|
|
186
665
|
}
|
|
187
666
|
};
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
}
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
var makeBaseUrl4 = (task) => {
|
|
205
|
-
return task === "text-to-video" ? FAL_AI_API_BASE_URL_QUEUE : FAL_AI_API_BASE_URL;
|
|
667
|
+
var HFInferenceZeroShotImageClassificationTask = class extends HFInferenceTask {
|
|
668
|
+
async getResponse(response) {
|
|
669
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
|
|
670
|
+
return response;
|
|
671
|
+
}
|
|
672
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
673
|
+
}
|
|
674
|
+
};
|
|
675
|
+
var HFInferenceTextClassificationTask = class extends HFInferenceTask {
|
|
676
|
+
async getResponse(response) {
|
|
677
|
+
const output = response?.[0];
|
|
678
|
+
if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
|
|
679
|
+
return output;
|
|
680
|
+
}
|
|
681
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
682
|
+
}
|
|
206
683
|
};
|
|
207
|
-
var
|
|
208
|
-
|
|
684
|
+
var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask {
|
|
685
|
+
async getResponse(response) {
|
|
686
|
+
if (Array.isArray(response) ? response.every(
|
|
687
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
|
|
688
|
+
) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number") {
|
|
689
|
+
return Array.isArray(response) ? response[0] : response;
|
|
690
|
+
}
|
|
691
|
+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
692
|
+
}
|
|
209
693
|
};
|
|
210
|
-
var
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
694
|
+
var HFInferenceFillMaskTask = class extends HFInferenceTask {
|
|
695
|
+
async getResponse(response) {
|
|
696
|
+
if (Array.isArray(response) && response.every(
|
|
697
|
+
(x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
|
|
698
|
+
)) {
|
|
699
|
+
return response;
|
|
700
|
+
}
|
|
701
|
+
throw new InferenceOutputError(
|
|
702
|
+
"Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
|
|
703
|
+
);
|
|
704
|
+
}
|
|
214
705
|
};
|
|
215
|
-
var
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
};
|
|
222
|
-
var FAL_AI_CONFIG = {
|
|
223
|
-
makeBaseUrl: makeBaseUrl4,
|
|
224
|
-
makeBody: makeBody4,
|
|
225
|
-
makeHeaders: makeHeaders4,
|
|
226
|
-
makeUrl: makeUrl4
|
|
227
|
-
};
|
|
228
|
-
async function pollFalResponse(res, url, headers) {
|
|
229
|
-
const requestId = res.request_id;
|
|
230
|
-
if (!requestId) {
|
|
231
|
-
throw new InferenceOutputError("No request ID found in the response");
|
|
232
|
-
}
|
|
233
|
-
let status = res.status;
|
|
234
|
-
const parsedUrl = new URL(url);
|
|
235
|
-
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
|
|
236
|
-
const modelId = new URL(res.response_url).pathname;
|
|
237
|
-
const queryParams = parsedUrl.search;
|
|
238
|
-
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
|
|
239
|
-
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
|
|
240
|
-
while (status !== "COMPLETED") {
|
|
241
|
-
await delay(500);
|
|
242
|
-
const statusResponse = await fetch(statusUrl, { headers });
|
|
243
|
-
if (!statusResponse.ok) {
|
|
244
|
-
throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
|
|
706
|
+
var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask {
|
|
707
|
+
async getResponse(response) {
|
|
708
|
+
if (Array.isArray(response) && response.every(
|
|
709
|
+
(x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
|
|
710
|
+
)) {
|
|
711
|
+
return response;
|
|
245
712
|
}
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
713
|
+
throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
|
|
714
|
+
}
|
|
715
|
+
};
|
|
716
|
+
var HFInferenceSentenceSimilarityTask = class extends HFInferenceTask {
|
|
717
|
+
async getResponse(response) {
|
|
718
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
719
|
+
return response;
|
|
250
720
|
}
|
|
721
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
251
722
|
}
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
723
|
+
};
|
|
724
|
+
var HFInferenceTableQuestionAnsweringTask = class extends HFInferenceTask {
|
|
725
|
+
static validate(elem) {
|
|
726
|
+
return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
|
|
727
|
+
(coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
|
|
728
|
+
);
|
|
258
729
|
}
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
730
|
+
async getResponse(response) {
|
|
731
|
+
if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) {
|
|
732
|
+
return Array.isArray(response) ? response[0] : response;
|
|
733
|
+
}
|
|
263
734
|
throw new InferenceOutputError(
|
|
264
|
-
"Expected {
|
|
735
|
+
"Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
|
|
265
736
|
);
|
|
266
737
|
}
|
|
267
|
-
}
|
|
268
|
-
|
|
269
|
-
// src/providers/fireworks-ai.ts
|
|
270
|
-
var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai";
|
|
271
|
-
var makeBaseUrl5 = () => {
|
|
272
|
-
return FIREWORKS_AI_API_BASE_URL;
|
|
273
|
-
};
|
|
274
|
-
var makeBody5 = (params) => {
|
|
275
|
-
return {
|
|
276
|
-
...params.args,
|
|
277
|
-
...params.chatCompletion ? { model: params.model } : void 0
|
|
278
|
-
};
|
|
279
738
|
};
|
|
280
|
-
var
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
739
|
+
var HFInferenceTokenClassificationTask = class extends HFInferenceTask {
|
|
740
|
+
async getResponse(response) {
|
|
741
|
+
if (Array.isArray(response) && response.every(
|
|
742
|
+
(x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
|
|
743
|
+
)) {
|
|
744
|
+
return response;
|
|
745
|
+
}
|
|
746
|
+
throw new InferenceOutputError(
|
|
747
|
+
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
|
|
748
|
+
);
|
|
286
749
|
}
|
|
287
|
-
return `${params.baseUrl}/inference`;
|
|
288
750
|
};
|
|
289
|
-
var
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
751
|
+
var HFInferenceTranslationTask = class extends HFInferenceTask {
|
|
752
|
+
async getResponse(response) {
|
|
753
|
+
if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
|
|
754
|
+
return response?.length === 1 ? response?.[0] : response;
|
|
755
|
+
}
|
|
756
|
+
throw new InferenceOutputError("Expected Array<{translation_text: string}>");
|
|
757
|
+
}
|
|
294
758
|
};
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
759
|
+
var HFInferenceSummarizationTask = class extends HFInferenceTask {
|
|
760
|
+
async getResponse(response) {
|
|
761
|
+
if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
|
|
762
|
+
return response?.[0];
|
|
763
|
+
}
|
|
764
|
+
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
|
765
|
+
}
|
|
299
766
|
};
|
|
300
|
-
var
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
};
|
|
767
|
+
var HFInferenceTextToSpeechTask = class extends HFInferenceTask {
|
|
768
|
+
async getResponse(response) {
|
|
769
|
+
return response;
|
|
770
|
+
}
|
|
305
771
|
};
|
|
306
|
-
var
|
|
307
|
-
|
|
772
|
+
var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
|
|
773
|
+
async getResponse(response) {
|
|
774
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
775
|
+
return response;
|
|
776
|
+
}
|
|
777
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
778
|
+
}
|
|
308
779
|
};
|
|
309
|
-
var
|
|
310
|
-
|
|
311
|
-
|
|
780
|
+
var HFInferenceVisualQuestionAnsweringTask = class extends HFInferenceTask {
|
|
781
|
+
async getResponse(response) {
|
|
782
|
+
if (Array.isArray(response) && response.every(
|
|
783
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
|
|
784
|
+
)) {
|
|
785
|
+
return response[0];
|
|
786
|
+
}
|
|
787
|
+
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
|
312
788
|
}
|
|
313
|
-
|
|
314
|
-
|
|
789
|
+
};
|
|
790
|
+
var HFInferenceTabularRegressionTask = class extends HFInferenceTask {
|
|
791
|
+
async getResponse(response) {
|
|
792
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
793
|
+
return response;
|
|
794
|
+
}
|
|
795
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
315
796
|
}
|
|
316
|
-
return `${params.baseUrl}/models/${params.model}`;
|
|
317
797
|
};
|
|
318
|
-
var
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
makeUrl: makeUrl6
|
|
798
|
+
var HFInferenceTextToAudioTask = class extends HFInferenceTask {
|
|
799
|
+
async getResponse(response) {
|
|
800
|
+
return response;
|
|
801
|
+
}
|
|
323
802
|
};
|
|
324
803
|
|
|
325
804
|
// src/providers/hyperbolic.ts
|
|
326
805
|
var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
|
|
327
|
-
var
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
return {
|
|
332
|
-
...params.args,
|
|
333
|
-
...params.task === "text-to-image" ? { model_name: params.model } : { model: params.model }
|
|
334
|
-
};
|
|
335
|
-
};
|
|
336
|
-
var makeHeaders7 = (params) => {
|
|
337
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
806
|
+
var HyperbolicConversationalTask = class extends BaseConversationalTask {
|
|
807
|
+
constructor() {
|
|
808
|
+
super("hyperbolic", HYPERBOLIC_API_BASE_URL);
|
|
809
|
+
}
|
|
338
810
|
};
|
|
339
|
-
var
|
|
340
|
-
|
|
341
|
-
|
|
811
|
+
var HyperbolicTextGenerationTask = class extends BaseTextGenerationTask {
|
|
812
|
+
constructor() {
|
|
813
|
+
super("hyperbolic", HYPERBOLIC_API_BASE_URL);
|
|
814
|
+
}
|
|
815
|
+
makeRoute() {
|
|
816
|
+
return "v1/chat/completions";
|
|
817
|
+
}
|
|
818
|
+
preparePayload(params) {
|
|
819
|
+
return {
|
|
820
|
+
messages: [{ content: params.args.inputs, role: "user" }],
|
|
821
|
+
...params.args.parameters ? {
|
|
822
|
+
max_tokens: params.args.parameters.max_new_tokens,
|
|
823
|
+
...omit(params.args.parameters, "max_new_tokens")
|
|
824
|
+
} : void 0,
|
|
825
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
826
|
+
model: params.model
|
|
827
|
+
};
|
|
828
|
+
}
|
|
829
|
+
async getResponse(response) {
|
|
830
|
+
if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
|
|
831
|
+
const completion = response.choices[0];
|
|
832
|
+
return {
|
|
833
|
+
generated_text: completion.message.content
|
|
834
|
+
};
|
|
835
|
+
}
|
|
836
|
+
throw new InferenceOutputError("Expected Hyperbolic text generation response format");
|
|
342
837
|
}
|
|
343
|
-
return `${params.baseUrl}/v1/chat/completions`;
|
|
344
838
|
};
|
|
345
|
-
var
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
839
|
+
var HyperbolicTextToImageTask = class extends TaskProviderHelper {
|
|
840
|
+
constructor() {
|
|
841
|
+
super("hyperbolic", HYPERBOLIC_API_BASE_URL);
|
|
842
|
+
}
|
|
843
|
+
makeRoute(params) {
|
|
844
|
+
return `/v1/images/generations`;
|
|
845
|
+
}
|
|
846
|
+
preparePayload(params) {
|
|
847
|
+
return {
|
|
848
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
849
|
+
...params.args.parameters,
|
|
850
|
+
prompt: params.args.inputs,
|
|
851
|
+
model_name: params.model
|
|
852
|
+
};
|
|
853
|
+
}
|
|
854
|
+
async getResponse(response, url, headers, outputType) {
|
|
855
|
+
if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images[0] && typeof response.images[0].image === "string") {
|
|
856
|
+
if (outputType === "url") {
|
|
857
|
+
return `data:image/jpeg;base64,${response.images[0].image}`;
|
|
858
|
+
}
|
|
859
|
+
return fetch(`data:image/jpeg;base64,${response.images[0].image}`).then((res) => res.blob());
|
|
860
|
+
}
|
|
861
|
+
throw new InferenceOutputError("Expected Hyperbolic text-to-image response format");
|
|
862
|
+
}
|
|
350
863
|
};
|
|
351
864
|
|
|
352
865
|
// src/providers/nebius.ts
|
|
353
866
|
var NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
|
|
354
|
-
var
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
return {
|
|
359
|
-
...params.args,
|
|
360
|
-
model: params.model
|
|
361
|
-
};
|
|
867
|
+
var NebiusConversationalTask = class extends BaseConversationalTask {
|
|
868
|
+
constructor() {
|
|
869
|
+
super("nebius", NEBIUS_API_BASE_URL);
|
|
870
|
+
}
|
|
362
871
|
};
|
|
363
|
-
var
|
|
364
|
-
|
|
872
|
+
var NebiusTextGenerationTask = class extends BaseTextGenerationTask {
|
|
873
|
+
constructor() {
|
|
874
|
+
super("nebius", NEBIUS_API_BASE_URL);
|
|
875
|
+
}
|
|
365
876
|
};
|
|
366
|
-
var
|
|
367
|
-
|
|
368
|
-
|
|
877
|
+
var NebiusTextToImageTask = class extends TaskProviderHelper {
|
|
878
|
+
constructor() {
|
|
879
|
+
super("nebius", NEBIUS_API_BASE_URL);
|
|
880
|
+
}
|
|
881
|
+
preparePayload(params) {
|
|
882
|
+
return {
|
|
883
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
884
|
+
...params.args.parameters,
|
|
885
|
+
response_format: "b64_json",
|
|
886
|
+
prompt: params.args.inputs,
|
|
887
|
+
model: params.model
|
|
888
|
+
};
|
|
369
889
|
}
|
|
370
|
-
|
|
371
|
-
return
|
|
890
|
+
makeRoute(params) {
|
|
891
|
+
return "v1/images/generations";
|
|
372
892
|
}
|
|
373
|
-
|
|
374
|
-
|
|
893
|
+
async getResponse(response, url, headers, outputType) {
|
|
894
|
+
if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
|
|
895
|
+
const base64Data = response.data[0].b64_json;
|
|
896
|
+
if (outputType === "url") {
|
|
897
|
+
return `data:image/jpeg;base64,${base64Data}`;
|
|
898
|
+
}
|
|
899
|
+
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
|
|
900
|
+
}
|
|
901
|
+
throw new InferenceOutputError("Expected Nebius text-to-image response format");
|
|
375
902
|
}
|
|
376
|
-
return params.baseUrl;
|
|
377
|
-
};
|
|
378
|
-
var NEBIUS_CONFIG = {
|
|
379
|
-
makeBaseUrl: makeBaseUrl8,
|
|
380
|
-
makeBody: makeBody8,
|
|
381
|
-
makeHeaders: makeHeaders8,
|
|
382
|
-
makeUrl: makeUrl8
|
|
383
903
|
};
|
|
384
904
|
|
|
385
905
|
// src/providers/novita.ts
|
|
386
906
|
var NOVITA_API_BASE_URL = "https://api.novita.ai";
|
|
387
|
-
var
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
};
|
|
395
|
-
};
|
|
396
|
-
var makeHeaders9 = (params) => {
|
|
397
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
907
|
+
var NovitaTextGenerationTask = class extends BaseTextGenerationTask {
|
|
908
|
+
constructor() {
|
|
909
|
+
super("novita", NOVITA_API_BASE_URL);
|
|
910
|
+
}
|
|
911
|
+
makeRoute() {
|
|
912
|
+
return "/v3/openai/chat/completions";
|
|
913
|
+
}
|
|
398
914
|
};
|
|
399
|
-
var
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
}
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
return `${params.baseUrl}/v3/hf/${params.model}`;
|
|
915
|
+
var NovitaConversationalTask = class extends BaseConversationalTask {
|
|
916
|
+
constructor() {
|
|
917
|
+
super("novita", NOVITA_API_BASE_URL);
|
|
918
|
+
}
|
|
919
|
+
makeRoute() {
|
|
920
|
+
return "/v3/openai/chat/completions";
|
|
406
921
|
}
|
|
407
|
-
return params.baseUrl;
|
|
408
922
|
};
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
923
|
+
|
|
924
|
+
// src/providers/openai.ts
|
|
925
|
+
var OPENAI_API_BASE_URL = "https://api.openai.com";
|
|
926
|
+
var OpenAIConversationalTask = class extends BaseConversationalTask {
|
|
927
|
+
constructor() {
|
|
928
|
+
super("openai", OPENAI_API_BASE_URL, true);
|
|
929
|
+
}
|
|
414
930
|
};
|
|
415
931
|
|
|
416
932
|
// src/providers/replicate.ts
|
|
417
|
-
var
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
}
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
933
|
+
var ReplicateTask = class extends TaskProviderHelper {
|
|
934
|
+
constructor(url) {
|
|
935
|
+
super("replicate", url || "https://api.replicate.com");
|
|
936
|
+
}
|
|
937
|
+
makeRoute(params) {
|
|
938
|
+
if (params.model.includes(":")) {
|
|
939
|
+
return "v1/predictions";
|
|
940
|
+
}
|
|
941
|
+
return `v1/models/${params.model}/predictions`;
|
|
942
|
+
}
|
|
943
|
+
preparePayload(params) {
|
|
944
|
+
return {
|
|
945
|
+
input: {
|
|
946
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
947
|
+
...params.args.parameters,
|
|
948
|
+
prompt: params.args.inputs
|
|
949
|
+
},
|
|
950
|
+
version: params.model.includes(":") ? params.model.split(":")[1] : void 0
|
|
951
|
+
};
|
|
952
|
+
}
|
|
953
|
+
prepareHeaders(params, binary) {
|
|
954
|
+
const headers = { Authorization: `Bearer ${params.accessToken}`, Prefer: "wait" };
|
|
955
|
+
if (!binary) {
|
|
956
|
+
headers["Content-Type"] = "application/json";
|
|
957
|
+
}
|
|
958
|
+
return headers;
|
|
959
|
+
}
|
|
960
|
+
makeUrl(params) {
|
|
961
|
+
const baseUrl = this.makeBaseUrl(params);
|
|
962
|
+
if (params.model.includes(":")) {
|
|
963
|
+
return `${baseUrl}/v1/predictions`;
|
|
964
|
+
}
|
|
965
|
+
return `${baseUrl}/v1/models/${params.model}/predictions`;
|
|
966
|
+
}
|
|
426
967
|
};
|
|
427
|
-
var
|
|
428
|
-
|
|
968
|
+
var ReplicateTextToImageTask = class extends ReplicateTask {
|
|
969
|
+
async getResponse(res, url, headers, outputType) {
|
|
970
|
+
if (typeof res === "object" && "output" in res && Array.isArray(res.output) && res.output.length > 0 && typeof res.output[0] === "string") {
|
|
971
|
+
if (outputType === "url") {
|
|
972
|
+
return res.output[0];
|
|
973
|
+
}
|
|
974
|
+
const urlResponse = await fetch(res.output[0]);
|
|
975
|
+
return await urlResponse.blob();
|
|
976
|
+
}
|
|
977
|
+
throw new InferenceOutputError("Expected Replicate text-to-image response format");
|
|
978
|
+
}
|
|
429
979
|
};
|
|
430
|
-
var
|
|
431
|
-
|
|
432
|
-
|
|
980
|
+
var ReplicateTextToSpeechTask = class extends ReplicateTask {
|
|
981
|
+
preparePayload(params) {
|
|
982
|
+
const payload = super.preparePayload(params);
|
|
983
|
+
const input = payload["input"];
|
|
984
|
+
if (typeof input === "object" && input !== null && "prompt" in input) {
|
|
985
|
+
const inputObj = input;
|
|
986
|
+
inputObj["text"] = inputObj["prompt"];
|
|
987
|
+
delete inputObj["prompt"];
|
|
988
|
+
}
|
|
989
|
+
return payload;
|
|
990
|
+
}
|
|
991
|
+
async getResponse(response) {
|
|
992
|
+
if (response instanceof Blob) {
|
|
993
|
+
return response;
|
|
994
|
+
}
|
|
995
|
+
if (response && typeof response === "object") {
|
|
996
|
+
if ("output" in response) {
|
|
997
|
+
if (typeof response.output === "string") {
|
|
998
|
+
const urlResponse = await fetch(response.output);
|
|
999
|
+
return await urlResponse.blob();
|
|
1000
|
+
} else if (Array.isArray(response.output)) {
|
|
1001
|
+
const urlResponse = await fetch(response.output[0]);
|
|
1002
|
+
return await urlResponse.blob();
|
|
1003
|
+
}
|
|
1004
|
+
}
|
|
1005
|
+
}
|
|
1006
|
+
throw new InferenceOutputError("Expected Blob or object with output");
|
|
433
1007
|
}
|
|
434
|
-
return `${params.baseUrl}/v1/models/${params.model}/predictions`;
|
|
435
1008
|
};
|
|
436
|
-
var
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
1009
|
+
var ReplicateTextToVideoTask = class extends ReplicateTask {
|
|
1010
|
+
async getResponse(response) {
|
|
1011
|
+
if (typeof response === "object" && !!response && "output" in response && typeof response.output === "string" && isUrl(response.output)) {
|
|
1012
|
+
const urlResponse = await fetch(response.output);
|
|
1013
|
+
return await urlResponse.blob();
|
|
1014
|
+
}
|
|
1015
|
+
throw new InferenceOutputError("Expected { output: string }");
|
|
1016
|
+
}
|
|
441
1017
|
};
|
|
442
1018
|
|
|
443
1019
|
// src/providers/sambanova.ts
|
|
444
|
-
var
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
};
|
|
448
|
-
var makeBody11 = (params) => {
|
|
449
|
-
return {
|
|
450
|
-
...params.args,
|
|
451
|
-
...params.chatCompletion ? { model: params.model } : void 0
|
|
452
|
-
};
|
|
453
|
-
};
|
|
454
|
-
var makeHeaders11 = (params) => {
|
|
455
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
456
|
-
};
|
|
457
|
-
var makeUrl11 = (params) => {
|
|
458
|
-
if (params.chatCompletion) {
|
|
459
|
-
return `${params.baseUrl}/v1/chat/completions`;
|
|
1020
|
+
var SambanovaConversationalTask = class extends BaseConversationalTask {
|
|
1021
|
+
constructor() {
|
|
1022
|
+
super("sambanova", "https://api.sambanova.ai");
|
|
460
1023
|
}
|
|
461
|
-
return params.baseUrl;
|
|
462
|
-
};
|
|
463
|
-
var SAMBANOVA_CONFIG = {
|
|
464
|
-
makeBaseUrl: makeBaseUrl11,
|
|
465
|
-
makeBody: makeBody11,
|
|
466
|
-
makeHeaders: makeHeaders11,
|
|
467
|
-
makeUrl: makeUrl11
|
|
468
1024
|
};
|
|
469
1025
|
|
|
470
1026
|
// src/providers/together.ts
|
|
471
1027
|
var TOGETHER_API_BASE_URL = "https://api.together.xyz";
|
|
472
|
-
var
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
return {
|
|
477
|
-
...params.args,
|
|
478
|
-
model: params.model
|
|
479
|
-
};
|
|
480
|
-
};
|
|
481
|
-
var makeHeaders12 = (params) => {
|
|
482
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
1028
|
+
var TogetherConversationalTask = class extends BaseConversationalTask {
|
|
1029
|
+
constructor() {
|
|
1030
|
+
super("together", TOGETHER_API_BASE_URL);
|
|
1031
|
+
}
|
|
483
1032
|
};
|
|
484
|
-
var
|
|
485
|
-
|
|
486
|
-
|
|
1033
|
+
var TogetherTextGenerationTask = class extends BaseTextGenerationTask {
|
|
1034
|
+
constructor() {
|
|
1035
|
+
super("together", TOGETHER_API_BASE_URL);
|
|
487
1036
|
}
|
|
488
|
-
|
|
489
|
-
return
|
|
1037
|
+
preparePayload(params) {
|
|
1038
|
+
return {
|
|
1039
|
+
model: params.model,
|
|
1040
|
+
...params.args,
|
|
1041
|
+
prompt: params.args.inputs
|
|
1042
|
+
};
|
|
490
1043
|
}
|
|
491
|
-
|
|
492
|
-
|
|
1044
|
+
async getResponse(response) {
|
|
1045
|
+
if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
|
|
1046
|
+
const completion = response.choices[0];
|
|
1047
|
+
return {
|
|
1048
|
+
generated_text: completion.text
|
|
1049
|
+
};
|
|
1050
|
+
}
|
|
1051
|
+
throw new InferenceOutputError("Expected Together text generation response format");
|
|
493
1052
|
}
|
|
494
|
-
return params.baseUrl;
|
|
495
1053
|
};
|
|
496
|
-
var
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
1054
|
+
var TogetherTextToImageTask = class extends TaskProviderHelper {
|
|
1055
|
+
constructor() {
|
|
1056
|
+
super("together", TOGETHER_API_BASE_URL);
|
|
1057
|
+
}
|
|
1058
|
+
makeRoute() {
|
|
1059
|
+
return "v1/images/generations";
|
|
1060
|
+
}
|
|
1061
|
+
preparePayload(params) {
|
|
1062
|
+
return {
|
|
1063
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
1064
|
+
...params.args.parameters,
|
|
1065
|
+
prompt: params.args.inputs,
|
|
1066
|
+
response_format: "base64",
|
|
1067
|
+
model: params.model
|
|
1068
|
+
};
|
|
1069
|
+
}
|
|
1070
|
+
async getResponse(response, outputType) {
|
|
1071
|
+
if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
|
|
1072
|
+
const base64Data = response.data[0].b64_json;
|
|
1073
|
+
if (outputType === "url") {
|
|
1074
|
+
return `data:image/jpeg;base64,${base64Data}`;
|
|
1075
|
+
}
|
|
1076
|
+
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
|
|
1077
|
+
}
|
|
1078
|
+
throw new InferenceOutputError("Expected Together text-to-image response format");
|
|
1079
|
+
}
|
|
501
1080
|
};
|
|
502
1081
|
|
|
503
|
-
// src/
|
|
504
|
-
var
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
}
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
1082
|
+
// src/lib/getProviderHelper.ts
|
|
1083
|
+
var PROVIDERS = {
|
|
1084
|
+
"black-forest-labs": {
|
|
1085
|
+
"text-to-image": new BlackForestLabsTextToImageTask()
|
|
1086
|
+
},
|
|
1087
|
+
cerebras: {
|
|
1088
|
+
conversational: new CerebrasConversationalTask()
|
|
1089
|
+
},
|
|
1090
|
+
cohere: {
|
|
1091
|
+
conversational: new CohereConversationalTask()
|
|
1092
|
+
},
|
|
1093
|
+
"fal-ai": {
|
|
1094
|
+
"text-to-image": new FalAITextToImageTask(),
|
|
1095
|
+
"text-to-speech": new FalAITextToSpeechTask(),
|
|
1096
|
+
"text-to-video": new FalAITextToVideoTask(),
|
|
1097
|
+
"automatic-speech-recognition": new FalAIAutomaticSpeechRecognitionTask()
|
|
1098
|
+
},
|
|
1099
|
+
"hf-inference": {
|
|
1100
|
+
"text-to-image": new HFInferenceTextToImageTask(),
|
|
1101
|
+
conversational: new HFInferenceConversationalTask(),
|
|
1102
|
+
"text-generation": new HFInferenceTextGenerationTask(),
|
|
1103
|
+
"text-classification": new HFInferenceTextClassificationTask(),
|
|
1104
|
+
"question-answering": new HFInferenceQuestionAnsweringTask(),
|
|
1105
|
+
"audio-classification": new HFInferenceAudioClassificationTask(),
|
|
1106
|
+
"automatic-speech-recognition": new HFInferenceAutomaticSpeechRecognitionTask(),
|
|
1107
|
+
"fill-mask": new HFInferenceFillMaskTask(),
|
|
1108
|
+
"feature-extraction": new HFInferenceFeatureExtractionTask(),
|
|
1109
|
+
"image-classification": new HFInferenceImageClassificationTask(),
|
|
1110
|
+
"image-segmentation": new HFInferenceImageSegmentationTask(),
|
|
1111
|
+
"document-question-answering": new HFInferenceDocumentQuestionAnsweringTask(),
|
|
1112
|
+
"image-to-text": new HFInferenceImageToTextTask(),
|
|
1113
|
+
"object-detection": new HFInferenceObjectDetectionTask(),
|
|
1114
|
+
"audio-to-audio": new HFInferenceAudioToAudioTask(),
|
|
1115
|
+
"zero-shot-image-classification": new HFInferenceZeroShotImageClassificationTask(),
|
|
1116
|
+
"zero-shot-classification": new HFInferenceZeroShotClassificationTask(),
|
|
1117
|
+
"image-to-image": new HFInferenceImageToImageTask(),
|
|
1118
|
+
"sentence-similarity": new HFInferenceSentenceSimilarityTask(),
|
|
1119
|
+
"table-question-answering": new HFInferenceTableQuestionAnsweringTask(),
|
|
1120
|
+
"tabular-classification": new HFInferenceTabularClassificationTask(),
|
|
1121
|
+
"text-to-speech": new HFInferenceTextToSpeechTask(),
|
|
1122
|
+
"token-classification": new HFInferenceTokenClassificationTask(),
|
|
1123
|
+
translation: new HFInferenceTranslationTask(),
|
|
1124
|
+
summarization: new HFInferenceSummarizationTask(),
|
|
1125
|
+
"visual-question-answering": new HFInferenceVisualQuestionAnsweringTask(),
|
|
1126
|
+
"tabular-regression": new HFInferenceTabularRegressionTask(),
|
|
1127
|
+
"text-to-audio": new HFInferenceTextToAudioTask()
|
|
1128
|
+
},
|
|
1129
|
+
"fireworks-ai": {
|
|
1130
|
+
conversational: new FireworksConversationalTask()
|
|
1131
|
+
},
|
|
1132
|
+
hyperbolic: {
|
|
1133
|
+
"text-to-image": new HyperbolicTextToImageTask(),
|
|
1134
|
+
conversational: new HyperbolicConversationalTask(),
|
|
1135
|
+
"text-generation": new HyperbolicTextGenerationTask()
|
|
1136
|
+
},
|
|
1137
|
+
nebius: {
|
|
1138
|
+
"text-to-image": new NebiusTextToImageTask(),
|
|
1139
|
+
conversational: new NebiusConversationalTask(),
|
|
1140
|
+
"text-generation": new NebiusTextGenerationTask()
|
|
1141
|
+
},
|
|
1142
|
+
novita: {
|
|
1143
|
+
conversational: new NovitaConversationalTask(),
|
|
1144
|
+
"text-generation": new NovitaTextGenerationTask()
|
|
1145
|
+
},
|
|
1146
|
+
openai: {
|
|
1147
|
+
conversational: new OpenAIConversationalTask()
|
|
1148
|
+
},
|
|
1149
|
+
replicate: {
|
|
1150
|
+
"text-to-image": new ReplicateTextToImageTask(),
|
|
1151
|
+
"text-to-speech": new ReplicateTextToSpeechTask(),
|
|
1152
|
+
"text-to-video": new ReplicateTextToVideoTask()
|
|
1153
|
+
},
|
|
1154
|
+
sambanova: {
|
|
1155
|
+
conversational: new SambanovaConversationalTask()
|
|
1156
|
+
},
|
|
1157
|
+
together: {
|
|
1158
|
+
"text-to-image": new TogetherTextToImageTask(),
|
|
1159
|
+
conversational: new TogetherConversationalTask(),
|
|
1160
|
+
"text-generation": new TogetherTextGenerationTask()
|
|
511
1161
|
}
|
|
512
|
-
return {
|
|
513
|
-
...params.args,
|
|
514
|
-
model: params.model
|
|
515
|
-
};
|
|
516
1162
|
};
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
throw new Error("OpenAI only supports chat completions.");
|
|
1163
|
+
function getProviderHelper(provider, task) {
|
|
1164
|
+
if (provider === "hf-inference") {
|
|
1165
|
+
if (!task) {
|
|
1166
|
+
return new HFInferenceTask();
|
|
1167
|
+
}
|
|
523
1168
|
}
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
1169
|
+
if (!task) {
|
|
1170
|
+
throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'");
|
|
1171
|
+
}
|
|
1172
|
+
if (!(provider in PROVIDERS)) {
|
|
1173
|
+
throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
|
|
1174
|
+
}
|
|
1175
|
+
const providerTasks = PROVIDERS[provider];
|
|
1176
|
+
if (!providerTasks || !(task in providerTasks)) {
|
|
1177
|
+
throw new Error(
|
|
1178
|
+
`Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}`
|
|
1179
|
+
);
|
|
1180
|
+
}
|
|
1181
|
+
return providerTasks[task];
|
|
1182
|
+
}
|
|
537
1183
|
|
|
538
1184
|
// src/providers/consts.ts
|
|
539
1185
|
var HARDCODED_MODEL_ID_MAPPING = {
|
|
@@ -603,28 +1249,11 @@ async function getProviderModelId(params, args, options = {}) {
|
|
|
603
1249
|
}
|
|
604
1250
|
|
|
605
1251
|
// src/lib/makeRequestOptions.ts
|
|
606
|
-
var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
|
|
607
1252
|
var tasks = null;
|
|
608
|
-
var providerConfigs = {
|
|
609
|
-
"black-forest-labs": BLACK_FOREST_LABS_CONFIG,
|
|
610
|
-
cerebras: CEREBRAS_CONFIG,
|
|
611
|
-
cohere: COHERE_CONFIG,
|
|
612
|
-
"fal-ai": FAL_AI_CONFIG,
|
|
613
|
-
"fireworks-ai": FIREWORKS_AI_CONFIG,
|
|
614
|
-
"hf-inference": HF_INFERENCE_CONFIG,
|
|
615
|
-
hyperbolic: HYPERBOLIC_CONFIG,
|
|
616
|
-
openai: OPENAI_CONFIG,
|
|
617
|
-
nebius: NEBIUS_CONFIG,
|
|
618
|
-
novita: NOVITA_CONFIG,
|
|
619
|
-
replicate: REPLICATE_CONFIG,
|
|
620
|
-
sambanova: SAMBANOVA_CONFIG,
|
|
621
|
-
together: TOGETHER_CONFIG
|
|
622
|
-
};
|
|
623
1253
|
async function makeRequestOptions(args, options) {
|
|
624
1254
|
const { provider: maybeProvider, model: maybeModel } = args;
|
|
625
1255
|
const provider = maybeProvider ?? "hf-inference";
|
|
626
|
-
const
|
|
627
|
-
const { task, chatCompletion: chatCompletion2 } = options ?? {};
|
|
1256
|
+
const { task } = options ?? {};
|
|
628
1257
|
if (args.endpointUrl && provider !== "hf-inference") {
|
|
629
1258
|
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
|
|
630
1259
|
}
|
|
@@ -634,19 +1263,16 @@ async function makeRequestOptions(args, options) {
|
|
|
634
1263
|
if (!maybeModel && !task) {
|
|
635
1264
|
throw new Error("No model provided, and no task has been specified.");
|
|
636
1265
|
}
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
if (providerConfig.clientSideRoutingOnly && !maybeModel) {
|
|
1266
|
+
const hfModel = maybeModel ?? await loadDefaultModel(task);
|
|
1267
|
+
const providerHelper = getProviderHelper(provider, task);
|
|
1268
|
+
if (providerHelper.clientSideRoutingOnly && !maybeModel) {
|
|
641
1269
|
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
|
|
642
1270
|
}
|
|
643
|
-
const
|
|
644
|
-
const resolvedModel = providerConfig.clientSideRoutingOnly ? (
|
|
1271
|
+
const resolvedModel = providerHelper.clientSideRoutingOnly ? (
|
|
645
1272
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
646
1273
|
removeProviderPrefix(maybeModel, provider)
|
|
647
1274
|
) : await getProviderModelId({ model: hfModel, provider }, args, {
|
|
648
1275
|
task,
|
|
649
|
-
chatCompletion: chatCompletion2,
|
|
650
1276
|
fetch: options?.fetch
|
|
651
1277
|
});
|
|
652
1278
|
return makeRequestOptionsFromResolvedModel(resolvedModel, args, options);
|
|
@@ -654,10 +1280,10 @@ async function makeRequestOptions(args, options) {
|
|
|
654
1280
|
function makeRequestOptionsFromResolvedModel(resolvedModel, args, options) {
|
|
655
1281
|
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
|
|
656
1282
|
const provider = maybeProvider ?? "hf-inference";
|
|
657
|
-
const
|
|
658
|
-
const
|
|
1283
|
+
const { includeCredentials, task, signal, billTo } = options ?? {};
|
|
1284
|
+
const providerHelper = getProviderHelper(provider, task);
|
|
659
1285
|
const authMethod = (() => {
|
|
660
|
-
if (
|
|
1286
|
+
if (providerHelper.clientSideRoutingOnly) {
|
|
661
1287
|
if (accessToken && accessToken.startsWith("hf_")) {
|
|
662
1288
|
throw new Error(`Provider ${provider} is closed-source and does not support HF tokens.`);
|
|
663
1289
|
}
|
|
@@ -671,35 +1297,30 @@ function makeRequestOptionsFromResolvedModel(resolvedModel, args, options) {
|
|
|
671
1297
|
}
|
|
672
1298
|
return "none";
|
|
673
1299
|
})();
|
|
674
|
-
const
|
|
1300
|
+
const modelId = endpointUrl ?? resolvedModel;
|
|
1301
|
+
const url = providerHelper.makeUrl({
|
|
675
1302
|
authMethod,
|
|
676
|
-
|
|
677
|
-
model: resolvedModel,
|
|
678
|
-
chatCompletion: chatCompletion2,
|
|
1303
|
+
model: modelId,
|
|
679
1304
|
task
|
|
680
1305
|
});
|
|
681
|
-
const
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
1306
|
+
const headers = providerHelper.prepareHeaders(
|
|
1307
|
+
{
|
|
1308
|
+
accessToken,
|
|
1309
|
+
authMethod
|
|
1310
|
+
},
|
|
1311
|
+
"data" in args && !!args.data
|
|
1312
|
+
);
|
|
686
1313
|
if (billTo) {
|
|
687
1314
|
headers[HF_HEADER_X_BILL_TO] = billTo;
|
|
688
1315
|
}
|
|
689
|
-
if (!binary) {
|
|
690
|
-
headers["Content-Type"] = "application/json";
|
|
691
|
-
}
|
|
692
1316
|
const ownUserAgent = `${name}/${version}`;
|
|
693
1317
|
const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : void 0].filter((x) => x !== void 0).join(" ");
|
|
694
1318
|
headers["User-Agent"] = userAgent;
|
|
695
|
-
const body =
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
chatCompletion: chatCompletion2
|
|
701
|
-
})
|
|
702
|
-
);
|
|
1319
|
+
const body = providerHelper.makeBody({
|
|
1320
|
+
args: remainingArgs,
|
|
1321
|
+
model: resolvedModel,
|
|
1322
|
+
task
|
|
1323
|
+
});
|
|
703
1324
|
let credentials;
|
|
704
1325
|
if (typeof includeCredentials === "string") {
|
|
705
1326
|
credentials = includeCredentials;
|
|
@@ -961,30 +1582,6 @@ async function* streamingRequest(args, options) {
|
|
|
961
1582
|
yield* innerStreamingRequest(args, options);
|
|
962
1583
|
}
|
|
963
1584
|
|
|
964
|
-
// src/utils/pick.ts
|
|
965
|
-
function pick(o, props) {
|
|
966
|
-
return Object.assign(
|
|
967
|
-
{},
|
|
968
|
-
...props.map((prop) => {
|
|
969
|
-
if (o[prop] !== void 0) {
|
|
970
|
-
return { [prop]: o[prop] };
|
|
971
|
-
}
|
|
972
|
-
})
|
|
973
|
-
);
|
|
974
|
-
}
|
|
975
|
-
|
|
976
|
-
// src/utils/typedInclude.ts
|
|
977
|
-
function typedInclude(arr, v) {
|
|
978
|
-
return arr.includes(v);
|
|
979
|
-
}
|
|
980
|
-
|
|
981
|
-
// src/utils/omit.ts
|
|
982
|
-
function omit(o, props) {
|
|
983
|
-
const propsArr = Array.isArray(props) ? props : [props];
|
|
984
|
-
const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
|
|
985
|
-
return pick(o, letsKeep);
|
|
986
|
-
}
|
|
987
|
-
|
|
988
1585
|
// src/tasks/audio/utils.ts
|
|
989
1586
|
function preparePayload(args) {
|
|
990
1587
|
return "data" in args ? args : {
|
|
@@ -995,16 +1592,24 @@ function preparePayload(args) {
|
|
|
995
1592
|
|
|
996
1593
|
// src/tasks/audio/audioClassification.ts
|
|
997
1594
|
async function audioClassification(args, options) {
|
|
1595
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
|
|
998
1596
|
const payload = preparePayload(args);
|
|
999
1597
|
const { data: res } = await innerRequest(payload, {
|
|
1000
1598
|
...options,
|
|
1001
1599
|
task: "audio-classification"
|
|
1002
1600
|
});
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1601
|
+
return providerHelper.getResponse(res);
|
|
1602
|
+
}
|
|
1603
|
+
|
|
1604
|
+
// src/tasks/audio/audioToAudio.ts
|
|
1605
|
+
async function audioToAudio(args, options) {
|
|
1606
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
|
|
1607
|
+
const payload = preparePayload(args);
|
|
1608
|
+
const { data: res } = await innerRequest(payload, {
|
|
1609
|
+
...options,
|
|
1610
|
+
task: "audio-to-audio"
|
|
1611
|
+
});
|
|
1612
|
+
return providerHelper.getResponse(res);
|
|
1008
1613
|
}
|
|
1009
1614
|
|
|
1010
1615
|
// src/utils/base64FromBytes.ts
|
|
@@ -1022,6 +1627,7 @@ function base64FromBytes(arr) {
|
|
|
1022
1627
|
|
|
1023
1628
|
// src/tasks/audio/automaticSpeechRecognition.ts
|
|
1024
1629
|
async function automaticSpeechRecognition(args, options) {
|
|
1630
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
|
|
1025
1631
|
const payload = await buildPayload(args);
|
|
1026
1632
|
const { data: res } = await innerRequest(payload, {
|
|
1027
1633
|
...options,
|
|
@@ -1031,9 +1637,8 @@ async function automaticSpeechRecognition(args, options) {
|
|
|
1031
1637
|
if (!isValidOutput) {
|
|
1032
1638
|
throw new InferenceOutputError("Expected {text: string}");
|
|
1033
1639
|
}
|
|
1034
|
-
return res;
|
|
1640
|
+
return providerHelper.getResponse(res);
|
|
1035
1641
|
}
|
|
1036
|
-
var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
|
|
1037
1642
|
async function buildPayload(args) {
|
|
1038
1643
|
if (args.provider === "fal-ai") {
|
|
1039
1644
|
const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : void 0;
|
|
@@ -1052,63 +1657,23 @@ async function buildPayload(args) {
|
|
|
1052
1657
|
}
|
|
1053
1658
|
const base64audio = base64FromBytes(new Uint8Array(await blob.arrayBuffer()));
|
|
1054
1659
|
return {
|
|
1055
|
-
..."data" in args ? omit(args, "data") : omit(args, "inputs"),
|
|
1056
|
-
audio_url: `data:${contentType};base64,${base64audio}`
|
|
1057
|
-
};
|
|
1058
|
-
} else {
|
|
1059
|
-
return preparePayload(args);
|
|
1060
|
-
}
|
|
1061
|
-
}
|
|
1062
|
-
|
|
1063
|
-
// src/tasks/audio/textToSpeech.ts
|
|
1064
|
-
async function textToSpeech(args, options) {
|
|
1065
|
-
const payload = args.provider === "replicate" ? {
|
|
1066
|
-
...omit(args, ["inputs", "parameters"]),
|
|
1067
|
-
...args.parameters,
|
|
1068
|
-
text: args.inputs
|
|
1069
|
-
} : args;
|
|
1070
|
-
const { data: res } = await innerRequest(payload, {
|
|
1071
|
-
...options,
|
|
1072
|
-
task: "text-to-speech"
|
|
1073
|
-
});
|
|
1074
|
-
if (res instanceof Blob) {
|
|
1075
|
-
return res;
|
|
1076
|
-
}
|
|
1077
|
-
if (res && typeof res === "object") {
|
|
1078
|
-
if ("output" in res) {
|
|
1079
|
-
if (typeof res.output === "string") {
|
|
1080
|
-
const urlResponse = await fetch(res.output);
|
|
1081
|
-
const blob = await urlResponse.blob();
|
|
1082
|
-
return blob;
|
|
1083
|
-
} else if (Array.isArray(res.output)) {
|
|
1084
|
-
const urlResponse = await fetch(res.output[0]);
|
|
1085
|
-
const blob = await urlResponse.blob();
|
|
1086
|
-
return blob;
|
|
1087
|
-
}
|
|
1088
|
-
}
|
|
1660
|
+
..."data" in args ? omit(args, "data") : omit(args, "inputs"),
|
|
1661
|
+
audio_url: `data:${contentType};base64,${base64audio}`
|
|
1662
|
+
};
|
|
1663
|
+
} else {
|
|
1664
|
+
return preparePayload(args);
|
|
1089
1665
|
}
|
|
1090
|
-
throw new InferenceOutputError("Expected Blob or object with output");
|
|
1091
1666
|
}
|
|
1092
1667
|
|
|
1093
|
-
// src/tasks/audio/
|
|
1094
|
-
async function
|
|
1095
|
-
const
|
|
1096
|
-
const
|
|
1668
|
+
// src/tasks/audio/textToSpeech.ts
|
|
1669
|
+
async function textToSpeech(args, options) {
|
|
1670
|
+
const provider = args.provider ?? "hf-inference";
|
|
1671
|
+
const providerHelper = getProviderHelper(provider, "text-to-speech");
|
|
1672
|
+
const { data: res } = await innerRequest(args, {
|
|
1097
1673
|
...options,
|
|
1098
|
-
task: "
|
|
1674
|
+
task: "text-to-speech"
|
|
1099
1675
|
});
|
|
1100
|
-
return
|
|
1101
|
-
}
|
|
1102
|
-
function validateOutput(output) {
|
|
1103
|
-
if (!Array.isArray(output)) {
|
|
1104
|
-
throw new InferenceOutputError("Expected Array");
|
|
1105
|
-
}
|
|
1106
|
-
if (!output.every((elem) => {
|
|
1107
|
-
return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
|
|
1108
|
-
})) {
|
|
1109
|
-
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
|
|
1110
|
-
}
|
|
1111
|
-
return output;
|
|
1676
|
+
return providerHelper.getResponse(res);
|
|
1112
1677
|
}
|
|
1113
1678
|
|
|
1114
1679
|
// src/tasks/cv/utils.ts
|
|
@@ -1118,183 +1683,95 @@ function preparePayload2(args) {
|
|
|
1118
1683
|
|
|
1119
1684
|
// src/tasks/cv/imageClassification.ts
|
|
1120
1685
|
async function imageClassification(args, options) {
|
|
1686
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
|
|
1121
1687
|
const payload = preparePayload2(args);
|
|
1122
1688
|
const { data: res } = await innerRequest(payload, {
|
|
1123
1689
|
...options,
|
|
1124
1690
|
task: "image-classification"
|
|
1125
1691
|
});
|
|
1126
|
-
|
|
1127
|
-
if (!isValidOutput) {
|
|
1128
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
1129
|
-
}
|
|
1130
|
-
return res;
|
|
1692
|
+
return providerHelper.getResponse(res);
|
|
1131
1693
|
}
|
|
1132
1694
|
|
|
1133
1695
|
// src/tasks/cv/imageSegmentation.ts
|
|
1134
1696
|
async function imageSegmentation(args, options) {
|
|
1697
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
|
|
1135
1698
|
const payload = preparePayload2(args);
|
|
1136
1699
|
const { data: res } = await innerRequest(payload, {
|
|
1137
1700
|
...options,
|
|
1138
1701
|
task: "image-segmentation"
|
|
1139
1702
|
});
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1703
|
+
return providerHelper.getResponse(res);
|
|
1704
|
+
}
|
|
1705
|
+
|
|
1706
|
+
// src/tasks/cv/imageToImage.ts
|
|
1707
|
+
async function imageToImage(args, options) {
|
|
1708
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
|
|
1709
|
+
let reqArgs;
|
|
1710
|
+
if (!args.parameters) {
|
|
1711
|
+
reqArgs = {
|
|
1712
|
+
accessToken: args.accessToken,
|
|
1713
|
+
model: args.model,
|
|
1714
|
+
data: args.inputs
|
|
1715
|
+
};
|
|
1716
|
+
} else {
|
|
1717
|
+
reqArgs = {
|
|
1718
|
+
...args,
|
|
1719
|
+
inputs: base64FromBytes(
|
|
1720
|
+
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await args.inputs.arrayBuffer())
|
|
1721
|
+
)
|
|
1722
|
+
};
|
|
1143
1723
|
}
|
|
1144
|
-
|
|
1724
|
+
const { data: res } = await innerRequest(reqArgs, {
|
|
1725
|
+
...options,
|
|
1726
|
+
task: "image-to-image"
|
|
1727
|
+
});
|
|
1728
|
+
return providerHelper.getResponse(res);
|
|
1145
1729
|
}
|
|
1146
1730
|
|
|
1147
1731
|
// src/tasks/cv/imageToText.ts
|
|
1148
1732
|
async function imageToText(args, options) {
|
|
1733
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
|
|
1149
1734
|
const payload = preparePayload2(args);
|
|
1150
1735
|
const { data: res } = await innerRequest(payload, {
|
|
1151
1736
|
...options,
|
|
1152
1737
|
task: "image-to-text"
|
|
1153
1738
|
});
|
|
1154
|
-
|
|
1155
|
-
throw new InferenceOutputError("Expected {generated_text: string}");
|
|
1156
|
-
}
|
|
1157
|
-
return res?.[0];
|
|
1739
|
+
return providerHelper.getResponse(res[0]);
|
|
1158
1740
|
}
|
|
1159
1741
|
|
|
1160
1742
|
// src/tasks/cv/objectDetection.ts
|
|
1161
1743
|
async function objectDetection(args, options) {
|
|
1744
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
|
|
1162
1745
|
const payload = preparePayload2(args);
|
|
1163
1746
|
const { data: res } = await innerRequest(payload, {
|
|
1164
1747
|
...options,
|
|
1165
1748
|
task: "object-detection"
|
|
1166
1749
|
});
|
|
1167
|
-
|
|
1168
|
-
(x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
|
|
1169
|
-
);
|
|
1170
|
-
if (!isValidOutput) {
|
|
1171
|
-
throw new InferenceOutputError(
|
|
1172
|
-
"Expected Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
|
|
1173
|
-
);
|
|
1174
|
-
}
|
|
1175
|
-
return res;
|
|
1750
|
+
return providerHelper.getResponse(res);
|
|
1176
1751
|
}
|
|
1177
1752
|
|
|
1178
1753
|
// src/tasks/cv/textToImage.ts
|
|
1179
|
-
function getResponseFormatArg(provider) {
|
|
1180
|
-
switch (provider) {
|
|
1181
|
-
case "fal-ai":
|
|
1182
|
-
return { sync_mode: true };
|
|
1183
|
-
case "nebius":
|
|
1184
|
-
return { response_format: "b64_json" };
|
|
1185
|
-
case "replicate":
|
|
1186
|
-
return void 0;
|
|
1187
|
-
case "together":
|
|
1188
|
-
return { response_format: "base64" };
|
|
1189
|
-
default:
|
|
1190
|
-
return void 0;
|
|
1191
|
-
}
|
|
1192
|
-
}
|
|
1193
1754
|
async function textToImage(args, options) {
|
|
1194
|
-
const
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
...getResponseFormatArg(args.provider),
|
|
1198
|
-
prompt: args.inputs
|
|
1199
|
-
};
|
|
1200
|
-
const { data: res } = await innerRequest(payload, {
|
|
1755
|
+
const provider = args.provider ?? "hf-inference";
|
|
1756
|
+
const providerHelper = getProviderHelper(provider, "text-to-image");
|
|
1757
|
+
const { data: res } = await innerRequest(args, {
|
|
1201
1758
|
...options,
|
|
1202
1759
|
task: "text-to-image"
|
|
1203
1760
|
});
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
return await pollBflResponse(res.polling_url, options?.outputType);
|
|
1207
|
-
}
|
|
1208
|
-
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
|
|
1209
|
-
if (options?.outputType === "url") {
|
|
1210
|
-
return res.images[0].url;
|
|
1211
|
-
} else {
|
|
1212
|
-
const image = await fetch(res.images[0].url);
|
|
1213
|
-
return await image.blob();
|
|
1214
|
-
}
|
|
1215
|
-
}
|
|
1216
|
-
if (args.provider === "hyperbolic" && "images" in res && Array.isArray(res.images) && res.images[0] && typeof res.images[0].image === "string") {
|
|
1217
|
-
if (options?.outputType === "url") {
|
|
1218
|
-
return `data:image/jpeg;base64,${res.images[0].image}`;
|
|
1219
|
-
}
|
|
1220
|
-
const base64Response = await fetch(`data:image/jpeg;base64,${res.images[0].image}`);
|
|
1221
|
-
return await base64Response.blob();
|
|
1222
|
-
}
|
|
1223
|
-
if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
|
|
1224
|
-
const base64Data = res.data[0].b64_json;
|
|
1225
|
-
if (options?.outputType === "url") {
|
|
1226
|
-
return `data:image/jpeg;base64,${base64Data}`;
|
|
1227
|
-
}
|
|
1228
|
-
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
1229
|
-
return await base64Response.blob();
|
|
1230
|
-
}
|
|
1231
|
-
if ("output" in res && Array.isArray(res.output)) {
|
|
1232
|
-
if (options?.outputType === "url") {
|
|
1233
|
-
return res.output[0];
|
|
1234
|
-
}
|
|
1235
|
-
const urlResponse = await fetch(res.output[0]);
|
|
1236
|
-
const blob = await urlResponse.blob();
|
|
1237
|
-
return blob;
|
|
1238
|
-
}
|
|
1239
|
-
}
|
|
1240
|
-
const isValidOutput = res && res instanceof Blob;
|
|
1241
|
-
if (!isValidOutput) {
|
|
1242
|
-
throw new InferenceOutputError("Expected Blob");
|
|
1243
|
-
}
|
|
1244
|
-
if (options?.outputType === "url") {
|
|
1245
|
-
const b64 = await res.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
|
|
1246
|
-
return `data:image/jpeg;base64,${b64}`;
|
|
1247
|
-
}
|
|
1248
|
-
return res;
|
|
1249
|
-
}
|
|
1250
|
-
async function pollBflResponse(url, outputType) {
|
|
1251
|
-
const urlObj = new URL(url);
|
|
1252
|
-
for (let step = 0; step < 5; step++) {
|
|
1253
|
-
await delay(1e3);
|
|
1254
|
-
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
|
|
1255
|
-
urlObj.searchParams.set("attempt", step.toString(10));
|
|
1256
|
-
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
|
|
1257
|
-
if (!resp.ok) {
|
|
1258
|
-
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
1259
|
-
}
|
|
1260
|
-
const payload = await resp.json();
|
|
1261
|
-
if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
|
|
1262
|
-
if (outputType === "url") {
|
|
1263
|
-
return payload.result.sample;
|
|
1264
|
-
}
|
|
1265
|
-
const image = await fetch(payload.result.sample);
|
|
1266
|
-
return await image.blob();
|
|
1267
|
-
}
|
|
1268
|
-
}
|
|
1269
|
-
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
1761
|
+
const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-image" });
|
|
1762
|
+
return providerHelper.getResponse(res, url, info.headers, options?.outputType);
|
|
1270
1763
|
}
|
|
1271
1764
|
|
|
1272
|
-
// src/tasks/cv/
|
|
1273
|
-
async function
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
accessToken: args.accessToken,
|
|
1278
|
-
model: args.model,
|
|
1279
|
-
data: args.inputs
|
|
1280
|
-
};
|
|
1281
|
-
} else {
|
|
1282
|
-
reqArgs = {
|
|
1283
|
-
...args,
|
|
1284
|
-
inputs: base64FromBytes(
|
|
1285
|
-
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await args.inputs.arrayBuffer())
|
|
1286
|
-
)
|
|
1287
|
-
};
|
|
1288
|
-
}
|
|
1289
|
-
const { data: res } = await innerRequest(reqArgs, {
|
|
1765
|
+
// src/tasks/cv/textToVideo.ts
|
|
1766
|
+
async function textToVideo(args, options) {
|
|
1767
|
+
const provider = args.provider ?? "hf-inference";
|
|
1768
|
+
const providerHelper = getProviderHelper(provider, "text-to-video");
|
|
1769
|
+
const { data: response } = await innerRequest(args, {
|
|
1290
1770
|
...options,
|
|
1291
|
-
task: "
|
|
1771
|
+
task: "text-to-video"
|
|
1292
1772
|
});
|
|
1293
|
-
const
|
|
1294
|
-
|
|
1295
|
-
throw new InferenceOutputError("Expected Blob");
|
|
1296
|
-
}
|
|
1297
|
-
return res;
|
|
1773
|
+
const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-video" });
|
|
1774
|
+
return providerHelper.getResponse(response, url, info.headers);
|
|
1298
1775
|
}
|
|
1299
1776
|
|
|
1300
1777
|
// src/tasks/cv/zeroShotImageClassification.ts
|
|
@@ -1320,226 +1797,112 @@ async function preparePayload3(args) {
|
|
|
1320
1797
|
}
|
|
1321
1798
|
}
|
|
1322
1799
|
async function zeroShotImageClassification(args, options) {
|
|
1800
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
|
|
1323
1801
|
const payload = await preparePayload3(args);
|
|
1324
1802
|
const { data: res } = await innerRequest(payload, {
|
|
1325
1803
|
...options,
|
|
1326
1804
|
task: "zero-shot-image-classification"
|
|
1327
1805
|
});
|
|
1328
|
-
|
|
1329
|
-
if (!isValidOutput) {
|
|
1330
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
1331
|
-
}
|
|
1332
|
-
return res;
|
|
1806
|
+
return providerHelper.getResponse(res);
|
|
1333
1807
|
}
|
|
1334
1808
|
|
|
1335
|
-
// src/tasks/
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
throw new Error(
|
|
1340
|
-
`textToVideo inference is only supported for the following providers: ${SUPPORTED_PROVIDERS.join(", ")}`
|
|
1341
|
-
);
|
|
1342
|
-
}
|
|
1343
|
-
const payload = args.provider === "fal-ai" || args.provider === "replicate" || args.provider === "novita" ? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs } : args;
|
|
1344
|
-
const { data, requestContext } = await innerRequest(payload, {
|
|
1809
|
+
// src/tasks/nlp/chatCompletion.ts
|
|
1810
|
+
async function chatCompletion(args, options) {
|
|
1811
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
|
|
1812
|
+
const { data: response } = await innerRequest(args, {
|
|
1345
1813
|
...options,
|
|
1346
|
-
task: "
|
|
1814
|
+
task: "conversational"
|
|
1815
|
+
});
|
|
1816
|
+
return providerHelper.getResponse(response);
|
|
1817
|
+
}
|
|
1818
|
+
|
|
1819
|
+
// src/tasks/nlp/chatCompletionStream.ts
|
|
1820
|
+
async function* chatCompletionStream(args, options) {
|
|
1821
|
+
yield* innerStreamingRequest(args, {
|
|
1822
|
+
...options,
|
|
1823
|
+
task: "conversational"
|
|
1347
1824
|
});
|
|
1348
|
-
if (args.provider === "fal-ai") {
|
|
1349
|
-
return await pollFalResponse(
|
|
1350
|
-
data,
|
|
1351
|
-
requestContext.url,
|
|
1352
|
-
requestContext.info.headers
|
|
1353
|
-
);
|
|
1354
|
-
} else if (args.provider === "novita") {
|
|
1355
|
-
const isValidOutput = typeof data === "object" && !!data && "video" in data && typeof data.video === "object" && !!data.video && "video_url" in data.video && typeof data.video.video_url === "string" && isUrl(data.video.video_url);
|
|
1356
|
-
if (!isValidOutput) {
|
|
1357
|
-
throw new InferenceOutputError("Expected { video: { video_url: string } }");
|
|
1358
|
-
}
|
|
1359
|
-
const urlResponse = await fetch(data.video.video_url);
|
|
1360
|
-
return await urlResponse.blob();
|
|
1361
|
-
} else {
|
|
1362
|
-
const isValidOutput = typeof data === "object" && !!data && "output" in data && typeof data.output === "string" && isUrl(data.output);
|
|
1363
|
-
if (!isValidOutput) {
|
|
1364
|
-
throw new InferenceOutputError("Expected { output: string }");
|
|
1365
|
-
}
|
|
1366
|
-
const urlResponse = await fetch(data.output);
|
|
1367
|
-
return await urlResponse.blob();
|
|
1368
|
-
}
|
|
1369
1825
|
}
|
|
1370
1826
|
|
|
1371
1827
|
// src/tasks/nlp/featureExtraction.ts
|
|
1372
1828
|
async function featureExtraction(args, options) {
|
|
1829
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction");
|
|
1373
1830
|
const { data: res } = await innerRequest(args, {
|
|
1374
1831
|
...options,
|
|
1375
1832
|
task: "feature-extraction"
|
|
1376
1833
|
});
|
|
1377
|
-
|
|
1378
|
-
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
|
|
1379
|
-
if (curDepth > maxDepth)
|
|
1380
|
-
return false;
|
|
1381
|
-
if (arr.every((x) => Array.isArray(x))) {
|
|
1382
|
-
return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
|
|
1383
|
-
} else {
|
|
1384
|
-
return arr.every((x) => typeof x === "number");
|
|
1385
|
-
}
|
|
1386
|
-
};
|
|
1387
|
-
isValidOutput = Array.isArray(res) && isNumArrayRec(res, 3, 0);
|
|
1388
|
-
if (!isValidOutput) {
|
|
1389
|
-
throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
|
|
1390
|
-
}
|
|
1391
|
-
return res;
|
|
1834
|
+
return providerHelper.getResponse(res);
|
|
1392
1835
|
}
|
|
1393
1836
|
|
|
1394
1837
|
// src/tasks/nlp/fillMask.ts
|
|
1395
1838
|
async function fillMask(args, options) {
|
|
1839
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask");
|
|
1396
1840
|
const { data: res } = await innerRequest(args, {
|
|
1397
1841
|
...options,
|
|
1398
1842
|
task: "fill-mask"
|
|
1399
1843
|
});
|
|
1400
|
-
|
|
1401
|
-
(x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
|
|
1402
|
-
);
|
|
1403
|
-
if (!isValidOutput) {
|
|
1404
|
-
throw new InferenceOutputError(
|
|
1405
|
-
"Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
|
|
1406
|
-
);
|
|
1407
|
-
}
|
|
1408
|
-
return res;
|
|
1844
|
+
return providerHelper.getResponse(res);
|
|
1409
1845
|
}
|
|
1410
1846
|
|
|
1411
1847
|
// src/tasks/nlp/questionAnswering.ts
|
|
1412
1848
|
async function questionAnswering(args, options) {
|
|
1849
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering");
|
|
1413
1850
|
const { data: res } = await innerRequest(args, {
|
|
1414
1851
|
...options,
|
|
1415
1852
|
task: "question-answering"
|
|
1416
1853
|
});
|
|
1417
|
-
|
|
1418
|
-
(elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
|
|
1419
|
-
) : typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";
|
|
1420
|
-
if (!isValidOutput) {
|
|
1421
|
-
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
1422
|
-
}
|
|
1423
|
-
return Array.isArray(res) ? res[0] : res;
|
|
1854
|
+
return providerHelper.getResponse(res);
|
|
1424
1855
|
}
|
|
1425
1856
|
|
|
1426
1857
|
// src/tasks/nlp/sentenceSimilarity.ts
|
|
1427
1858
|
async function sentenceSimilarity(args, options) {
|
|
1859
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity");
|
|
1428
1860
|
const { data: res } = await innerRequest(args, {
|
|
1429
1861
|
...options,
|
|
1430
1862
|
task: "sentence-similarity"
|
|
1431
1863
|
});
|
|
1432
|
-
|
|
1433
|
-
if (!isValidOutput) {
|
|
1434
|
-
throw new InferenceOutputError("Expected number[]");
|
|
1435
|
-
}
|
|
1436
|
-
return res;
|
|
1864
|
+
return providerHelper.getResponse(res);
|
|
1437
1865
|
}
|
|
1438
1866
|
|
|
1439
1867
|
// src/tasks/nlp/summarization.ts
|
|
1440
1868
|
async function summarization(args, options) {
|
|
1869
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization");
|
|
1441
1870
|
const { data: res } = await innerRequest(args, {
|
|
1442
1871
|
...options,
|
|
1443
1872
|
task: "summarization"
|
|
1444
1873
|
});
|
|
1445
|
-
|
|
1446
|
-
if (!isValidOutput) {
|
|
1447
|
-
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
|
1448
|
-
}
|
|
1449
|
-
return res?.[0];
|
|
1874
|
+
return providerHelper.getResponse(res);
|
|
1450
1875
|
}
|
|
1451
1876
|
|
|
1452
1877
|
// src/tasks/nlp/tableQuestionAnswering.ts
|
|
1453
1878
|
async function tableQuestionAnswering(args, options) {
|
|
1879
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
|
|
1454
1880
|
const { data: res } = await innerRequest(args, {
|
|
1455
1881
|
...options,
|
|
1456
1882
|
task: "table-question-answering"
|
|
1457
1883
|
});
|
|
1458
|
-
|
|
1459
|
-
if (!isValidOutput) {
|
|
1460
|
-
throw new InferenceOutputError(
|
|
1461
|
-
"Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
|
|
1462
|
-
);
|
|
1463
|
-
}
|
|
1464
|
-
return Array.isArray(res) ? res[0] : res;
|
|
1465
|
-
}
|
|
1466
|
-
function validate(elem) {
|
|
1467
|
-
return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
|
|
1468
|
-
(coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
|
|
1469
|
-
);
|
|
1884
|
+
return providerHelper.getResponse(res);
|
|
1470
1885
|
}
|
|
1471
1886
|
|
|
1472
1887
|
// src/tasks/nlp/textClassification.ts
|
|
1473
1888
|
async function textClassification(args, options) {
|
|
1889
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
|
|
1474
1890
|
const { data: res } = await innerRequest(args, {
|
|
1475
1891
|
...options,
|
|
1476
1892
|
task: "text-classification"
|
|
1477
1893
|
});
|
|
1478
|
-
|
|
1479
|
-
const isValidOutput = Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number");
|
|
1480
|
-
if (!isValidOutput) {
|
|
1481
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
1482
|
-
}
|
|
1483
|
-
return output;
|
|
1484
|
-
}
|
|
1485
|
-
|
|
1486
|
-
// src/utils/toArray.ts
|
|
1487
|
-
function toArray(obj) {
|
|
1488
|
-
if (Array.isArray(obj)) {
|
|
1489
|
-
return obj;
|
|
1490
|
-
}
|
|
1491
|
-
return [obj];
|
|
1894
|
+
return providerHelper.getResponse(res);
|
|
1492
1895
|
}
|
|
1493
1896
|
|
|
1494
1897
|
// src/tasks/nlp/textGeneration.ts
|
|
1495
1898
|
async function textGeneration(args, options) {
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
|
|
1502
|
-
|
|
1503
|
-
if (!isValidOutput) {
|
|
1504
|
-
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
1505
|
-
}
|
|
1506
|
-
const completion = raw.choices[0];
|
|
1507
|
-
return {
|
|
1508
|
-
generated_text: completion.text
|
|
1509
|
-
};
|
|
1510
|
-
} else if (args.provider === "hyperbolic") {
|
|
1511
|
-
const payload = {
|
|
1512
|
-
messages: [{ content: args.inputs, role: "user" }],
|
|
1513
|
-
...args.parameters ? {
|
|
1514
|
-
max_tokens: args.parameters.max_new_tokens,
|
|
1515
|
-
...omit(args.parameters, "max_new_tokens")
|
|
1516
|
-
} : void 0,
|
|
1517
|
-
...omit(args, ["inputs", "parameters"])
|
|
1518
|
-
};
|
|
1519
|
-
const raw = (await innerRequest(payload, {
|
|
1520
|
-
...options,
|
|
1521
|
-
task: "text-generation"
|
|
1522
|
-
})).data;
|
|
1523
|
-
const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
|
|
1524
|
-
if (!isValidOutput) {
|
|
1525
|
-
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
1526
|
-
}
|
|
1527
|
-
const completion = raw.choices[0];
|
|
1528
|
-
return {
|
|
1529
|
-
generated_text: completion.message.content
|
|
1530
|
-
};
|
|
1531
|
-
} else {
|
|
1532
|
-
const { data: res } = await innerRequest(args, {
|
|
1533
|
-
...options,
|
|
1534
|
-
task: "text-generation"
|
|
1535
|
-
});
|
|
1536
|
-
const output = toArray(res);
|
|
1537
|
-
const isValidOutput = Array.isArray(output) && output.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
|
|
1538
|
-
if (!isValidOutput) {
|
|
1539
|
-
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
1540
|
-
}
|
|
1541
|
-
return output?.[0];
|
|
1542
|
-
}
|
|
1899
|
+
const provider = args.provider ?? "hf-inference";
|
|
1900
|
+
const providerHelper = getProviderHelper(provider, "text-generation");
|
|
1901
|
+
const { data: response } = await innerRequest(args, {
|
|
1902
|
+
...options,
|
|
1903
|
+
task: "text-generation"
|
|
1904
|
+
});
|
|
1905
|
+
return providerHelper.getResponse(response);
|
|
1543
1906
|
}
|
|
1544
1907
|
|
|
1545
1908
|
// src/tasks/nlp/textGenerationStream.ts
|
|
@@ -1552,77 +1915,37 @@ async function* textGenerationStream(args, options) {
|
|
|
1552
1915
|
|
|
1553
1916
|
// src/tasks/nlp/tokenClassification.ts
|
|
1554
1917
|
async function tokenClassification(args, options) {
|
|
1918
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification");
|
|
1555
1919
|
const { data: res } = await innerRequest(args, {
|
|
1556
1920
|
...options,
|
|
1557
1921
|
task: "token-classification"
|
|
1558
1922
|
});
|
|
1559
|
-
|
|
1560
|
-
const isValidOutput = Array.isArray(output) && output.every(
|
|
1561
|
-
(x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
|
|
1562
|
-
);
|
|
1563
|
-
if (!isValidOutput) {
|
|
1564
|
-
throw new InferenceOutputError(
|
|
1565
|
-
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
|
|
1566
|
-
);
|
|
1567
|
-
}
|
|
1568
|
-
return output;
|
|
1923
|
+
return providerHelper.getResponse(res);
|
|
1569
1924
|
}
|
|
1570
1925
|
|
|
1571
1926
|
// src/tasks/nlp/translation.ts
|
|
1572
1927
|
async function translation(args, options) {
|
|
1928
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation");
|
|
1573
1929
|
const { data: res } = await innerRequest(args, {
|
|
1574
1930
|
...options,
|
|
1575
1931
|
task: "translation"
|
|
1576
1932
|
});
|
|
1577
|
-
|
|
1578
|
-
if (!isValidOutput) {
|
|
1579
|
-
throw new InferenceOutputError("Expected type Array<{translation_text: string}>");
|
|
1580
|
-
}
|
|
1581
|
-
return res?.length === 1 ? res?.[0] : res;
|
|
1933
|
+
return providerHelper.getResponse(res);
|
|
1582
1934
|
}
|
|
1583
1935
|
|
|
1584
1936
|
// src/tasks/nlp/zeroShotClassification.ts
|
|
1585
1937
|
async function zeroShotClassification(args, options) {
|
|
1938
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification");
|
|
1586
1939
|
const { data: res } = await innerRequest(args, {
|
|
1587
1940
|
...options,
|
|
1588
1941
|
task: "zero-shot-classification"
|
|
1589
1942
|
});
|
|
1590
|
-
|
|
1591
|
-
const isValidOutput = Array.isArray(output) && output.every(
|
|
1592
|
-
(x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
|
|
1593
|
-
);
|
|
1594
|
-
if (!isValidOutput) {
|
|
1595
|
-
throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
|
|
1596
|
-
}
|
|
1597
|
-
return output;
|
|
1598
|
-
}
|
|
1599
|
-
|
|
1600
|
-
// src/tasks/nlp/chatCompletion.ts
|
|
1601
|
-
async function chatCompletion(args, options) {
|
|
1602
|
-
const { data: res } = await innerRequest(args, {
|
|
1603
|
-
...options,
|
|
1604
|
-
task: "text-generation",
|
|
1605
|
-
chatCompletion: true
|
|
1606
|
-
});
|
|
1607
|
-
const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
|
|
1608
|
-
(res.system_fingerprint === void 0 || res.system_fingerprint === null || typeof res.system_fingerprint === "string") && typeof res?.usage === "object";
|
|
1609
|
-
if (!isValidOutput) {
|
|
1610
|
-
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
1611
|
-
}
|
|
1612
|
-
return res;
|
|
1613
|
-
}
|
|
1614
|
-
|
|
1615
|
-
// src/tasks/nlp/chatCompletionStream.ts
|
|
1616
|
-
async function* chatCompletionStream(args, options) {
|
|
1617
|
-
yield* innerStreamingRequest(args, {
|
|
1618
|
-
...options,
|
|
1619
|
-
task: "text-generation",
|
|
1620
|
-
chatCompletion: true
|
|
1621
|
-
});
|
|
1943
|
+
return providerHelper.getResponse(res);
|
|
1622
1944
|
}
|
|
1623
1945
|
|
|
1624
1946
|
// src/tasks/multimodal/documentQuestionAnswering.ts
|
|
1625
1947
|
async function documentQuestionAnswering(args, options) {
|
|
1948
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering");
|
|
1626
1949
|
const reqArgs = {
|
|
1627
1950
|
...args,
|
|
1628
1951
|
inputs: {
|
|
@@ -1638,18 +1961,12 @@ async function documentQuestionAnswering(args, options) {
|
|
|
1638
1961
|
task: "document-question-answering"
|
|
1639
1962
|
}
|
|
1640
1963
|
);
|
|
1641
|
-
|
|
1642
|
-
const isValidOutput = Array.isArray(output) && output.every(
|
|
1643
|
-
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
|
|
1644
|
-
);
|
|
1645
|
-
if (!isValidOutput) {
|
|
1646
|
-
throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>");
|
|
1647
|
-
}
|
|
1648
|
-
return output[0];
|
|
1964
|
+
return providerHelper.getResponse(res);
|
|
1649
1965
|
}
|
|
1650
1966
|
|
|
1651
1967
|
// src/tasks/multimodal/visualQuestionAnswering.ts
|
|
1652
1968
|
async function visualQuestionAnswering(args, options) {
|
|
1969
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "visual-question-answering");
|
|
1653
1970
|
const reqArgs = {
|
|
1654
1971
|
...args,
|
|
1655
1972
|
inputs: {
|
|
@@ -1662,39 +1979,27 @@ async function visualQuestionAnswering(args, options) {
|
|
|
1662
1979
|
...options,
|
|
1663
1980
|
task: "visual-question-answering"
|
|
1664
1981
|
});
|
|
1665
|
-
|
|
1666
|
-
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
|
|
1667
|
-
);
|
|
1668
|
-
if (!isValidOutput) {
|
|
1669
|
-
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
|
1670
|
-
}
|
|
1671
|
-
return res[0];
|
|
1982
|
+
return providerHelper.getResponse(res);
|
|
1672
1983
|
}
|
|
1673
1984
|
|
|
1674
|
-
// src/tasks/tabular/
|
|
1675
|
-
async function
|
|
1985
|
+
// src/tasks/tabular/tabularClassification.ts
|
|
1986
|
+
async function tabularClassification(args, options) {
|
|
1987
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification");
|
|
1676
1988
|
const { data: res } = await innerRequest(args, {
|
|
1677
1989
|
...options,
|
|
1678
|
-
task: "tabular-
|
|
1990
|
+
task: "tabular-classification"
|
|
1679
1991
|
});
|
|
1680
|
-
|
|
1681
|
-
if (!isValidOutput) {
|
|
1682
|
-
throw new InferenceOutputError("Expected number[]");
|
|
1683
|
-
}
|
|
1684
|
-
return res;
|
|
1992
|
+
return providerHelper.getResponse(res);
|
|
1685
1993
|
}
|
|
1686
1994
|
|
|
1687
|
-
// src/tasks/tabular/
|
|
1688
|
-
async function
|
|
1995
|
+
// src/tasks/tabular/tabularRegression.ts
|
|
1996
|
+
async function tabularRegression(args, options) {
|
|
1997
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression");
|
|
1689
1998
|
const { data: res } = await innerRequest(args, {
|
|
1690
1999
|
...options,
|
|
1691
|
-
task: "tabular-
|
|
2000
|
+
task: "tabular-regression"
|
|
1692
2001
|
});
|
|
1693
|
-
|
|
1694
|
-
if (!isValidOutput) {
|
|
1695
|
-
throw new InferenceOutputError("Expected number[]");
|
|
1696
|
-
}
|
|
1697
|
-
return res;
|
|
2002
|
+
return providerHelper.getResponse(res);
|
|
1698
2003
|
}
|
|
1699
2004
|
|
|
1700
2005
|
// src/InferenceClient.ts
|
|
@@ -1763,8 +2068,8 @@ __export(snippets_exports, {
|
|
|
1763
2068
|
});
|
|
1764
2069
|
|
|
1765
2070
|
// src/snippets/getInferenceSnippets.ts
|
|
1766
|
-
var import_tasks = require("@huggingface/tasks");
|
|
1767
2071
|
var import_jinja = require("@huggingface/jinja");
|
|
2072
|
+
var import_tasks = require("@huggingface/tasks");
|
|
1768
2073
|
|
|
1769
2074
|
// src/snippets/templates.exported.ts
|
|
1770
2075
|
var templates = {
|
|
@@ -1807,7 +2112,7 @@ const image = await client.textToVideo({
|
|
|
1807
2112
|
},
|
|
1808
2113
|
"openai": {
|
|
1809
2114
|
"conversational": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\nconst chatCompletion = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);',
|
|
1810
|
-
"conversationalStream": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\
|
|
2115
|
+
"conversationalStream": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\nconst stream = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n stream: true,\n});\n\nfor await (const chunk of stream) {\n process.stdout.write(chunk.choices[0]?.delta?.content || "");\n}'
|
|
1811
2116
|
}
|
|
1812
2117
|
},
|
|
1813
2118
|
"python": {
|
|
@@ -1939,15 +2244,23 @@ var HF_JS_METHODS = {
|
|
|
1939
2244
|
};
|
|
1940
2245
|
var snippetGenerator = (templateName, inputPreparationFn) => {
|
|
1941
2246
|
return (model, accessToken, provider, providerModelId, opts) => {
|
|
2247
|
+
let task = model.pipeline_tag;
|
|
1942
2248
|
if (model.pipeline_tag && ["text-generation", "image-text-to-text"].includes(model.pipeline_tag) && model.tags.includes("conversational")) {
|
|
1943
2249
|
templateName = opts?.streaming ? "conversationalStream" : "conversational";
|
|
1944
2250
|
inputPreparationFn = prepareConversationalInput;
|
|
2251
|
+
task = "conversational";
|
|
1945
2252
|
}
|
|
1946
2253
|
const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: (0, import_tasks.getModelInputSnippet)(model) };
|
|
1947
2254
|
const request2 = makeRequestOptionsFromResolvedModel(
|
|
1948
2255
|
providerModelId ?? model.id,
|
|
1949
|
-
{
|
|
1950
|
-
|
|
2256
|
+
{
|
|
2257
|
+
accessToken,
|
|
2258
|
+
provider,
|
|
2259
|
+
...inputs
|
|
2260
|
+
},
|
|
2261
|
+
{
|
|
2262
|
+
task
|
|
2263
|
+
}
|
|
1951
2264
|
);
|
|
1952
2265
|
let providerInputs = inputs;
|
|
1953
2266
|
const bodyAsObj = request2.info.body;
|
|
@@ -2034,7 +2347,7 @@ var prepareConversationalInput = (model, opts) => {
|
|
|
2034
2347
|
return {
|
|
2035
2348
|
messages: opts?.messages ?? (0, import_tasks.getModelInputSnippet)(model),
|
|
2036
2349
|
...opts?.temperature ? { temperature: opts?.temperature } : void 0,
|
|
2037
|
-
max_tokens: opts?.max_tokens ??
|
|
2350
|
+
max_tokens: opts?.max_tokens ?? 512,
|
|
2038
2351
|
...opts?.top_p ? { top_p: opts?.top_p } : void 0
|
|
2039
2352
|
};
|
|
2040
2353
|
};
|