@huggingface/inference 3.9.2 → 3.11.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +9 -7
- package/dist/index.cjs +771 -646
- package/dist/index.js +770 -646
- package/dist/src/InferenceClient.d.ts +16 -17
- package/dist/src/InferenceClient.d.ts.map +1 -1
- package/dist/src/lib/getInferenceProviderMapping.d.ts +6 -2
- package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -1
- package/dist/src/lib/getProviderHelper.d.ts.map +1 -1
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/consts.d.ts.map +1 -1
- package/dist/src/providers/ovhcloud.d.ts +38 -0
- package/dist/src/providers/ovhcloud.d.ts.map +1 -0
- package/dist/src/providers/providerHelper.d.ts +1 -1
- package/dist/src/providers/providerHelper.d.ts.map +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
- package/dist/src/snippets/templates.exported.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioToAudio.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
- package/dist/src/types.d.ts +7 -5
- package/dist/src/types.d.ts.map +1 -1
- package/dist/src/utils/typedEntries.d.ts +4 -0
- package/dist/src/utils/typedEntries.d.ts.map +1 -0
- package/package.json +3 -3
- package/src/InferenceClient.ts +32 -43
- package/src/lib/getInferenceProviderMapping.ts +68 -19
- package/src/lib/getProviderHelper.ts +5 -0
- package/src/lib/makeRequestOptions.ts +4 -3
- package/src/providers/consts.ts +1 -0
- package/src/providers/ovhcloud.ts +75 -0
- package/src/providers/providerHelper.ts +1 -1
- package/src/snippets/getInferenceSnippets.ts +5 -4
- package/src/snippets/templates.exported.ts +7 -3
- package/src/tasks/audio/audioClassification.ts +3 -1
- package/src/tasks/audio/audioToAudio.ts +4 -1
- package/src/tasks/audio/automaticSpeechRecognition.ts +3 -1
- package/src/tasks/audio/textToSpeech.ts +2 -1
- package/src/tasks/custom/request.ts +3 -1
- package/src/tasks/custom/streamingRequest.ts +3 -1
- package/src/tasks/cv/imageClassification.ts +3 -1
- package/src/tasks/cv/imageSegmentation.ts +3 -1
- package/src/tasks/cv/imageToImage.ts +3 -1
- package/src/tasks/cv/imageToText.ts +3 -1
- package/src/tasks/cv/objectDetection.ts +3 -1
- package/src/tasks/cv/textToImage.ts +2 -1
- package/src/tasks/cv/textToVideo.ts +2 -1
- package/src/tasks/cv/zeroShotImageClassification.ts +3 -1
- package/src/tasks/multimodal/documentQuestionAnswering.ts +3 -1
- package/src/tasks/multimodal/visualQuestionAnswering.ts +3 -1
- package/src/tasks/nlp/chatCompletion.ts +3 -1
- package/src/tasks/nlp/chatCompletionStream.ts +3 -1
- package/src/tasks/nlp/featureExtraction.ts +3 -1
- package/src/tasks/nlp/fillMask.ts +3 -1
- package/src/tasks/nlp/questionAnswering.ts +4 -1
- package/src/tasks/nlp/sentenceSimilarity.ts +3 -1
- package/src/tasks/nlp/summarization.ts +3 -1
- package/src/tasks/nlp/tableQuestionAnswering.ts +3 -1
- package/src/tasks/nlp/textClassification.ts +3 -1
- package/src/tasks/nlp/textGeneration.ts +3 -1
- package/src/tasks/nlp/textGenerationStream.ts +3 -1
- package/src/tasks/nlp/tokenClassification.ts +3 -1
- package/src/tasks/nlp/translation.ts +3 -1
- package/src/tasks/nlp/zeroShotClassification.ts +3 -1
- package/src/tasks/tabular/tabularClassification.ts +3 -1
- package/src/tasks/tabular/tabularRegression.ts +3 -1
- package/src/types.ts +9 -4
- package/src/utils/typedEntries.ts +5 -0
package/dist/index.cjs
CHANGED
|
@@ -25,6 +25,7 @@ __export(src_exports, {
|
|
|
25
25
|
InferenceClient: () => InferenceClient,
|
|
26
26
|
InferenceClientEndpoint: () => InferenceClientEndpoint,
|
|
27
27
|
InferenceOutputError: () => InferenceOutputError,
|
|
28
|
+
PROVIDERS_OR_POLICIES: () => PROVIDERS_OR_POLICIES,
|
|
28
29
|
audioClassification: () => audioClassification,
|
|
29
30
|
audioToAudio: () => audioToAudio,
|
|
30
31
|
automaticSpeechRecognition: () => automaticSpeechRecognition,
|
|
@@ -98,6 +99,38 @@ __export(tasks_exports, {
|
|
|
98
99
|
zeroShotImageClassification: () => zeroShotImageClassification
|
|
99
100
|
});
|
|
100
101
|
|
|
102
|
+
// src/config.ts
|
|
103
|
+
var HF_HUB_URL = "https://huggingface.co";
|
|
104
|
+
var HF_ROUTER_URL = "https://router.huggingface.co";
|
|
105
|
+
var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
|
|
106
|
+
|
|
107
|
+
// src/providers/consts.ts
|
|
108
|
+
var HARDCODED_MODEL_INFERENCE_MAPPING = {
|
|
109
|
+
/**
|
|
110
|
+
* "HF model ID" => "Model ID on Inference Provider's side"
|
|
111
|
+
*
|
|
112
|
+
* Example:
|
|
113
|
+
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
|
|
114
|
+
*/
|
|
115
|
+
"black-forest-labs": {},
|
|
116
|
+
cerebras: {},
|
|
117
|
+
cohere: {},
|
|
118
|
+
"fal-ai": {},
|
|
119
|
+
"featherless-ai": {},
|
|
120
|
+
"fireworks-ai": {},
|
|
121
|
+
groq: {},
|
|
122
|
+
"hf-inference": {},
|
|
123
|
+
hyperbolic: {},
|
|
124
|
+
nebius: {},
|
|
125
|
+
novita: {},
|
|
126
|
+
nscale: {},
|
|
127
|
+
openai: {},
|
|
128
|
+
ovhcloud: {},
|
|
129
|
+
replicate: {},
|
|
130
|
+
sambanova: {},
|
|
131
|
+
together: {}
|
|
132
|
+
};
|
|
133
|
+
|
|
101
134
|
// src/lib/InferenceOutputError.ts
|
|
102
135
|
var InferenceOutputError = class extends TypeError {
|
|
103
136
|
constructor(message) {
|
|
@@ -108,42 +141,6 @@ var InferenceOutputError = class extends TypeError {
|
|
|
108
141
|
}
|
|
109
142
|
};
|
|
110
143
|
|
|
111
|
-
// src/utils/delay.ts
|
|
112
|
-
function delay(ms) {
|
|
113
|
-
return new Promise((resolve) => {
|
|
114
|
-
setTimeout(() => resolve(), ms);
|
|
115
|
-
});
|
|
116
|
-
}
|
|
117
|
-
|
|
118
|
-
// src/utils/pick.ts
|
|
119
|
-
function pick(o, props) {
|
|
120
|
-
return Object.assign(
|
|
121
|
-
{},
|
|
122
|
-
...props.map((prop) => {
|
|
123
|
-
if (o[prop] !== void 0) {
|
|
124
|
-
return { [prop]: o[prop] };
|
|
125
|
-
}
|
|
126
|
-
})
|
|
127
|
-
);
|
|
128
|
-
}
|
|
129
|
-
|
|
130
|
-
// src/utils/typedInclude.ts
|
|
131
|
-
function typedInclude(arr, v) {
|
|
132
|
-
return arr.includes(v);
|
|
133
|
-
}
|
|
134
|
-
|
|
135
|
-
// src/utils/omit.ts
|
|
136
|
-
function omit(o, props) {
|
|
137
|
-
const propsArr = Array.isArray(props) ? props : [props];
|
|
138
|
-
const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
|
|
139
|
-
return pick(o, letsKeep);
|
|
140
|
-
}
|
|
141
|
-
|
|
142
|
-
// src/config.ts
|
|
143
|
-
var HF_HUB_URL = "https://huggingface.co";
|
|
144
|
-
var HF_ROUTER_URL = "https://router.huggingface.co";
|
|
145
|
-
var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
|
|
146
|
-
|
|
147
144
|
// src/utils/toArray.ts
|
|
148
145
|
function toArray(obj) {
|
|
149
146
|
if (Array.isArray(obj)) {
|
|
@@ -238,627 +235,736 @@ var BaseTextGenerationTask = class extends TaskProviderHelper {
|
|
|
238
235
|
}
|
|
239
236
|
};
|
|
240
237
|
|
|
241
|
-
// src/providers/
|
|
242
|
-
var
|
|
243
|
-
var
|
|
238
|
+
// src/providers/hf-inference.ts
|
|
239
|
+
var EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = ["feature-extraction", "sentence-similarity"];
|
|
240
|
+
var HFInferenceTask = class extends TaskProviderHelper {
|
|
244
241
|
constructor() {
|
|
245
|
-
super("
|
|
242
|
+
super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
|
|
246
243
|
}
|
|
247
244
|
preparePayload(params) {
|
|
248
|
-
return
|
|
249
|
-
...omit(params.args, ["inputs", "parameters"]),
|
|
250
|
-
...params.args.parameters,
|
|
251
|
-
prompt: params.args.inputs
|
|
252
|
-
};
|
|
245
|
+
return params.args;
|
|
253
246
|
}
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
};
|
|
258
|
-
if (!binary) {
|
|
259
|
-
headers["Content-Type"] = "application/json";
|
|
247
|
+
makeUrl(params) {
|
|
248
|
+
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
|
|
249
|
+
return params.model;
|
|
260
250
|
}
|
|
261
|
-
return
|
|
251
|
+
return super.makeUrl(params);
|
|
262
252
|
}
|
|
263
253
|
makeRoute(params) {
|
|
264
|
-
if (
|
|
265
|
-
|
|
254
|
+
if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
|
|
255
|
+
return `pipeline/${params.task}/${params.model}`;
|
|
266
256
|
}
|
|
267
|
-
return
|
|
257
|
+
return `models/${params.model}`;
|
|
258
|
+
}
|
|
259
|
+
async getResponse(response) {
|
|
260
|
+
return response;
|
|
268
261
|
}
|
|
262
|
+
};
|
|
263
|
+
var HFInferenceTextToImageTask = class extends HFInferenceTask {
|
|
269
264
|
async getResponse(response, url, headers, outputType) {
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
265
|
+
if (!response) {
|
|
266
|
+
throw new InferenceOutputError("response is undefined");
|
|
267
|
+
}
|
|
268
|
+
if (typeof response == "object") {
|
|
269
|
+
if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
|
|
270
|
+
const base64Data = response.data[0].b64_json;
|
|
271
|
+
if (outputType === "url") {
|
|
272
|
+
return `data:image/jpeg;base64,${base64Data}`;
|
|
273
|
+
}
|
|
274
|
+
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
275
|
+
return await base64Response.blob();
|
|
278
276
|
}
|
|
279
|
-
|
|
280
|
-
if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
|
|
277
|
+
if ("output" in response && Array.isArray(response.output)) {
|
|
281
278
|
if (outputType === "url") {
|
|
282
|
-
return
|
|
279
|
+
return response.output[0];
|
|
283
280
|
}
|
|
284
|
-
const
|
|
285
|
-
|
|
281
|
+
const urlResponse = await fetch(response.output[0]);
|
|
282
|
+
const blob = await urlResponse.blob();
|
|
283
|
+
return blob;
|
|
286
284
|
}
|
|
287
285
|
}
|
|
288
|
-
|
|
286
|
+
if (response instanceof Blob) {
|
|
287
|
+
if (outputType === "url") {
|
|
288
|
+
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
|
|
289
|
+
return `data:image/jpeg;base64,${b64}`;
|
|
290
|
+
}
|
|
291
|
+
return response;
|
|
292
|
+
}
|
|
293
|
+
throw new InferenceOutputError("Expected a Blob ");
|
|
289
294
|
}
|
|
290
295
|
};
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
+
var HFInferenceConversationalTask = class extends HFInferenceTask {
|
|
297
|
+
makeUrl(params) {
|
|
298
|
+
let url;
|
|
299
|
+
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
|
|
300
|
+
url = params.model.trim();
|
|
301
|
+
} else {
|
|
302
|
+
url = `${this.makeBaseUrl(params)}/models/${params.model}`;
|
|
303
|
+
}
|
|
304
|
+
url = url.replace(/\/+$/, "");
|
|
305
|
+
if (url.endsWith("/v1")) {
|
|
306
|
+
url += "/chat/completions";
|
|
307
|
+
} else if (!url.endsWith("/chat/completions")) {
|
|
308
|
+
url += "/v1/chat/completions";
|
|
309
|
+
}
|
|
310
|
+
return url;
|
|
311
|
+
}
|
|
312
|
+
preparePayload(params) {
|
|
313
|
+
return {
|
|
314
|
+
...params.args,
|
|
315
|
+
model: params.model
|
|
316
|
+
};
|
|
317
|
+
}
|
|
318
|
+
async getResponse(response) {
|
|
319
|
+
return response;
|
|
296
320
|
}
|
|
297
321
|
};
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
322
|
+
var HFInferenceTextGenerationTask = class extends HFInferenceTask {
|
|
323
|
+
async getResponse(response) {
|
|
324
|
+
const res = toArray(response);
|
|
325
|
+
if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
|
|
326
|
+
return res?.[0];
|
|
327
|
+
}
|
|
328
|
+
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
303
329
|
}
|
|
304
|
-
|
|
305
|
-
|
|
330
|
+
};
|
|
331
|
+
var HFInferenceAudioClassificationTask = class extends HFInferenceTask {
|
|
332
|
+
async getResponse(response) {
|
|
333
|
+
if (Array.isArray(response) && response.every(
|
|
334
|
+
(x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
|
|
335
|
+
)) {
|
|
336
|
+
return response;
|
|
337
|
+
}
|
|
338
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
|
|
306
339
|
}
|
|
307
340
|
};
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
|
|
312
|
-
}
|
|
313
|
-
|
|
314
|
-
// src/providers/fal-ai.ts
|
|
315
|
-
var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
|
|
316
|
-
var FalAITask = class extends TaskProviderHelper {
|
|
317
|
-
constructor(url) {
|
|
318
|
-
super("fal-ai", url || "https://fal.run");
|
|
341
|
+
var HFInferenceAutomaticSpeechRecognitionTask = class extends HFInferenceTask {
|
|
342
|
+
async getResponse(response) {
|
|
343
|
+
return response;
|
|
319
344
|
}
|
|
320
|
-
|
|
321
|
-
|
|
345
|
+
};
|
|
346
|
+
var HFInferenceAudioToAudioTask = class extends HFInferenceTask {
|
|
347
|
+
async getResponse(response) {
|
|
348
|
+
if (!Array.isArray(response)) {
|
|
349
|
+
throw new InferenceOutputError("Expected Array");
|
|
350
|
+
}
|
|
351
|
+
if (!response.every((elem) => {
|
|
352
|
+
return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
|
|
353
|
+
})) {
|
|
354
|
+
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
|
|
355
|
+
}
|
|
356
|
+
return response;
|
|
322
357
|
}
|
|
323
|
-
|
|
324
|
-
|
|
358
|
+
};
|
|
359
|
+
var HFInferenceDocumentQuestionAnsweringTask = class extends HFInferenceTask {
|
|
360
|
+
async getResponse(response) {
|
|
361
|
+
if (Array.isArray(response) && response.every(
|
|
362
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
|
|
363
|
+
)) {
|
|
364
|
+
return response[0];
|
|
365
|
+
}
|
|
366
|
+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
325
367
|
}
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
368
|
+
};
|
|
369
|
+
var HFInferenceFeatureExtractionTask = class extends HFInferenceTask {
|
|
370
|
+
async getResponse(response) {
|
|
371
|
+
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
|
|
372
|
+
if (curDepth > maxDepth)
|
|
373
|
+
return false;
|
|
374
|
+
if (arr.every((x) => Array.isArray(x))) {
|
|
375
|
+
return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
|
|
376
|
+
} else {
|
|
377
|
+
return arr.every((x) => typeof x === "number");
|
|
378
|
+
}
|
|
329
379
|
};
|
|
330
|
-
if (
|
|
331
|
-
|
|
380
|
+
if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
|
|
381
|
+
return response;
|
|
332
382
|
}
|
|
333
|
-
|
|
383
|
+
throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
|
|
334
384
|
}
|
|
335
385
|
};
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
preparePayload(params) {
|
|
341
|
-
const payload = {
|
|
342
|
-
...omit(params.args, ["inputs", "parameters"]),
|
|
343
|
-
...params.args.parameters,
|
|
344
|
-
sync_mode: true,
|
|
345
|
-
prompt: params.args.inputs
|
|
346
|
-
};
|
|
347
|
-
if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
|
|
348
|
-
payload.loras = [
|
|
349
|
-
{
|
|
350
|
-
path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
|
|
351
|
-
scale: 1
|
|
352
|
-
}
|
|
353
|
-
];
|
|
354
|
-
if (params.mapping.providerId === "fal-ai/lora") {
|
|
355
|
-
payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
|
|
356
|
-
}
|
|
386
|
+
var HFInferenceImageClassificationTask = class extends HFInferenceTask {
|
|
387
|
+
async getResponse(response) {
|
|
388
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
|
|
389
|
+
return response;
|
|
357
390
|
}
|
|
358
|
-
|
|
391
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
359
392
|
}
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
const urlResponse = await fetch(response.images[0].url);
|
|
366
|
-
return await urlResponse.blob();
|
|
393
|
+
};
|
|
394
|
+
var HFInferenceImageSegmentationTask = class extends HFInferenceTask {
|
|
395
|
+
async getResponse(response) {
|
|
396
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) {
|
|
397
|
+
return response;
|
|
367
398
|
}
|
|
368
|
-
throw new InferenceOutputError("Expected
|
|
399
|
+
throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
|
|
369
400
|
}
|
|
370
401
|
};
|
|
371
|
-
var
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
makeRoute(params) {
|
|
376
|
-
if (params.authMethod !== "provider-key") {
|
|
377
|
-
return `/${params.model}?_subdomain=queue`;
|
|
402
|
+
var HFInferenceImageToTextTask = class extends HFInferenceTask {
|
|
403
|
+
async getResponse(response) {
|
|
404
|
+
if (typeof response?.generated_text !== "string") {
|
|
405
|
+
throw new InferenceOutputError("Expected {generated_text: string}");
|
|
378
406
|
}
|
|
379
|
-
return
|
|
407
|
+
return response;
|
|
380
408
|
}
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
}
|
|
409
|
+
};
|
|
410
|
+
var HFInferenceImageToImageTask = class extends HFInferenceTask {
|
|
411
|
+
async getResponse(response) {
|
|
412
|
+
if (response instanceof Blob) {
|
|
413
|
+
return response;
|
|
414
|
+
}
|
|
415
|
+
throw new InferenceOutputError("Expected Blob");
|
|
387
416
|
}
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
417
|
+
};
|
|
418
|
+
var HFInferenceObjectDetectionTask = class extends HFInferenceTask {
|
|
419
|
+
async getResponse(response) {
|
|
420
|
+
if (Array.isArray(response) && response.every(
|
|
421
|
+
(x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
|
|
422
|
+
)) {
|
|
423
|
+
return response;
|
|
391
424
|
}
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
425
|
+
throw new InferenceOutputError(
|
|
426
|
+
"Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
|
|
427
|
+
);
|
|
428
|
+
}
|
|
429
|
+
};
|
|
430
|
+
var HFInferenceZeroShotImageClassificationTask = class extends HFInferenceTask {
|
|
431
|
+
async getResponse(response) {
|
|
432
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
|
|
433
|
+
return response;
|
|
395
434
|
}
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
const
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
await delay(500);
|
|
405
|
-
const statusResponse = await fetch(statusUrl, { headers });
|
|
406
|
-
if (!statusResponse.ok) {
|
|
407
|
-
throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
|
|
408
|
-
}
|
|
409
|
-
try {
|
|
410
|
-
status = (await statusResponse.json()).status;
|
|
411
|
-
} catch (error) {
|
|
412
|
-
throw new InferenceOutputError("Failed to parse status response from fal-ai API");
|
|
413
|
-
}
|
|
435
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
436
|
+
}
|
|
437
|
+
};
|
|
438
|
+
var HFInferenceTextClassificationTask = class extends HFInferenceTask {
|
|
439
|
+
async getResponse(response) {
|
|
440
|
+
const output = response?.[0];
|
|
441
|
+
if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
|
|
442
|
+
return output;
|
|
414
443
|
}
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
444
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
445
|
+
}
|
|
446
|
+
};
|
|
447
|
+
var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask {
|
|
448
|
+
async getResponse(response) {
|
|
449
|
+
if (Array.isArray(response) ? response.every(
|
|
450
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
|
|
451
|
+
) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number") {
|
|
452
|
+
return Array.isArray(response) ? response[0] : response;
|
|
421
453
|
}
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
)
|
|
454
|
+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
455
|
+
}
|
|
456
|
+
};
|
|
457
|
+
var HFInferenceFillMaskTask = class extends HFInferenceTask {
|
|
458
|
+
async getResponse(response) {
|
|
459
|
+
if (Array.isArray(response) && response.every(
|
|
460
|
+
(x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
|
|
461
|
+
)) {
|
|
462
|
+
return response;
|
|
429
463
|
}
|
|
464
|
+
throw new InferenceOutputError(
|
|
465
|
+
"Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
|
|
466
|
+
);
|
|
430
467
|
}
|
|
431
468
|
};
|
|
432
|
-
var
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
469
|
+
var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask {
|
|
470
|
+
async getResponse(response) {
|
|
471
|
+
if (Array.isArray(response) && response.every(
|
|
472
|
+
(x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
|
|
473
|
+
)) {
|
|
474
|
+
return response;
|
|
475
|
+
}
|
|
476
|
+
throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
|
|
437
477
|
}
|
|
478
|
+
};
|
|
479
|
+
var HFInferenceSentenceSimilarityTask = class extends HFInferenceTask {
|
|
438
480
|
async getResponse(response) {
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
throw new InferenceOutputError(
|
|
442
|
-
`Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
|
|
443
|
-
);
|
|
481
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
482
|
+
return response;
|
|
444
483
|
}
|
|
445
|
-
|
|
484
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
446
485
|
}
|
|
447
486
|
};
|
|
448
|
-
var
|
|
449
|
-
|
|
450
|
-
return
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
text: params.args.inputs
|
|
454
|
-
};
|
|
487
|
+
var HFInferenceTableQuestionAnsweringTask = class extends HFInferenceTask {
|
|
488
|
+
static validate(elem) {
|
|
489
|
+
return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
|
|
490
|
+
(coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
|
|
491
|
+
);
|
|
455
492
|
}
|
|
456
493
|
async getResponse(response) {
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
throw new InferenceOutputError(
|
|
460
|
-
`Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
|
|
461
|
-
);
|
|
494
|
+
if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) {
|
|
495
|
+
return Array.isArray(response) ? response[0] : response;
|
|
462
496
|
}
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
497
|
+
throw new InferenceOutputError(
|
|
498
|
+
"Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
|
|
499
|
+
);
|
|
500
|
+
}
|
|
501
|
+
};
|
|
502
|
+
var HFInferenceTokenClassificationTask = class extends HFInferenceTask {
|
|
503
|
+
async getResponse(response) {
|
|
504
|
+
if (Array.isArray(response) && response.every(
|
|
505
|
+
(x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
|
|
506
|
+
)) {
|
|
507
|
+
return response;
|
|
473
508
|
}
|
|
509
|
+
throw new InferenceOutputError(
|
|
510
|
+
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
|
|
511
|
+
);
|
|
474
512
|
}
|
|
475
513
|
};
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
514
|
+
var HFInferenceTranslationTask = class extends HFInferenceTask {
|
|
515
|
+
async getResponse(response) {
|
|
516
|
+
if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
|
|
517
|
+
return response?.length === 1 ? response?.[0] : response;
|
|
518
|
+
}
|
|
519
|
+
throw new InferenceOutputError("Expected Array<{translation_text: string}>");
|
|
482
520
|
}
|
|
483
521
|
};
|
|
484
|
-
var
|
|
485
|
-
|
|
486
|
-
|
|
522
|
+
var HFInferenceSummarizationTask = class extends HFInferenceTask {
|
|
523
|
+
async getResponse(response) {
|
|
524
|
+
if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
|
|
525
|
+
return response?.[0];
|
|
526
|
+
}
|
|
527
|
+
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
|
487
528
|
}
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
model: params.model,
|
|
493
|
-
prompt: params.args.inputs
|
|
494
|
-
};
|
|
529
|
+
};
|
|
530
|
+
var HFInferenceTextToSpeechTask = class extends HFInferenceTask {
|
|
531
|
+
async getResponse(response) {
|
|
532
|
+
return response;
|
|
495
533
|
}
|
|
534
|
+
};
|
|
535
|
+
var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
|
|
496
536
|
async getResponse(response) {
|
|
497
|
-
if (
|
|
498
|
-
|
|
499
|
-
return {
|
|
500
|
-
generated_text: completion.text
|
|
501
|
-
};
|
|
537
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
538
|
+
return response;
|
|
502
539
|
}
|
|
503
|
-
throw new InferenceOutputError("Expected
|
|
540
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
504
541
|
}
|
|
505
542
|
};
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
543
|
+
var HFInferenceVisualQuestionAnsweringTask = class extends HFInferenceTask {
|
|
544
|
+
async getResponse(response) {
|
|
545
|
+
if (Array.isArray(response) && response.every(
|
|
546
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
|
|
547
|
+
)) {
|
|
548
|
+
return response[0];
|
|
549
|
+
}
|
|
550
|
+
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
|
511
551
|
}
|
|
512
|
-
|
|
513
|
-
|
|
552
|
+
};
|
|
553
|
+
var HFInferenceTabularRegressionTask = class extends HFInferenceTask {
|
|
554
|
+
async getResponse(response) {
|
|
555
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
556
|
+
return response;
|
|
557
|
+
}
|
|
558
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
559
|
+
}
|
|
560
|
+
};
|
|
561
|
+
var HFInferenceTextToAudioTask = class extends HFInferenceTask {
|
|
562
|
+
async getResponse(response) {
|
|
563
|
+
return response;
|
|
514
564
|
}
|
|
515
565
|
};
|
|
516
566
|
|
|
517
|
-
// src/
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
567
|
+
// src/utils/typedInclude.ts
|
|
568
|
+
function typedInclude(arr, v) {
|
|
569
|
+
return arr.includes(v);
|
|
570
|
+
}
|
|
571
|
+
|
|
572
|
+
// src/lib/getInferenceProviderMapping.ts
|
|
573
|
+
var inferenceProviderMappingCache = /* @__PURE__ */ new Map();
|
|
574
|
+
async function fetchInferenceProviderMappingForModel(modelId, accessToken, options) {
|
|
575
|
+
let inferenceProviderMapping;
|
|
576
|
+
if (inferenceProviderMappingCache.has(modelId)) {
|
|
577
|
+
inferenceProviderMapping = inferenceProviderMappingCache.get(modelId);
|
|
578
|
+
} else {
|
|
579
|
+
const resp = await (options?.fetch ?? fetch)(
|
|
580
|
+
`${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
|
|
581
|
+
{
|
|
582
|
+
headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {}
|
|
583
|
+
}
|
|
584
|
+
);
|
|
585
|
+
if (resp.status === 404) {
|
|
586
|
+
throw new Error(`Model ${modelId} does not exist`);
|
|
587
|
+
}
|
|
588
|
+
inferenceProviderMapping = await resp.json().then((json) => json.inferenceProviderMapping).catch(() => null);
|
|
589
|
+
if (inferenceProviderMapping) {
|
|
590
|
+
inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
|
|
591
|
+
}
|
|
522
592
|
}
|
|
523
|
-
|
|
524
|
-
|
|
593
|
+
if (!inferenceProviderMapping) {
|
|
594
|
+
throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
|
|
525
595
|
}
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
596
|
+
return inferenceProviderMapping;
|
|
597
|
+
}
|
|
598
|
+
async function getInferenceProviderMapping(params, options) {
|
|
599
|
+
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
|
|
600
|
+
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
|
|
530
601
|
}
|
|
531
|
-
|
|
532
|
-
|
|
602
|
+
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
|
|
603
|
+
params.modelId,
|
|
604
|
+
params.accessToken,
|
|
605
|
+
options
|
|
606
|
+
);
|
|
607
|
+
const providerMapping = inferenceProviderMapping[params.provider];
|
|
608
|
+
if (providerMapping) {
|
|
609
|
+
const equivalentTasks = params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task];
|
|
610
|
+
if (!typedInclude(equivalentTasks, providerMapping.task)) {
|
|
611
|
+
throw new Error(
|
|
612
|
+
`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
|
|
613
|
+
);
|
|
614
|
+
}
|
|
615
|
+
if (providerMapping.status === "staging") {
|
|
616
|
+
console.warn(
|
|
617
|
+
`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
|
|
618
|
+
);
|
|
619
|
+
}
|
|
620
|
+
return { ...providerMapping, hfModelId: params.modelId };
|
|
533
621
|
}
|
|
534
|
-
|
|
622
|
+
return null;
|
|
623
|
+
}
|
|
624
|
+
async function resolveProvider(provider, modelId, endpointUrl) {
|
|
625
|
+
if (endpointUrl) {
|
|
626
|
+
if (provider) {
|
|
627
|
+
throw new Error("Specifying both endpointUrl and provider is not supported.");
|
|
628
|
+
}
|
|
629
|
+
return "hf-inference";
|
|
630
|
+
}
|
|
631
|
+
if (!provider) {
|
|
632
|
+
console.log(
|
|
633
|
+
"Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
|
|
634
|
+
);
|
|
635
|
+
provider = "auto";
|
|
636
|
+
}
|
|
637
|
+
if (provider === "auto") {
|
|
638
|
+
if (!modelId) {
|
|
639
|
+
throw new Error("Specifying a model is required when provider is 'auto'");
|
|
640
|
+
}
|
|
641
|
+
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
|
|
642
|
+
provider = Object.keys(inferenceProviderMapping)[0];
|
|
643
|
+
}
|
|
644
|
+
if (!provider) {
|
|
645
|
+
throw new Error(`No Inference Provider available for model ${modelId}.`);
|
|
646
|
+
}
|
|
647
|
+
return provider;
|
|
648
|
+
}
|
|
535
649
|
|
|
536
|
-
// src/
|
|
537
|
-
|
|
538
|
-
|
|
650
|
+
// src/utils/delay.ts
|
|
651
|
+
function delay(ms) {
|
|
652
|
+
return new Promise((resolve) => {
|
|
653
|
+
setTimeout(() => resolve(), ms);
|
|
654
|
+
});
|
|
655
|
+
}
|
|
656
|
+
|
|
657
|
+
// src/utils/pick.ts
|
|
658
|
+
function pick(o, props) {
|
|
659
|
+
return Object.assign(
|
|
660
|
+
{},
|
|
661
|
+
...props.map((prop) => {
|
|
662
|
+
if (o[prop] !== void 0) {
|
|
663
|
+
return { [prop]: o[prop] };
|
|
664
|
+
}
|
|
665
|
+
})
|
|
666
|
+
);
|
|
667
|
+
}
|
|
668
|
+
|
|
669
|
+
// src/utils/omit.ts
|
|
670
|
+
function omit(o, props) {
|
|
671
|
+
const propsArr = Array.isArray(props) ? props : [props];
|
|
672
|
+
const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
|
|
673
|
+
return pick(o, letsKeep);
|
|
674
|
+
}
|
|
675
|
+
|
|
676
|
+
// src/providers/black-forest-labs.ts
|
|
677
|
+
var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
|
|
678
|
+
var BlackForestLabsTextToImageTask = class extends TaskProviderHelper {
|
|
539
679
|
constructor() {
|
|
540
|
-
super("
|
|
680
|
+
super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
|
|
541
681
|
}
|
|
542
682
|
preparePayload(params) {
|
|
543
|
-
return
|
|
683
|
+
return {
|
|
684
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
685
|
+
...params.args.parameters,
|
|
686
|
+
prompt: params.args.inputs
|
|
687
|
+
};
|
|
544
688
|
}
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
689
|
+
prepareHeaders(params, binary) {
|
|
690
|
+
const headers = {
|
|
691
|
+
Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`
|
|
692
|
+
};
|
|
693
|
+
if (!binary) {
|
|
694
|
+
headers["Content-Type"] = "application/json";
|
|
548
695
|
}
|
|
549
|
-
return
|
|
696
|
+
return headers;
|
|
550
697
|
}
|
|
551
698
|
makeRoute(params) {
|
|
552
|
-
if (params
|
|
553
|
-
|
|
699
|
+
if (!params) {
|
|
700
|
+
throw new Error("Params are required");
|
|
554
701
|
}
|
|
555
|
-
return
|
|
556
|
-
}
|
|
557
|
-
async getResponse(response) {
|
|
558
|
-
return response;
|
|
702
|
+
return `/v1/${params.model}`;
|
|
559
703
|
}
|
|
560
|
-
};
|
|
561
|
-
var HFInferenceTextToImageTask = class extends HFInferenceTask {
|
|
562
704
|
async getResponse(response, url, headers, outputType) {
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
}
|
|
572
|
-
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
573
|
-
return await base64Response.blob();
|
|
705
|
+
const urlObj = new URL(response.polling_url);
|
|
706
|
+
for (let step = 0; step < 5; step++) {
|
|
707
|
+
await delay(1e3);
|
|
708
|
+
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
|
|
709
|
+
urlObj.searchParams.set("attempt", step.toString(10));
|
|
710
|
+
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
|
|
711
|
+
if (!resp.ok) {
|
|
712
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
574
713
|
}
|
|
575
|
-
|
|
714
|
+
const payload = await resp.json();
|
|
715
|
+
if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
|
|
576
716
|
if (outputType === "url") {
|
|
577
|
-
return
|
|
717
|
+
return payload.result.sample;
|
|
578
718
|
}
|
|
579
|
-
const
|
|
580
|
-
|
|
581
|
-
return blob;
|
|
582
|
-
}
|
|
583
|
-
}
|
|
584
|
-
if (response instanceof Blob) {
|
|
585
|
-
if (outputType === "url") {
|
|
586
|
-
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
|
|
587
|
-
return `data:image/jpeg;base64,${b64}`;
|
|
719
|
+
const image = await fetch(payload.result.sample);
|
|
720
|
+
return await image.blob();
|
|
588
721
|
}
|
|
589
|
-
return response;
|
|
590
722
|
}
|
|
591
|
-
throw new InferenceOutputError("
|
|
723
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
592
724
|
}
|
|
593
725
|
};
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
} else {
|
|
600
|
-
url = `${this.makeBaseUrl(params)}/models/${params.model}`;
|
|
601
|
-
}
|
|
602
|
-
url = url.replace(/\/+$/, "");
|
|
603
|
-
if (url.endsWith("/v1")) {
|
|
604
|
-
url += "/chat/completions";
|
|
605
|
-
} else if (!url.endsWith("/chat/completions")) {
|
|
606
|
-
url += "/v1/chat/completions";
|
|
607
|
-
}
|
|
608
|
-
return url;
|
|
609
|
-
}
|
|
610
|
-
preparePayload(params) {
|
|
611
|
-
return {
|
|
612
|
-
...params.args,
|
|
613
|
-
model: params.model
|
|
614
|
-
};
|
|
615
|
-
}
|
|
616
|
-
async getResponse(response) {
|
|
617
|
-
return response;
|
|
726
|
+
|
|
727
|
+
// src/providers/cerebras.ts
|
|
728
|
+
var CerebrasConversationalTask = class extends BaseConversationalTask {
|
|
729
|
+
constructor() {
|
|
730
|
+
super("cerebras", "https://api.cerebras.ai");
|
|
618
731
|
}
|
|
619
732
|
};
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
}
|
|
626
|
-
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
733
|
+
|
|
734
|
+
// src/providers/cohere.ts
|
|
735
|
+
var CohereConversationalTask = class extends BaseConversationalTask {
|
|
736
|
+
constructor() {
|
|
737
|
+
super("cohere", "https://api.cohere.com");
|
|
627
738
|
}
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
async getResponse(response) {
|
|
631
|
-
if (Array.isArray(response) && response.every(
|
|
632
|
-
(x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
|
|
633
|
-
)) {
|
|
634
|
-
return response;
|
|
635
|
-
}
|
|
636
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
|
|
739
|
+
makeRoute() {
|
|
740
|
+
return "/compatibility/v1/chat/completions";
|
|
637
741
|
}
|
|
638
742
|
};
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
743
|
+
|
|
744
|
+
// src/lib/isUrl.ts
|
|
745
|
+
function isUrl(modelOrUrl) {
|
|
746
|
+
return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
// src/providers/fal-ai.ts
|
|
750
|
+
var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
|
|
751
|
+
var FalAITask = class extends TaskProviderHelper {
|
|
752
|
+
constructor(url) {
|
|
753
|
+
super("fal-ai", url || "https://fal.run");
|
|
642
754
|
}
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
async getResponse(response) {
|
|
646
|
-
if (!Array.isArray(response)) {
|
|
647
|
-
throw new InferenceOutputError("Expected Array");
|
|
648
|
-
}
|
|
649
|
-
if (!response.every((elem) => {
|
|
650
|
-
return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
|
|
651
|
-
})) {
|
|
652
|
-
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
|
|
653
|
-
}
|
|
654
|
-
return response;
|
|
755
|
+
preparePayload(params) {
|
|
756
|
+
return params.args;
|
|
655
757
|
}
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
async getResponse(response) {
|
|
659
|
-
if (Array.isArray(response) && response.every(
|
|
660
|
-
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
|
|
661
|
-
)) {
|
|
662
|
-
return response[0];
|
|
663
|
-
}
|
|
664
|
-
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
758
|
+
makeRoute(params) {
|
|
759
|
+
return `/${params.model}`;
|
|
665
760
|
}
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
|
|
670
|
-
if (curDepth > maxDepth)
|
|
671
|
-
return false;
|
|
672
|
-
if (arr.every((x) => Array.isArray(x))) {
|
|
673
|
-
return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
|
|
674
|
-
} else {
|
|
675
|
-
return arr.every((x) => typeof x === "number");
|
|
676
|
-
}
|
|
761
|
+
prepareHeaders(params, binary) {
|
|
762
|
+
const headers = {
|
|
763
|
+
Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`
|
|
677
764
|
};
|
|
678
|
-
if (
|
|
679
|
-
|
|
765
|
+
if (!binary) {
|
|
766
|
+
headers["Content-Type"] = "application/json";
|
|
680
767
|
}
|
|
681
|
-
|
|
768
|
+
return headers;
|
|
682
769
|
}
|
|
683
770
|
};
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
771
|
+
function buildLoraPath(modelId, adapterWeightsPath) {
|
|
772
|
+
return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
|
|
773
|
+
}
|
|
774
|
+
var FalAITextToImageTask = class extends FalAITask {
|
|
775
|
+
preparePayload(params) {
|
|
776
|
+
const payload = {
|
|
777
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
778
|
+
...params.args.parameters,
|
|
779
|
+
sync_mode: true,
|
|
780
|
+
prompt: params.args.inputs
|
|
781
|
+
};
|
|
782
|
+
if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
|
|
783
|
+
payload.loras = [
|
|
784
|
+
{
|
|
785
|
+
path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
|
|
786
|
+
scale: 1
|
|
787
|
+
}
|
|
788
|
+
];
|
|
789
|
+
if (params.mapping.providerId === "fal-ai/lora") {
|
|
790
|
+
payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
|
|
791
|
+
}
|
|
688
792
|
}
|
|
689
|
-
|
|
793
|
+
return payload;
|
|
690
794
|
}
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
795
|
+
async getResponse(response, outputType) {
|
|
796
|
+
if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") {
|
|
797
|
+
if (outputType === "url") {
|
|
798
|
+
return response.images[0].url;
|
|
799
|
+
}
|
|
800
|
+
const urlResponse = await fetch(response.images[0].url);
|
|
801
|
+
return await urlResponse.blob();
|
|
696
802
|
}
|
|
697
|
-
throw new InferenceOutputError("Expected
|
|
803
|
+
throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
|
|
698
804
|
}
|
|
699
805
|
};
|
|
700
|
-
var
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
throw new InferenceOutputError("Expected {generated_text: string}");
|
|
704
|
-
}
|
|
705
|
-
return response;
|
|
806
|
+
var FalAITextToVideoTask = class extends FalAITask {
|
|
807
|
+
constructor() {
|
|
808
|
+
super("https://queue.fal.run");
|
|
706
809
|
}
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
if (response instanceof Blob) {
|
|
711
|
-
return response;
|
|
810
|
+
makeRoute(params) {
|
|
811
|
+
if (params.authMethod !== "provider-key") {
|
|
812
|
+
return `/${params.model}?_subdomain=queue`;
|
|
712
813
|
}
|
|
713
|
-
|
|
814
|
+
return `/${params.model}`;
|
|
714
815
|
}
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
return response;
|
|
722
|
-
}
|
|
723
|
-
throw new InferenceOutputError(
|
|
724
|
-
"Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
|
|
725
|
-
);
|
|
816
|
+
preparePayload(params) {
|
|
817
|
+
return {
|
|
818
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
819
|
+
...params.args.parameters,
|
|
820
|
+
prompt: params.args.inputs
|
|
821
|
+
};
|
|
726
822
|
}
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
|
|
731
|
-
return response;
|
|
823
|
+
async getResponse(response, url, headers) {
|
|
824
|
+
if (!url || !headers) {
|
|
825
|
+
throw new InferenceOutputError("URL and headers are required for text-to-video task");
|
|
732
826
|
}
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
var HFInferenceTextClassificationTask = class extends HFInferenceTask {
|
|
737
|
-
async getResponse(response) {
|
|
738
|
-
const output = response?.[0];
|
|
739
|
-
if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
|
|
740
|
-
return output;
|
|
827
|
+
const requestId = response.request_id;
|
|
828
|
+
if (!requestId) {
|
|
829
|
+
throw new InferenceOutputError("No request ID found in the response");
|
|
741
830
|
}
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
}
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
831
|
+
let status = response.status;
|
|
832
|
+
const parsedUrl = new URL(url);
|
|
833
|
+
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
|
|
834
|
+
const modelId = new URL(response.response_url).pathname;
|
|
835
|
+
const queryParams = parsedUrl.search;
|
|
836
|
+
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
|
|
837
|
+
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
|
|
838
|
+
while (status !== "COMPLETED") {
|
|
839
|
+
await delay(500);
|
|
840
|
+
const statusResponse = await fetch(statusUrl, { headers });
|
|
841
|
+
if (!statusResponse.ok) {
|
|
842
|
+
throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
|
|
843
|
+
}
|
|
844
|
+
try {
|
|
845
|
+
status = (await statusResponse.json()).status;
|
|
846
|
+
} catch (error) {
|
|
847
|
+
throw new InferenceOutputError("Failed to parse status response from fal-ai API");
|
|
848
|
+
}
|
|
751
849
|
}
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
(x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
|
|
759
|
-
)) {
|
|
760
|
-
return response;
|
|
850
|
+
const resultResponse = await fetch(resultUrl, { headers });
|
|
851
|
+
let result;
|
|
852
|
+
try {
|
|
853
|
+
result = await resultResponse.json();
|
|
854
|
+
} catch (error) {
|
|
855
|
+
throw new InferenceOutputError("Failed to parse result response from fal-ai API");
|
|
761
856
|
}
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
if (Array.isArray(response) && response.every(
|
|
770
|
-
(x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
|
|
771
|
-
)) {
|
|
772
|
-
return response;
|
|
857
|
+
if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
|
|
858
|
+
const urlResponse = await fetch(result.video.url);
|
|
859
|
+
return await urlResponse.blob();
|
|
860
|
+
} else {
|
|
861
|
+
throw new InferenceOutputError(
|
|
862
|
+
"Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
|
|
863
|
+
);
|
|
773
864
|
}
|
|
774
|
-
throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
|
|
775
865
|
}
|
|
776
866
|
};
|
|
777
|
-
var
|
|
867
|
+
var FalAIAutomaticSpeechRecognitionTask = class extends FalAITask {
|
|
868
|
+
prepareHeaders(params, binary) {
|
|
869
|
+
const headers = super.prepareHeaders(params, binary);
|
|
870
|
+
headers["Content-Type"] = "application/json";
|
|
871
|
+
return headers;
|
|
872
|
+
}
|
|
778
873
|
async getResponse(response) {
|
|
779
|
-
|
|
780
|
-
|
|
874
|
+
const res = response;
|
|
875
|
+
if (typeof res?.text !== "string") {
|
|
876
|
+
throw new InferenceOutputError(
|
|
877
|
+
`Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
|
|
878
|
+
);
|
|
781
879
|
}
|
|
782
|
-
|
|
880
|
+
return { text: res.text };
|
|
783
881
|
}
|
|
784
882
|
};
|
|
785
|
-
var
|
|
786
|
-
|
|
787
|
-
return
|
|
788
|
-
(
|
|
789
|
-
|
|
883
|
+
var FalAITextToSpeechTask = class extends FalAITask {
|
|
884
|
+
preparePayload(params) {
|
|
885
|
+
return {
|
|
886
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
887
|
+
...params.args.parameters,
|
|
888
|
+
text: params.args.inputs
|
|
889
|
+
};
|
|
790
890
|
}
|
|
791
891
|
async getResponse(response) {
|
|
792
|
-
|
|
793
|
-
|
|
892
|
+
const res = response;
|
|
893
|
+
if (typeof res?.audio?.url !== "string") {
|
|
894
|
+
throw new InferenceOutputError(
|
|
895
|
+
`Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
|
|
896
|
+
);
|
|
794
897
|
}
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
}
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
return response;
|
|
898
|
+
try {
|
|
899
|
+
const urlResponse = await fetch(res.audio.url);
|
|
900
|
+
if (!urlResponse.ok) {
|
|
901
|
+
throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
|
|
902
|
+
}
|
|
903
|
+
return await urlResponse.blob();
|
|
904
|
+
} catch (error) {
|
|
905
|
+
throw new InferenceOutputError(
|
|
906
|
+
`Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)}`
|
|
907
|
+
);
|
|
806
908
|
}
|
|
807
|
-
throw new InferenceOutputError(
|
|
808
|
-
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
|
|
809
|
-
);
|
|
810
909
|
}
|
|
811
910
|
};
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
911
|
+
|
|
912
|
+
// src/providers/featherless-ai.ts
|
|
913
|
+
var FEATHERLESS_API_BASE_URL = "https://api.featherless.ai";
|
|
914
|
+
var FeatherlessAIConversationalTask = class extends BaseConversationalTask {
|
|
915
|
+
constructor() {
|
|
916
|
+
super("featherless-ai", FEATHERLESS_API_BASE_URL);
|
|
818
917
|
}
|
|
819
918
|
};
|
|
820
|
-
var
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
return response?.[0];
|
|
824
|
-
}
|
|
825
|
-
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
|
919
|
+
var FeatherlessAITextGenerationTask = class extends BaseTextGenerationTask {
|
|
920
|
+
constructor() {
|
|
921
|
+
super("featherless-ai", FEATHERLESS_API_BASE_URL);
|
|
826
922
|
}
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
923
|
+
preparePayload(params) {
|
|
924
|
+
return {
|
|
925
|
+
...params.args,
|
|
926
|
+
...params.args.parameters,
|
|
927
|
+
model: params.model,
|
|
928
|
+
prompt: params.args.inputs
|
|
929
|
+
};
|
|
831
930
|
}
|
|
832
|
-
};
|
|
833
|
-
var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
|
|
834
931
|
async getResponse(response) {
|
|
835
|
-
if (
|
|
836
|
-
|
|
932
|
+
if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
|
|
933
|
+
const completion = response.choices[0];
|
|
934
|
+
return {
|
|
935
|
+
generated_text: completion.text
|
|
936
|
+
};
|
|
837
937
|
}
|
|
838
|
-
throw new InferenceOutputError("Expected
|
|
938
|
+
throw new InferenceOutputError("Expected Featherless AI text generation response format");
|
|
839
939
|
}
|
|
840
940
|
};
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
)
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
941
|
+
|
|
942
|
+
// src/providers/fireworks-ai.ts
|
|
943
|
+
var FireworksConversationalTask = class extends BaseConversationalTask {
|
|
944
|
+
constructor() {
|
|
945
|
+
super("fireworks-ai", "https://api.fireworks.ai");
|
|
946
|
+
}
|
|
947
|
+
makeRoute() {
|
|
948
|
+
return "/inference/v1/chat/completions";
|
|
849
949
|
}
|
|
850
950
|
};
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
951
|
+
|
|
952
|
+
// src/providers/groq.ts
|
|
953
|
+
var GROQ_API_BASE_URL = "https://api.groq.com";
|
|
954
|
+
var GroqTextGenerationTask = class extends BaseTextGenerationTask {
|
|
955
|
+
constructor() {
|
|
956
|
+
super("groq", GROQ_API_BASE_URL);
|
|
957
|
+
}
|
|
958
|
+
makeRoute() {
|
|
959
|
+
return "/openai/v1/chat/completions";
|
|
857
960
|
}
|
|
858
961
|
};
|
|
859
|
-
var
|
|
860
|
-
|
|
861
|
-
|
|
962
|
+
var GroqConversationalTask = class extends BaseConversationalTask {
|
|
963
|
+
constructor() {
|
|
964
|
+
super("groq", GROQ_API_BASE_URL);
|
|
965
|
+
}
|
|
966
|
+
makeRoute() {
|
|
967
|
+
return "/openai/v1/chat/completions";
|
|
862
968
|
}
|
|
863
969
|
};
|
|
864
970
|
|
|
@@ -1025,6 +1131,39 @@ var OpenAIConversationalTask = class extends BaseConversationalTask {
|
|
|
1025
1131
|
}
|
|
1026
1132
|
};
|
|
1027
1133
|
|
|
1134
|
+
// src/providers/ovhcloud.ts
|
|
1135
|
+
var OVHCLOUD_API_BASE_URL = "https://oai.endpoints.kepler.ai.cloud.ovh.net";
|
|
1136
|
+
var OvhCloudConversationalTask = class extends BaseConversationalTask {
|
|
1137
|
+
constructor() {
|
|
1138
|
+
super("ovhcloud", OVHCLOUD_API_BASE_URL);
|
|
1139
|
+
}
|
|
1140
|
+
};
|
|
1141
|
+
var OvhCloudTextGenerationTask = class extends BaseTextGenerationTask {
|
|
1142
|
+
constructor() {
|
|
1143
|
+
super("ovhcloud", OVHCLOUD_API_BASE_URL);
|
|
1144
|
+
}
|
|
1145
|
+
preparePayload(params) {
|
|
1146
|
+
return {
|
|
1147
|
+
model: params.model,
|
|
1148
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
1149
|
+
...params.args.parameters ? {
|
|
1150
|
+
max_tokens: params.args.parameters.max_new_tokens,
|
|
1151
|
+
...omit(params.args.parameters, "max_new_tokens")
|
|
1152
|
+
} : void 0,
|
|
1153
|
+
prompt: params.args.inputs
|
|
1154
|
+
};
|
|
1155
|
+
}
|
|
1156
|
+
async getResponse(response) {
|
|
1157
|
+
if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
|
|
1158
|
+
const completion = response.choices[0];
|
|
1159
|
+
return {
|
|
1160
|
+
generated_text: completion.text
|
|
1161
|
+
};
|
|
1162
|
+
}
|
|
1163
|
+
throw new InferenceOutputError("Expected OVHcloud text generation response format");
|
|
1164
|
+
}
|
|
1165
|
+
};
|
|
1166
|
+
|
|
1028
1167
|
// src/providers/replicate.ts
|
|
1029
1168
|
var ReplicateTask = class extends TaskProviderHelper {
|
|
1030
1169
|
constructor(url) {
|
|
@@ -1277,6 +1416,10 @@ var PROVIDERS = {
|
|
|
1277
1416
|
openai: {
|
|
1278
1417
|
conversational: new OpenAIConversationalTask()
|
|
1279
1418
|
},
|
|
1419
|
+
ovhcloud: {
|
|
1420
|
+
conversational: new OvhCloudConversationalTask(),
|
|
1421
|
+
"text-generation": new OvhCloudTextGenerationTask()
|
|
1422
|
+
},
|
|
1280
1423
|
replicate: {
|
|
1281
1424
|
"text-to-image": new ReplicateTextToImageTask(),
|
|
1282
1425
|
"text-to-speech": new ReplicateTextToSpeechTask(),
|
|
@@ -1315,81 +1458,13 @@ function getProviderHelper(provider, task) {
|
|
|
1315
1458
|
|
|
1316
1459
|
// package.json
|
|
1317
1460
|
var name = "@huggingface/inference";
|
|
1318
|
-
var version = "3.
|
|
1319
|
-
|
|
1320
|
-
// src/providers/consts.ts
|
|
1321
|
-
var HARDCODED_MODEL_INFERENCE_MAPPING = {
|
|
1322
|
-
/**
|
|
1323
|
-
* "HF model ID" => "Model ID on Inference Provider's side"
|
|
1324
|
-
*
|
|
1325
|
-
* Example:
|
|
1326
|
-
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
|
|
1327
|
-
*/
|
|
1328
|
-
"black-forest-labs": {},
|
|
1329
|
-
cerebras: {},
|
|
1330
|
-
cohere: {},
|
|
1331
|
-
"fal-ai": {},
|
|
1332
|
-
"featherless-ai": {},
|
|
1333
|
-
"fireworks-ai": {},
|
|
1334
|
-
groq: {},
|
|
1335
|
-
"hf-inference": {},
|
|
1336
|
-
hyperbolic: {},
|
|
1337
|
-
nebius: {},
|
|
1338
|
-
novita: {},
|
|
1339
|
-
nscale: {},
|
|
1340
|
-
openai: {},
|
|
1341
|
-
replicate: {},
|
|
1342
|
-
sambanova: {},
|
|
1343
|
-
together: {}
|
|
1344
|
-
};
|
|
1345
|
-
|
|
1346
|
-
// src/lib/getInferenceProviderMapping.ts
|
|
1347
|
-
var inferenceProviderMappingCache = /* @__PURE__ */ new Map();
|
|
1348
|
-
async function getInferenceProviderMapping(params, options) {
|
|
1349
|
-
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
|
|
1350
|
-
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
|
|
1351
|
-
}
|
|
1352
|
-
let inferenceProviderMapping;
|
|
1353
|
-
if (inferenceProviderMappingCache.has(params.modelId)) {
|
|
1354
|
-
inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId);
|
|
1355
|
-
} else {
|
|
1356
|
-
const resp = await (options?.fetch ?? fetch)(
|
|
1357
|
-
`${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`,
|
|
1358
|
-
{
|
|
1359
|
-
headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {}
|
|
1360
|
-
}
|
|
1361
|
-
);
|
|
1362
|
-
if (resp.status === 404) {
|
|
1363
|
-
throw new Error(`Model ${params.modelId} does not exist`);
|
|
1364
|
-
}
|
|
1365
|
-
inferenceProviderMapping = await resp.json().then((json) => json.inferenceProviderMapping).catch(() => null);
|
|
1366
|
-
}
|
|
1367
|
-
if (!inferenceProviderMapping) {
|
|
1368
|
-
throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`);
|
|
1369
|
-
}
|
|
1370
|
-
const providerMapping = inferenceProviderMapping[params.provider];
|
|
1371
|
-
if (providerMapping) {
|
|
1372
|
-
const equivalentTasks = params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task];
|
|
1373
|
-
if (!typedInclude(equivalentTasks, providerMapping.task)) {
|
|
1374
|
-
throw new Error(
|
|
1375
|
-
`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
|
|
1376
|
-
);
|
|
1377
|
-
}
|
|
1378
|
-
if (providerMapping.status === "staging") {
|
|
1379
|
-
console.warn(
|
|
1380
|
-
`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
|
|
1381
|
-
);
|
|
1382
|
-
}
|
|
1383
|
-
return { ...providerMapping, hfModelId: params.modelId };
|
|
1384
|
-
}
|
|
1385
|
-
return null;
|
|
1386
|
-
}
|
|
1461
|
+
var version = "3.11.0";
|
|
1387
1462
|
|
|
1388
1463
|
// src/lib/makeRequestOptions.ts
|
|
1389
1464
|
var tasks = null;
|
|
1390
1465
|
async function makeRequestOptions(args, providerHelper, options) {
|
|
1391
|
-
const {
|
|
1392
|
-
const provider =
|
|
1466
|
+
const { model: maybeModel } = args;
|
|
1467
|
+
const provider = providerHelper.provider;
|
|
1393
1468
|
const { task } = options ?? {};
|
|
1394
1469
|
if (args.endpointUrl && provider !== "hf-inference") {
|
|
1395
1470
|
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
|
|
@@ -1444,7 +1519,7 @@ async function makeRequestOptions(args, providerHelper, options) {
|
|
|
1444
1519
|
}
|
|
1445
1520
|
function makeRequestOptionsFromResolvedModel(resolvedModel, providerHelper, args, mapping, options) {
|
|
1446
1521
|
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
|
|
1447
|
-
const provider =
|
|
1522
|
+
const provider = providerHelper.provider;
|
|
1448
1523
|
const { includeCredentials, task, signal, billTo } = options ?? {};
|
|
1449
1524
|
const authMethod = (() => {
|
|
1450
1525
|
if (providerHelper.clientSideRoutingOnly) {
|
|
@@ -1735,7 +1810,8 @@ async function request(args, options) {
|
|
|
1735
1810
|
console.warn(
|
|
1736
1811
|
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
|
|
1737
1812
|
);
|
|
1738
|
-
const
|
|
1813
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1814
|
+
const providerHelper = getProviderHelper(provider, options?.task);
|
|
1739
1815
|
const result = await innerRequest(args, providerHelper, options);
|
|
1740
1816
|
return result.data;
|
|
1741
1817
|
}
|
|
@@ -1745,7 +1821,8 @@ async function* streamingRequest(args, options) {
|
|
|
1745
1821
|
console.warn(
|
|
1746
1822
|
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
|
|
1747
1823
|
);
|
|
1748
|
-
const
|
|
1824
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1825
|
+
const providerHelper = getProviderHelper(provider, options?.task);
|
|
1749
1826
|
yield* innerStreamingRequest(args, providerHelper, options);
|
|
1750
1827
|
}
|
|
1751
1828
|
|
|
@@ -1759,7 +1836,8 @@ function preparePayload(args) {
|
|
|
1759
1836
|
|
|
1760
1837
|
// src/tasks/audio/audioClassification.ts
|
|
1761
1838
|
async function audioClassification(args, options) {
|
|
1762
|
-
const
|
|
1839
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1840
|
+
const providerHelper = getProviderHelper(provider, "audio-classification");
|
|
1763
1841
|
const payload = preparePayload(args);
|
|
1764
1842
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1765
1843
|
...options,
|
|
@@ -1770,7 +1848,9 @@ async function audioClassification(args, options) {
|
|
|
1770
1848
|
|
|
1771
1849
|
// src/tasks/audio/audioToAudio.ts
|
|
1772
1850
|
async function audioToAudio(args, options) {
|
|
1773
|
-
const
|
|
1851
|
+
const model = "inputs" in args ? args.model : void 0;
|
|
1852
|
+
const provider = await resolveProvider(args.provider, model);
|
|
1853
|
+
const providerHelper = getProviderHelper(provider, "audio-to-audio");
|
|
1774
1854
|
const payload = preparePayload(args);
|
|
1775
1855
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1776
1856
|
...options,
|
|
@@ -1794,7 +1874,8 @@ function base64FromBytes(arr) {
|
|
|
1794
1874
|
|
|
1795
1875
|
// src/tasks/audio/automaticSpeechRecognition.ts
|
|
1796
1876
|
async function automaticSpeechRecognition(args, options) {
|
|
1797
|
-
const
|
|
1877
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1878
|
+
const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
|
|
1798
1879
|
const payload = await buildPayload(args);
|
|
1799
1880
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1800
1881
|
...options,
|
|
@@ -1834,7 +1915,7 @@ async function buildPayload(args) {
|
|
|
1834
1915
|
|
|
1835
1916
|
// src/tasks/audio/textToSpeech.ts
|
|
1836
1917
|
async function textToSpeech(args, options) {
|
|
1837
|
-
const provider = args.provider
|
|
1918
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1838
1919
|
const providerHelper = getProviderHelper(provider, "text-to-speech");
|
|
1839
1920
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1840
1921
|
...options,
|
|
@@ -1850,7 +1931,8 @@ function preparePayload2(args) {
|
|
|
1850
1931
|
|
|
1851
1932
|
// src/tasks/cv/imageClassification.ts
|
|
1852
1933
|
async function imageClassification(args, options) {
|
|
1853
|
-
const
|
|
1934
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1935
|
+
const providerHelper = getProviderHelper(provider, "image-classification");
|
|
1854
1936
|
const payload = preparePayload2(args);
|
|
1855
1937
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1856
1938
|
...options,
|
|
@@ -1861,7 +1943,8 @@ async function imageClassification(args, options) {
|
|
|
1861
1943
|
|
|
1862
1944
|
// src/tasks/cv/imageSegmentation.ts
|
|
1863
1945
|
async function imageSegmentation(args, options) {
|
|
1864
|
-
const
|
|
1946
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1947
|
+
const providerHelper = getProviderHelper(provider, "image-segmentation");
|
|
1865
1948
|
const payload = preparePayload2(args);
|
|
1866
1949
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1867
1950
|
...options,
|
|
@@ -1872,7 +1955,8 @@ async function imageSegmentation(args, options) {
|
|
|
1872
1955
|
|
|
1873
1956
|
// src/tasks/cv/imageToImage.ts
|
|
1874
1957
|
async function imageToImage(args, options) {
|
|
1875
|
-
const
|
|
1958
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1959
|
+
const providerHelper = getProviderHelper(provider, "image-to-image");
|
|
1876
1960
|
let reqArgs;
|
|
1877
1961
|
if (!args.parameters) {
|
|
1878
1962
|
reqArgs = {
|
|
@@ -1897,7 +1981,8 @@ async function imageToImage(args, options) {
|
|
|
1897
1981
|
|
|
1898
1982
|
// src/tasks/cv/imageToText.ts
|
|
1899
1983
|
async function imageToText(args, options) {
|
|
1900
|
-
const
|
|
1984
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1985
|
+
const providerHelper = getProviderHelper(provider, "image-to-text");
|
|
1901
1986
|
const payload = preparePayload2(args);
|
|
1902
1987
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1903
1988
|
...options,
|
|
@@ -1908,7 +1993,8 @@ async function imageToText(args, options) {
|
|
|
1908
1993
|
|
|
1909
1994
|
// src/tasks/cv/objectDetection.ts
|
|
1910
1995
|
async function objectDetection(args, options) {
|
|
1911
|
-
const
|
|
1996
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1997
|
+
const providerHelper = getProviderHelper(provider, "object-detection");
|
|
1912
1998
|
const payload = preparePayload2(args);
|
|
1913
1999
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1914
2000
|
...options,
|
|
@@ -1919,7 +2005,7 @@ async function objectDetection(args, options) {
|
|
|
1919
2005
|
|
|
1920
2006
|
// src/tasks/cv/textToImage.ts
|
|
1921
2007
|
async function textToImage(args, options) {
|
|
1922
|
-
const provider = args.provider
|
|
2008
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1923
2009
|
const providerHelper = getProviderHelper(provider, "text-to-image");
|
|
1924
2010
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1925
2011
|
...options,
|
|
@@ -1931,7 +2017,7 @@ async function textToImage(args, options) {
|
|
|
1931
2017
|
|
|
1932
2018
|
// src/tasks/cv/textToVideo.ts
|
|
1933
2019
|
async function textToVideo(args, options) {
|
|
1934
|
-
const provider = args.provider
|
|
2020
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1935
2021
|
const providerHelper = getProviderHelper(provider, "text-to-video");
|
|
1936
2022
|
const { data: response } = await innerRequest(
|
|
1937
2023
|
args,
|
|
@@ -1968,7 +2054,8 @@ async function preparePayload3(args) {
|
|
|
1968
2054
|
}
|
|
1969
2055
|
}
|
|
1970
2056
|
async function zeroShotImageClassification(args, options) {
|
|
1971
|
-
const
|
|
2057
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2058
|
+
const providerHelper = getProviderHelper(provider, "zero-shot-image-classification");
|
|
1972
2059
|
const payload = await preparePayload3(args);
|
|
1973
2060
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1974
2061
|
...options,
|
|
@@ -1979,7 +2066,8 @@ async function zeroShotImageClassification(args, options) {
|
|
|
1979
2066
|
|
|
1980
2067
|
// src/tasks/nlp/chatCompletion.ts
|
|
1981
2068
|
async function chatCompletion(args, options) {
|
|
1982
|
-
const
|
|
2069
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2070
|
+
const providerHelper = getProviderHelper(provider, "conversational");
|
|
1983
2071
|
const { data: response } = await innerRequest(args, providerHelper, {
|
|
1984
2072
|
...options,
|
|
1985
2073
|
task: "conversational"
|
|
@@ -1989,7 +2077,8 @@ async function chatCompletion(args, options) {
|
|
|
1989
2077
|
|
|
1990
2078
|
// src/tasks/nlp/chatCompletionStream.ts
|
|
1991
2079
|
async function* chatCompletionStream(args, options) {
|
|
1992
|
-
const
|
|
2080
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2081
|
+
const providerHelper = getProviderHelper(provider, "conversational");
|
|
1993
2082
|
yield* innerStreamingRequest(args, providerHelper, {
|
|
1994
2083
|
...options,
|
|
1995
2084
|
task: "conversational"
|
|
@@ -1998,7 +2087,8 @@ async function* chatCompletionStream(args, options) {
|
|
|
1998
2087
|
|
|
1999
2088
|
// src/tasks/nlp/featureExtraction.ts
|
|
2000
2089
|
async function featureExtraction(args, options) {
|
|
2001
|
-
const
|
|
2090
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2091
|
+
const providerHelper = getProviderHelper(provider, "feature-extraction");
|
|
2002
2092
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2003
2093
|
...options,
|
|
2004
2094
|
task: "feature-extraction"
|
|
@@ -2008,7 +2098,8 @@ async function featureExtraction(args, options) {
|
|
|
2008
2098
|
|
|
2009
2099
|
// src/tasks/nlp/fillMask.ts
|
|
2010
2100
|
async function fillMask(args, options) {
|
|
2011
|
-
const
|
|
2101
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2102
|
+
const providerHelper = getProviderHelper(provider, "fill-mask");
|
|
2012
2103
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2013
2104
|
...options,
|
|
2014
2105
|
task: "fill-mask"
|
|
@@ -2018,7 +2109,8 @@ async function fillMask(args, options) {
|
|
|
2018
2109
|
|
|
2019
2110
|
// src/tasks/nlp/questionAnswering.ts
|
|
2020
2111
|
async function questionAnswering(args, options) {
|
|
2021
|
-
const
|
|
2112
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2113
|
+
const providerHelper = getProviderHelper(provider, "question-answering");
|
|
2022
2114
|
const { data: res } = await innerRequest(
|
|
2023
2115
|
args,
|
|
2024
2116
|
providerHelper,
|
|
@@ -2032,7 +2124,8 @@ async function questionAnswering(args, options) {
|
|
|
2032
2124
|
|
|
2033
2125
|
// src/tasks/nlp/sentenceSimilarity.ts
|
|
2034
2126
|
async function sentenceSimilarity(args, options) {
|
|
2035
|
-
const
|
|
2127
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2128
|
+
const providerHelper = getProviderHelper(provider, "sentence-similarity");
|
|
2036
2129
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2037
2130
|
...options,
|
|
2038
2131
|
task: "sentence-similarity"
|
|
@@ -2042,7 +2135,8 @@ async function sentenceSimilarity(args, options) {
|
|
|
2042
2135
|
|
|
2043
2136
|
// src/tasks/nlp/summarization.ts
|
|
2044
2137
|
async function summarization(args, options) {
|
|
2045
|
-
const
|
|
2138
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2139
|
+
const providerHelper = getProviderHelper(provider, "summarization");
|
|
2046
2140
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2047
2141
|
...options,
|
|
2048
2142
|
task: "summarization"
|
|
@@ -2052,7 +2146,8 @@ async function summarization(args, options) {
|
|
|
2052
2146
|
|
|
2053
2147
|
// src/tasks/nlp/tableQuestionAnswering.ts
|
|
2054
2148
|
async function tableQuestionAnswering(args, options) {
|
|
2055
|
-
const
|
|
2149
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2150
|
+
const providerHelper = getProviderHelper(provider, "table-question-answering");
|
|
2056
2151
|
const { data: res } = await innerRequest(
|
|
2057
2152
|
args,
|
|
2058
2153
|
providerHelper,
|
|
@@ -2066,7 +2161,8 @@ async function tableQuestionAnswering(args, options) {
|
|
|
2066
2161
|
|
|
2067
2162
|
// src/tasks/nlp/textClassification.ts
|
|
2068
2163
|
async function textClassification(args, options) {
|
|
2069
|
-
const
|
|
2164
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2165
|
+
const providerHelper = getProviderHelper(provider, "text-classification");
|
|
2070
2166
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2071
2167
|
...options,
|
|
2072
2168
|
task: "text-classification"
|
|
@@ -2076,7 +2172,8 @@ async function textClassification(args, options) {
|
|
|
2076
2172
|
|
|
2077
2173
|
// src/tasks/nlp/textGeneration.ts
|
|
2078
2174
|
async function textGeneration(args, options) {
|
|
2079
|
-
const
|
|
2175
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2176
|
+
const providerHelper = getProviderHelper(provider, "text-generation");
|
|
2080
2177
|
const { data: response } = await innerRequest(args, providerHelper, {
|
|
2081
2178
|
...options,
|
|
2082
2179
|
task: "text-generation"
|
|
@@ -2086,7 +2183,8 @@ async function textGeneration(args, options) {
|
|
|
2086
2183
|
|
|
2087
2184
|
// src/tasks/nlp/textGenerationStream.ts
|
|
2088
2185
|
async function* textGenerationStream(args, options) {
|
|
2089
|
-
const
|
|
2186
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2187
|
+
const providerHelper = getProviderHelper(provider, "text-generation");
|
|
2090
2188
|
yield* innerStreamingRequest(args, providerHelper, {
|
|
2091
2189
|
...options,
|
|
2092
2190
|
task: "text-generation"
|
|
@@ -2095,7 +2193,8 @@ async function* textGenerationStream(args, options) {
|
|
|
2095
2193
|
|
|
2096
2194
|
// src/tasks/nlp/tokenClassification.ts
|
|
2097
2195
|
async function tokenClassification(args, options) {
|
|
2098
|
-
const
|
|
2196
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2197
|
+
const providerHelper = getProviderHelper(provider, "token-classification");
|
|
2099
2198
|
const { data: res } = await innerRequest(
|
|
2100
2199
|
args,
|
|
2101
2200
|
providerHelper,
|
|
@@ -2109,7 +2208,8 @@ async function tokenClassification(args, options) {
|
|
|
2109
2208
|
|
|
2110
2209
|
// src/tasks/nlp/translation.ts
|
|
2111
2210
|
async function translation(args, options) {
|
|
2112
|
-
const
|
|
2211
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2212
|
+
const providerHelper = getProviderHelper(provider, "translation");
|
|
2113
2213
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2114
2214
|
...options,
|
|
2115
2215
|
task: "translation"
|
|
@@ -2119,7 +2219,8 @@ async function translation(args, options) {
|
|
|
2119
2219
|
|
|
2120
2220
|
// src/tasks/nlp/zeroShotClassification.ts
|
|
2121
2221
|
async function zeroShotClassification(args, options) {
|
|
2122
|
-
const
|
|
2222
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2223
|
+
const providerHelper = getProviderHelper(provider, "zero-shot-classification");
|
|
2123
2224
|
const { data: res } = await innerRequest(
|
|
2124
2225
|
args,
|
|
2125
2226
|
providerHelper,
|
|
@@ -2133,7 +2234,8 @@ async function zeroShotClassification(args, options) {
|
|
|
2133
2234
|
|
|
2134
2235
|
// src/tasks/multimodal/documentQuestionAnswering.ts
|
|
2135
2236
|
async function documentQuestionAnswering(args, options) {
|
|
2136
|
-
const
|
|
2237
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2238
|
+
const providerHelper = getProviderHelper(provider, "document-question-answering");
|
|
2137
2239
|
const reqArgs = {
|
|
2138
2240
|
...args,
|
|
2139
2241
|
inputs: {
|
|
@@ -2155,7 +2257,8 @@ async function documentQuestionAnswering(args, options) {
|
|
|
2155
2257
|
|
|
2156
2258
|
// src/tasks/multimodal/visualQuestionAnswering.ts
|
|
2157
2259
|
async function visualQuestionAnswering(args, options) {
|
|
2158
|
-
const
|
|
2260
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2261
|
+
const providerHelper = getProviderHelper(provider, "visual-question-answering");
|
|
2159
2262
|
const reqArgs = {
|
|
2160
2263
|
...args,
|
|
2161
2264
|
inputs: {
|
|
@@ -2173,7 +2276,8 @@ async function visualQuestionAnswering(args, options) {
|
|
|
2173
2276
|
|
|
2174
2277
|
// src/tasks/tabular/tabularClassification.ts
|
|
2175
2278
|
async function tabularClassification(args, options) {
|
|
2176
|
-
const
|
|
2279
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2280
|
+
const providerHelper = getProviderHelper(provider, "tabular-classification");
|
|
2177
2281
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2178
2282
|
...options,
|
|
2179
2283
|
task: "tabular-classification"
|
|
@@ -2183,7 +2287,8 @@ async function tabularClassification(args, options) {
|
|
|
2183
2287
|
|
|
2184
2288
|
// src/tasks/tabular/tabularRegression.ts
|
|
2185
2289
|
async function tabularRegression(args, options) {
|
|
2186
|
-
const
|
|
2290
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2291
|
+
const providerHelper = getProviderHelper(provider, "tabular-regression");
|
|
2187
2292
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2188
2293
|
...options,
|
|
2189
2294
|
task: "tabular-regression"
|
|
@@ -2191,6 +2296,11 @@ async function tabularRegression(args, options) {
|
|
|
2191
2296
|
return providerHelper.getResponse(res);
|
|
2192
2297
|
}
|
|
2193
2298
|
|
|
2299
|
+
// src/utils/typedEntries.ts
|
|
2300
|
+
function typedEntries(obj) {
|
|
2301
|
+
return Object.entries(obj);
|
|
2302
|
+
}
|
|
2303
|
+
|
|
2194
2304
|
// src/InferenceClient.ts
|
|
2195
2305
|
var InferenceClient = class {
|
|
2196
2306
|
accessToken;
|
|
@@ -2198,40 +2308,36 @@ var InferenceClient = class {
|
|
|
2198
2308
|
constructor(accessToken = "", defaultOptions = {}) {
|
|
2199
2309
|
this.accessToken = accessToken;
|
|
2200
2310
|
this.defaultOptions = defaultOptions;
|
|
2201
|
-
for (const [name2, fn] of
|
|
2311
|
+
for (const [name2, fn] of typedEntries(tasks_exports)) {
|
|
2202
2312
|
Object.defineProperty(this, name2, {
|
|
2203
2313
|
enumerable: false,
|
|
2204
2314
|
value: (params, options) => (
|
|
2205
2315
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
2206
|
-
fn(
|
|
2316
|
+
fn(
|
|
2317
|
+
/// ^ The cast of fn to any is necessary, otherwise TS can't compile because the generated union type is too complex
|
|
2318
|
+
{ endpointUrl: defaultOptions.endpointUrl, accessToken, ...params },
|
|
2319
|
+
{
|
|
2320
|
+
...omit(defaultOptions, ["endpointUrl"]),
|
|
2321
|
+
...options
|
|
2322
|
+
}
|
|
2323
|
+
)
|
|
2207
2324
|
)
|
|
2208
2325
|
});
|
|
2209
2326
|
}
|
|
2210
2327
|
}
|
|
2211
2328
|
/**
|
|
2212
|
-
* Returns
|
|
2329
|
+
* Returns a new instance of InferenceClient tied to a specified endpoint.
|
|
2330
|
+
*
|
|
2331
|
+
* For backward compatibility mostly.
|
|
2213
2332
|
*/
|
|
2214
2333
|
endpoint(endpointUrl) {
|
|
2215
|
-
return new
|
|
2216
|
-
}
|
|
2217
|
-
};
|
|
2218
|
-
var InferenceClientEndpoint = class {
|
|
2219
|
-
constructor(endpointUrl, accessToken = "", defaultOptions = {}) {
|
|
2220
|
-
accessToken;
|
|
2221
|
-
defaultOptions;
|
|
2222
|
-
for (const [name2, fn] of Object.entries(tasks_exports)) {
|
|
2223
|
-
Object.defineProperty(this, name2, {
|
|
2224
|
-
enumerable: false,
|
|
2225
|
-
value: (params, options) => (
|
|
2226
|
-
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
2227
|
-
fn({ ...params, accessToken, endpointUrl }, { ...defaultOptions, ...options })
|
|
2228
|
-
)
|
|
2229
|
-
});
|
|
2230
|
-
}
|
|
2334
|
+
return new InferenceClient(this.accessToken, { ...this.defaultOptions, endpointUrl });
|
|
2231
2335
|
}
|
|
2232
2336
|
};
|
|
2233
2337
|
var HfInference = class extends InferenceClient {
|
|
2234
2338
|
};
|
|
2339
|
+
var InferenceClientEndpoint = class extends InferenceClient {
|
|
2340
|
+
};
|
|
2235
2341
|
|
|
2236
2342
|
// src/types.ts
|
|
2237
2343
|
var INFERENCE_PROVIDERS = [
|
|
@@ -2248,10 +2354,12 @@ var INFERENCE_PROVIDERS = [
|
|
|
2248
2354
|
"novita",
|
|
2249
2355
|
"nscale",
|
|
2250
2356
|
"openai",
|
|
2357
|
+
"ovhcloud",
|
|
2251
2358
|
"replicate",
|
|
2252
2359
|
"sambanova",
|
|
2253
2360
|
"together"
|
|
2254
2361
|
];
|
|
2362
|
+
var PROVIDERS_OR_POLICIES = [...INFERENCE_PROVIDERS, "auto"];
|
|
2255
2363
|
|
|
2256
2364
|
// src/snippets/index.ts
|
|
2257
2365
|
var snippets_exports = {};
|
|
@@ -2272,6 +2380,7 @@ var templates = {
|
|
|
2272
2380
|
"basicImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "image/jpeg",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
|
|
2273
2381
|
"textToAudio": '{% if model.library_name == "transformers" %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ',
|
|
2274
2382
|
"textToImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\n\nquery({ {{ providerInputs.asTsString }} }).then((response) => {\n // Use image\n});',
|
|
2383
|
+
"textToSpeech": '{% if model.library_name == "transformers" %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\nquery({ text: {{ inputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ text: {{ inputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ',
|
|
2275
2384
|
"zeroShotClassification": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({\n inputs: {{ providerInputs.asObj.inputs }},\n parameters: { candidate_labels: ["refund", "legal", "faq"] }\n}).then((response) => {\n console.log(JSON.stringify(response));\n});'
|
|
2276
2385
|
},
|
|
2277
2386
|
"huggingface.js": {
|
|
@@ -2293,11 +2402,23 @@ const image = await client.textToImage({
|
|
|
2293
2402
|
billTo: "{{ billTo }}",
|
|
2294
2403
|
}{% endif %});
|
|
2295
2404
|
/// Use the generated image (it's a Blob)`,
|
|
2405
|
+
"textToSpeech": `import { InferenceClient } from "@huggingface/inference";
|
|
2406
|
+
|
|
2407
|
+
const client = new InferenceClient("{{ accessToken }}");
|
|
2408
|
+
|
|
2409
|
+
const audio = await client.textToSpeech({
|
|
2410
|
+
provider: "{{ provider }}",
|
|
2411
|
+
model: "{{ model.id }}",
|
|
2412
|
+
inputs: {{ inputs.asObj.inputs }},
|
|
2413
|
+
}{% if billTo %}, {
|
|
2414
|
+
billTo: "{{ billTo }}",
|
|
2415
|
+
}{% endif %});
|
|
2416
|
+
// Use the generated audio (it's a Blob)`,
|
|
2296
2417
|
"textToVideo": `import { InferenceClient } from "@huggingface/inference";
|
|
2297
2418
|
|
|
2298
2419
|
const client = new InferenceClient("{{ accessToken }}");
|
|
2299
2420
|
|
|
2300
|
-
const
|
|
2421
|
+
const video = await client.textToVideo({
|
|
2301
2422
|
provider: "{{ provider }}",
|
|
2302
2423
|
model: "{{ model.id }}",
|
|
2303
2424
|
inputs: {{ inputs.asObj.inputs }},
|
|
@@ -2313,7 +2434,7 @@ const image = await client.textToVideo({
|
|
|
2313
2434
|
},
|
|
2314
2435
|
"python": {
|
|
2315
2436
|
"fal_client": {
|
|
2316
|
-
"textToImage": '{% if provider == "fal-ai" %}\nimport fal_client\n\nresult = fal_client.subscribe(\n "{{ providerModelId }}",\n arguments={\n "prompt": {{ inputs.asObj.inputs }},\n },\n)\nprint(result)\n{% endif %} '
|
|
2437
|
+
"textToImage": '{% if provider == "fal-ai" %}\nimport fal_client\n\n{% if providerInputs.asObj.loras is defined and providerInputs.asObj.loras != none %}\nresult = fal_client.subscribe(\n "{{ providerModelId }}",\n arguments={\n "prompt": {{ inputs.asObj.inputs }},\n "loras":{{ providerInputs.asObj.loras | tojson }},\n },\n)\n{% else %}\nresult = fal_client.subscribe(\n "{{ providerModelId }}",\n arguments={\n "prompt": {{ inputs.asObj.inputs }},\n },\n)\n{% endif %} \nprint(result)\n{% endif %} '
|
|
2317
2438
|
},
|
|
2318
2439
|
"huggingface_hub": {
|
|
2319
2440
|
"basic": 'result = client.{{ methodName }}(\n inputs={{ inputs.asObj.inputs }},\n model="{{ model.id }}",\n)',
|
|
@@ -2325,6 +2446,7 @@ const image = await client.textToVideo({
|
|
|
2325
2446
|
"imageToImage": '# output is a PIL.Image object\nimage = client.image_to_image(\n "{{ inputs.asObj.inputs }}",\n prompt="{{ inputs.asObj.parameters.prompt }}",\n model="{{ model.id }}",\n) ',
|
|
2326
2447
|
"importInferenceClient": 'from huggingface_hub import InferenceClient\n\nclient = InferenceClient(\n provider="{{ provider }}",\n api_key="{{ accessToken }}",\n{% if billTo %}\n bill_to="{{ billTo }}",\n{% endif %}\n)',
|
|
2327
2448
|
"textToImage": '# output is a PIL.Image object\nimage = client.text_to_image(\n {{ inputs.asObj.inputs }},\n model="{{ model.id }}",\n) ',
|
|
2449
|
+
"textToSpeech": '# audio is returned as bytes\naudio = client.text_to_speech(\n {{ inputs.asObj.inputs }},\n model="{{ model.id }}",\n) \n',
|
|
2328
2450
|
"textToVideo": 'video = client.text_to_video(\n {{ inputs.asObj.inputs }},\n model="{{ model.id }}",\n) '
|
|
2329
2451
|
},
|
|
2330
2452
|
"openai": {
|
|
@@ -2341,8 +2463,9 @@ const image = await client.textToVideo({
|
|
|
2341
2463
|
"imageToImage": 'def query(payload):\n with open(payload["inputs"], "rb") as f:\n img = f.read()\n payload["inputs"] = base64.b64encode(img).decode("utf-8")\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nimage_bytes = query({\n{{ providerInputs.asJsonString }}\n})\n\n# You can access the image with PIL.Image for example\nimport io\nfrom PIL import Image\nimage = Image.open(io.BytesIO(image_bytes)) ',
|
|
2342
2464
|
"importRequests": '{% if importBase64 %}\nimport base64\n{% endif %}\n{% if importJson %}\nimport json\n{% endif %}\nimport requests\n\nAPI_URL = "{{ fullUrl }}"\nheaders = {\n "Authorization": "{{ authorizationHeader }}",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}"\n{% endif %}\n}',
|
|
2343
2465
|
"tabular": 'def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nresponse = query({\n "inputs": {\n "data": {{ providerInputs.asObj.inputs }}\n },\n}) ',
|
|
2344
|
-
"textToAudio": '{% if model.library_name == "transformers" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\naudio_bytes = query({\n "inputs": {{
|
|
2466
|
+
"textToAudio": '{% if model.library_name == "transformers" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\naudio_bytes = query({\n "inputs": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio_bytes)\n{% else %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\naudio, sampling_rate = query({\n "inputs": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio, rate=sampling_rate)\n{% endif %} ',
|
|
2345
2467
|
"textToImage": '{% if provider == "hf-inference" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nimage_bytes = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n})\n\n# You can access the image with PIL.Image for example\nimport io\nfrom PIL import Image\nimage = Image.open(io.BytesIO(image_bytes))\n{% endif %}',
|
|
2468
|
+
"textToSpeech": '{% if model.library_name == "transformers" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\naudio_bytes = query({\n "text": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio_bytes)\n{% else %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\naudio, sampling_rate = query({\n "text": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio, rate=sampling_rate)\n{% endif %} ',
|
|
2346
2469
|
"zeroShotClassification": 'def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\noutput = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n "parameters": {"candidate_labels": ["refund", "legal", "faq"]},\n}) ',
|
|
2347
2470
|
"zeroShotImageClassification": 'def query(data):\n with open(data["image_path"], "rb") as f:\n img = f.read()\n payload={\n "parameters": data["parameters"],\n "inputs": base64.b64encode(img).decode("utf-8")\n }\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\noutput = query({\n "image_path": {{ providerInputs.asObj.inputs }},\n "parameters": {"candidate_labels": ["cat", "dog", "llama"]},\n}) '
|
|
2348
2471
|
}
|
|
@@ -2445,6 +2568,7 @@ var HF_JS_METHODS = {
|
|
|
2445
2568
|
"text-generation": "textGeneration",
|
|
2446
2569
|
"text2text-generation": "textGeneration",
|
|
2447
2570
|
"token-classification": "tokenClassification",
|
|
2571
|
+
"text-to-speech": "textToSpeech",
|
|
2448
2572
|
translation: "translation"
|
|
2449
2573
|
};
|
|
2450
2574
|
var snippetGenerator = (templateName, inputPreparationFn) => {
|
|
@@ -2564,7 +2688,7 @@ var prepareConversationalInput = (model, opts) => {
|
|
|
2564
2688
|
return {
|
|
2565
2689
|
messages: opts?.messages ?? (0, import_tasks.getModelInputSnippet)(model),
|
|
2566
2690
|
...opts?.temperature ? { temperature: opts?.temperature } : void 0,
|
|
2567
|
-
max_tokens: opts?.max_tokens
|
|
2691
|
+
...opts?.max_tokens ? { max_tokens: opts?.max_tokens } : void 0,
|
|
2568
2692
|
...opts?.top_p ? { top_p: opts?.top_p } : void 0
|
|
2569
2693
|
};
|
|
2570
2694
|
};
|
|
@@ -2591,7 +2715,7 @@ var snippets = {
|
|
|
2591
2715
|
"text-generation": snippetGenerator("basic"),
|
|
2592
2716
|
"text-to-audio": snippetGenerator("textToAudio"),
|
|
2593
2717
|
"text-to-image": snippetGenerator("textToImage"),
|
|
2594
|
-
"text-to-speech": snippetGenerator("
|
|
2718
|
+
"text-to-speech": snippetGenerator("textToSpeech"),
|
|
2595
2719
|
"text-to-video": snippetGenerator("textToVideo"),
|
|
2596
2720
|
"text2text-generation": snippetGenerator("basic"),
|
|
2597
2721
|
"token-classification": snippetGenerator("basic"),
|
|
@@ -2658,6 +2782,7 @@ function removeSuffix(str, suffix) {
|
|
|
2658
2782
|
InferenceClient,
|
|
2659
2783
|
InferenceClientEndpoint,
|
|
2660
2784
|
InferenceOutputError,
|
|
2785
|
+
PROVIDERS_OR_POLICIES,
|
|
2661
2786
|
audioClassification,
|
|
2662
2787
|
audioToAudio,
|
|
2663
2788
|
automaticSpeechRecognition,
|