@huggingface/inference 3.9.2 → 3.11.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +9 -7
- package/dist/index.cjs +771 -646
- package/dist/index.js +770 -646
- package/dist/src/InferenceClient.d.ts +16 -17
- package/dist/src/InferenceClient.d.ts.map +1 -1
- package/dist/src/lib/getInferenceProviderMapping.d.ts +6 -2
- package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -1
- package/dist/src/lib/getProviderHelper.d.ts.map +1 -1
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/consts.d.ts.map +1 -1
- package/dist/src/providers/ovhcloud.d.ts +38 -0
- package/dist/src/providers/ovhcloud.d.ts.map +1 -0
- package/dist/src/providers/providerHelper.d.ts +1 -1
- package/dist/src/providers/providerHelper.d.ts.map +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
- package/dist/src/snippets/templates.exported.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioToAudio.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
- package/dist/src/types.d.ts +7 -5
- package/dist/src/types.d.ts.map +1 -1
- package/dist/src/utils/typedEntries.d.ts +4 -0
- package/dist/src/utils/typedEntries.d.ts.map +1 -0
- package/package.json +3 -3
- package/src/InferenceClient.ts +32 -43
- package/src/lib/getInferenceProviderMapping.ts +68 -19
- package/src/lib/getProviderHelper.ts +5 -0
- package/src/lib/makeRequestOptions.ts +4 -3
- package/src/providers/consts.ts +1 -0
- package/src/providers/ovhcloud.ts +75 -0
- package/src/providers/providerHelper.ts +1 -1
- package/src/snippets/getInferenceSnippets.ts +5 -4
- package/src/snippets/templates.exported.ts +7 -3
- package/src/tasks/audio/audioClassification.ts +3 -1
- package/src/tasks/audio/audioToAudio.ts +4 -1
- package/src/tasks/audio/automaticSpeechRecognition.ts +3 -1
- package/src/tasks/audio/textToSpeech.ts +2 -1
- package/src/tasks/custom/request.ts +3 -1
- package/src/tasks/custom/streamingRequest.ts +3 -1
- package/src/tasks/cv/imageClassification.ts +3 -1
- package/src/tasks/cv/imageSegmentation.ts +3 -1
- package/src/tasks/cv/imageToImage.ts +3 -1
- package/src/tasks/cv/imageToText.ts +3 -1
- package/src/tasks/cv/objectDetection.ts +3 -1
- package/src/tasks/cv/textToImage.ts +2 -1
- package/src/tasks/cv/textToVideo.ts +2 -1
- package/src/tasks/cv/zeroShotImageClassification.ts +3 -1
- package/src/tasks/multimodal/documentQuestionAnswering.ts +3 -1
- package/src/tasks/multimodal/visualQuestionAnswering.ts +3 -1
- package/src/tasks/nlp/chatCompletion.ts +3 -1
- package/src/tasks/nlp/chatCompletionStream.ts +3 -1
- package/src/tasks/nlp/featureExtraction.ts +3 -1
- package/src/tasks/nlp/fillMask.ts +3 -1
- package/src/tasks/nlp/questionAnswering.ts +4 -1
- package/src/tasks/nlp/sentenceSimilarity.ts +3 -1
- package/src/tasks/nlp/summarization.ts +3 -1
- package/src/tasks/nlp/tableQuestionAnswering.ts +3 -1
- package/src/tasks/nlp/textClassification.ts +3 -1
- package/src/tasks/nlp/textGeneration.ts +3 -1
- package/src/tasks/nlp/textGenerationStream.ts +3 -1
- package/src/tasks/nlp/tokenClassification.ts +3 -1
- package/src/tasks/nlp/translation.ts +3 -1
- package/src/tasks/nlp/zeroShotClassification.ts +3 -1
- package/src/tasks/tabular/tabularClassification.ts +3 -1
- package/src/tasks/tabular/tabularRegression.ts +3 -1
- package/src/types.ts +9 -4
- package/src/utils/typedEntries.ts +5 -0
package/dist/index.js
CHANGED
|
@@ -41,6 +41,38 @@ __export(tasks_exports, {
|
|
|
41
41
|
zeroShotImageClassification: () => zeroShotImageClassification
|
|
42
42
|
});
|
|
43
43
|
|
|
44
|
+
// src/config.ts
|
|
45
|
+
var HF_HUB_URL = "https://huggingface.co";
|
|
46
|
+
var HF_ROUTER_URL = "https://router.huggingface.co";
|
|
47
|
+
var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
|
|
48
|
+
|
|
49
|
+
// src/providers/consts.ts
|
|
50
|
+
var HARDCODED_MODEL_INFERENCE_MAPPING = {
|
|
51
|
+
/**
|
|
52
|
+
* "HF model ID" => "Model ID on Inference Provider's side"
|
|
53
|
+
*
|
|
54
|
+
* Example:
|
|
55
|
+
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
|
|
56
|
+
*/
|
|
57
|
+
"black-forest-labs": {},
|
|
58
|
+
cerebras: {},
|
|
59
|
+
cohere: {},
|
|
60
|
+
"fal-ai": {},
|
|
61
|
+
"featherless-ai": {},
|
|
62
|
+
"fireworks-ai": {},
|
|
63
|
+
groq: {},
|
|
64
|
+
"hf-inference": {},
|
|
65
|
+
hyperbolic: {},
|
|
66
|
+
nebius: {},
|
|
67
|
+
novita: {},
|
|
68
|
+
nscale: {},
|
|
69
|
+
openai: {},
|
|
70
|
+
ovhcloud: {},
|
|
71
|
+
replicate: {},
|
|
72
|
+
sambanova: {},
|
|
73
|
+
together: {}
|
|
74
|
+
};
|
|
75
|
+
|
|
44
76
|
// src/lib/InferenceOutputError.ts
|
|
45
77
|
var InferenceOutputError = class extends TypeError {
|
|
46
78
|
constructor(message) {
|
|
@@ -51,42 +83,6 @@ var InferenceOutputError = class extends TypeError {
|
|
|
51
83
|
}
|
|
52
84
|
};
|
|
53
85
|
|
|
54
|
-
// src/utils/delay.ts
|
|
55
|
-
function delay(ms) {
|
|
56
|
-
return new Promise((resolve) => {
|
|
57
|
-
setTimeout(() => resolve(), ms);
|
|
58
|
-
});
|
|
59
|
-
}
|
|
60
|
-
|
|
61
|
-
// src/utils/pick.ts
|
|
62
|
-
function pick(o, props) {
|
|
63
|
-
return Object.assign(
|
|
64
|
-
{},
|
|
65
|
-
...props.map((prop) => {
|
|
66
|
-
if (o[prop] !== void 0) {
|
|
67
|
-
return { [prop]: o[prop] };
|
|
68
|
-
}
|
|
69
|
-
})
|
|
70
|
-
);
|
|
71
|
-
}
|
|
72
|
-
|
|
73
|
-
// src/utils/typedInclude.ts
|
|
74
|
-
function typedInclude(arr, v) {
|
|
75
|
-
return arr.includes(v);
|
|
76
|
-
}
|
|
77
|
-
|
|
78
|
-
// src/utils/omit.ts
|
|
79
|
-
function omit(o, props) {
|
|
80
|
-
const propsArr = Array.isArray(props) ? props : [props];
|
|
81
|
-
const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
|
|
82
|
-
return pick(o, letsKeep);
|
|
83
|
-
}
|
|
84
|
-
|
|
85
|
-
// src/config.ts
|
|
86
|
-
var HF_HUB_URL = "https://huggingface.co";
|
|
87
|
-
var HF_ROUTER_URL = "https://router.huggingface.co";
|
|
88
|
-
var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
|
|
89
|
-
|
|
90
86
|
// src/utils/toArray.ts
|
|
91
87
|
function toArray(obj) {
|
|
92
88
|
if (Array.isArray(obj)) {
|
|
@@ -181,627 +177,736 @@ var BaseTextGenerationTask = class extends TaskProviderHelper {
|
|
|
181
177
|
}
|
|
182
178
|
};
|
|
183
179
|
|
|
184
|
-
// src/providers/
|
|
185
|
-
var
|
|
186
|
-
var
|
|
180
|
+
// src/providers/hf-inference.ts
|
|
181
|
+
var EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = ["feature-extraction", "sentence-similarity"];
|
|
182
|
+
var HFInferenceTask = class extends TaskProviderHelper {
|
|
187
183
|
constructor() {
|
|
188
|
-
super("
|
|
184
|
+
super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
|
|
189
185
|
}
|
|
190
186
|
preparePayload(params) {
|
|
191
|
-
return
|
|
192
|
-
...omit(params.args, ["inputs", "parameters"]),
|
|
193
|
-
...params.args.parameters,
|
|
194
|
-
prompt: params.args.inputs
|
|
195
|
-
};
|
|
187
|
+
return params.args;
|
|
196
188
|
}
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
};
|
|
201
|
-
if (!binary) {
|
|
202
|
-
headers["Content-Type"] = "application/json";
|
|
189
|
+
makeUrl(params) {
|
|
190
|
+
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
|
|
191
|
+
return params.model;
|
|
203
192
|
}
|
|
204
|
-
return
|
|
193
|
+
return super.makeUrl(params);
|
|
205
194
|
}
|
|
206
195
|
makeRoute(params) {
|
|
207
|
-
if (
|
|
208
|
-
|
|
196
|
+
if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
|
|
197
|
+
return `pipeline/${params.task}/${params.model}`;
|
|
209
198
|
}
|
|
210
|
-
return
|
|
199
|
+
return `models/${params.model}`;
|
|
200
|
+
}
|
|
201
|
+
async getResponse(response) {
|
|
202
|
+
return response;
|
|
211
203
|
}
|
|
204
|
+
};
|
|
205
|
+
var HFInferenceTextToImageTask = class extends HFInferenceTask {
|
|
212
206
|
async getResponse(response, url, headers, outputType) {
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
207
|
+
if (!response) {
|
|
208
|
+
throw new InferenceOutputError("response is undefined");
|
|
209
|
+
}
|
|
210
|
+
if (typeof response == "object") {
|
|
211
|
+
if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
|
|
212
|
+
const base64Data = response.data[0].b64_json;
|
|
213
|
+
if (outputType === "url") {
|
|
214
|
+
return `data:image/jpeg;base64,${base64Data}`;
|
|
215
|
+
}
|
|
216
|
+
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
217
|
+
return await base64Response.blob();
|
|
221
218
|
}
|
|
222
|
-
|
|
223
|
-
if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
|
|
219
|
+
if ("output" in response && Array.isArray(response.output)) {
|
|
224
220
|
if (outputType === "url") {
|
|
225
|
-
return
|
|
221
|
+
return response.output[0];
|
|
226
222
|
}
|
|
227
|
-
const
|
|
228
|
-
|
|
223
|
+
const urlResponse = await fetch(response.output[0]);
|
|
224
|
+
const blob = await urlResponse.blob();
|
|
225
|
+
return blob;
|
|
229
226
|
}
|
|
230
227
|
}
|
|
231
|
-
|
|
228
|
+
if (response instanceof Blob) {
|
|
229
|
+
if (outputType === "url") {
|
|
230
|
+
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
|
|
231
|
+
return `data:image/jpeg;base64,${b64}`;
|
|
232
|
+
}
|
|
233
|
+
return response;
|
|
234
|
+
}
|
|
235
|
+
throw new InferenceOutputError("Expected a Blob ");
|
|
232
236
|
}
|
|
233
237
|
};
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
238
|
+
var HFInferenceConversationalTask = class extends HFInferenceTask {
|
|
239
|
+
makeUrl(params) {
|
|
240
|
+
let url;
|
|
241
|
+
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
|
|
242
|
+
url = params.model.trim();
|
|
243
|
+
} else {
|
|
244
|
+
url = `${this.makeBaseUrl(params)}/models/${params.model}`;
|
|
245
|
+
}
|
|
246
|
+
url = url.replace(/\/+$/, "");
|
|
247
|
+
if (url.endsWith("/v1")) {
|
|
248
|
+
url += "/chat/completions";
|
|
249
|
+
} else if (!url.endsWith("/chat/completions")) {
|
|
250
|
+
url += "/v1/chat/completions";
|
|
251
|
+
}
|
|
252
|
+
return url;
|
|
253
|
+
}
|
|
254
|
+
preparePayload(params) {
|
|
255
|
+
return {
|
|
256
|
+
...params.args,
|
|
257
|
+
model: params.model
|
|
258
|
+
};
|
|
259
|
+
}
|
|
260
|
+
async getResponse(response) {
|
|
261
|
+
return response;
|
|
239
262
|
}
|
|
240
263
|
};
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
264
|
+
var HFInferenceTextGenerationTask = class extends HFInferenceTask {
|
|
265
|
+
async getResponse(response) {
|
|
266
|
+
const res = toArray(response);
|
|
267
|
+
if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
|
|
268
|
+
return res?.[0];
|
|
269
|
+
}
|
|
270
|
+
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
246
271
|
}
|
|
247
|
-
|
|
248
|
-
|
|
272
|
+
};
|
|
273
|
+
var HFInferenceAudioClassificationTask = class extends HFInferenceTask {
|
|
274
|
+
async getResponse(response) {
|
|
275
|
+
if (Array.isArray(response) && response.every(
|
|
276
|
+
(x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
|
|
277
|
+
)) {
|
|
278
|
+
return response;
|
|
279
|
+
}
|
|
280
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
|
|
249
281
|
}
|
|
250
282
|
};
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
|
|
255
|
-
}
|
|
256
|
-
|
|
257
|
-
// src/providers/fal-ai.ts
|
|
258
|
-
var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
|
|
259
|
-
var FalAITask = class extends TaskProviderHelper {
|
|
260
|
-
constructor(url) {
|
|
261
|
-
super("fal-ai", url || "https://fal.run");
|
|
283
|
+
var HFInferenceAutomaticSpeechRecognitionTask = class extends HFInferenceTask {
|
|
284
|
+
async getResponse(response) {
|
|
285
|
+
return response;
|
|
262
286
|
}
|
|
263
|
-
|
|
264
|
-
|
|
287
|
+
};
|
|
288
|
+
var HFInferenceAudioToAudioTask = class extends HFInferenceTask {
|
|
289
|
+
async getResponse(response) {
|
|
290
|
+
if (!Array.isArray(response)) {
|
|
291
|
+
throw new InferenceOutputError("Expected Array");
|
|
292
|
+
}
|
|
293
|
+
if (!response.every((elem) => {
|
|
294
|
+
return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
|
|
295
|
+
})) {
|
|
296
|
+
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
|
|
297
|
+
}
|
|
298
|
+
return response;
|
|
265
299
|
}
|
|
266
|
-
|
|
267
|
-
|
|
300
|
+
};
|
|
301
|
+
var HFInferenceDocumentQuestionAnsweringTask = class extends HFInferenceTask {
|
|
302
|
+
async getResponse(response) {
|
|
303
|
+
if (Array.isArray(response) && response.every(
|
|
304
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
|
|
305
|
+
)) {
|
|
306
|
+
return response[0];
|
|
307
|
+
}
|
|
308
|
+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
268
309
|
}
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
310
|
+
};
|
|
311
|
+
var HFInferenceFeatureExtractionTask = class extends HFInferenceTask {
|
|
312
|
+
async getResponse(response) {
|
|
313
|
+
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
|
|
314
|
+
if (curDepth > maxDepth)
|
|
315
|
+
return false;
|
|
316
|
+
if (arr.every((x) => Array.isArray(x))) {
|
|
317
|
+
return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
|
|
318
|
+
} else {
|
|
319
|
+
return arr.every((x) => typeof x === "number");
|
|
320
|
+
}
|
|
272
321
|
};
|
|
273
|
-
if (
|
|
274
|
-
|
|
322
|
+
if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
|
|
323
|
+
return response;
|
|
275
324
|
}
|
|
276
|
-
|
|
325
|
+
throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
|
|
277
326
|
}
|
|
278
327
|
};
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
preparePayload(params) {
|
|
284
|
-
const payload = {
|
|
285
|
-
...omit(params.args, ["inputs", "parameters"]),
|
|
286
|
-
...params.args.parameters,
|
|
287
|
-
sync_mode: true,
|
|
288
|
-
prompt: params.args.inputs
|
|
289
|
-
};
|
|
290
|
-
if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
|
|
291
|
-
payload.loras = [
|
|
292
|
-
{
|
|
293
|
-
path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
|
|
294
|
-
scale: 1
|
|
295
|
-
}
|
|
296
|
-
];
|
|
297
|
-
if (params.mapping.providerId === "fal-ai/lora") {
|
|
298
|
-
payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
|
|
299
|
-
}
|
|
328
|
+
var HFInferenceImageClassificationTask = class extends HFInferenceTask {
|
|
329
|
+
async getResponse(response) {
|
|
330
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
|
|
331
|
+
return response;
|
|
300
332
|
}
|
|
301
|
-
|
|
333
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
302
334
|
}
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
const urlResponse = await fetch(response.images[0].url);
|
|
309
|
-
return await urlResponse.blob();
|
|
335
|
+
};
|
|
336
|
+
var HFInferenceImageSegmentationTask = class extends HFInferenceTask {
|
|
337
|
+
async getResponse(response) {
|
|
338
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) {
|
|
339
|
+
return response;
|
|
310
340
|
}
|
|
311
|
-
throw new InferenceOutputError("Expected
|
|
341
|
+
throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
|
|
312
342
|
}
|
|
313
343
|
};
|
|
314
|
-
var
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
makeRoute(params) {
|
|
319
|
-
if (params.authMethod !== "provider-key") {
|
|
320
|
-
return `/${params.model}?_subdomain=queue`;
|
|
344
|
+
var HFInferenceImageToTextTask = class extends HFInferenceTask {
|
|
345
|
+
async getResponse(response) {
|
|
346
|
+
if (typeof response?.generated_text !== "string") {
|
|
347
|
+
throw new InferenceOutputError("Expected {generated_text: string}");
|
|
321
348
|
}
|
|
322
|
-
return
|
|
349
|
+
return response;
|
|
323
350
|
}
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
}
|
|
351
|
+
};
|
|
352
|
+
var HFInferenceImageToImageTask = class extends HFInferenceTask {
|
|
353
|
+
async getResponse(response) {
|
|
354
|
+
if (response instanceof Blob) {
|
|
355
|
+
return response;
|
|
356
|
+
}
|
|
357
|
+
throw new InferenceOutputError("Expected Blob");
|
|
330
358
|
}
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
359
|
+
};
|
|
360
|
+
var HFInferenceObjectDetectionTask = class extends HFInferenceTask {
|
|
361
|
+
async getResponse(response) {
|
|
362
|
+
if (Array.isArray(response) && response.every(
|
|
363
|
+
(x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
|
|
364
|
+
)) {
|
|
365
|
+
return response;
|
|
334
366
|
}
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
367
|
+
throw new InferenceOutputError(
|
|
368
|
+
"Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
|
|
369
|
+
);
|
|
370
|
+
}
|
|
371
|
+
};
|
|
372
|
+
var HFInferenceZeroShotImageClassificationTask = class extends HFInferenceTask {
|
|
373
|
+
async getResponse(response) {
|
|
374
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
|
|
375
|
+
return response;
|
|
338
376
|
}
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
const
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
await delay(500);
|
|
348
|
-
const statusResponse = await fetch(statusUrl, { headers });
|
|
349
|
-
if (!statusResponse.ok) {
|
|
350
|
-
throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
|
|
351
|
-
}
|
|
352
|
-
try {
|
|
353
|
-
status = (await statusResponse.json()).status;
|
|
354
|
-
} catch (error) {
|
|
355
|
-
throw new InferenceOutputError("Failed to parse status response from fal-ai API");
|
|
356
|
-
}
|
|
377
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
378
|
+
}
|
|
379
|
+
};
|
|
380
|
+
var HFInferenceTextClassificationTask = class extends HFInferenceTask {
|
|
381
|
+
async getResponse(response) {
|
|
382
|
+
const output = response?.[0];
|
|
383
|
+
if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
|
|
384
|
+
return output;
|
|
357
385
|
}
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
386
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
387
|
+
}
|
|
388
|
+
};
|
|
389
|
+
var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask {
|
|
390
|
+
async getResponse(response) {
|
|
391
|
+
if (Array.isArray(response) ? response.every(
|
|
392
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
|
|
393
|
+
) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number") {
|
|
394
|
+
return Array.isArray(response) ? response[0] : response;
|
|
364
395
|
}
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
)
|
|
396
|
+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
397
|
+
}
|
|
398
|
+
};
|
|
399
|
+
var HFInferenceFillMaskTask = class extends HFInferenceTask {
|
|
400
|
+
async getResponse(response) {
|
|
401
|
+
if (Array.isArray(response) && response.every(
|
|
402
|
+
(x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
|
|
403
|
+
)) {
|
|
404
|
+
return response;
|
|
372
405
|
}
|
|
406
|
+
throw new InferenceOutputError(
|
|
407
|
+
"Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
|
|
408
|
+
);
|
|
373
409
|
}
|
|
374
410
|
};
|
|
375
|
-
var
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
411
|
+
var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask {
|
|
412
|
+
async getResponse(response) {
|
|
413
|
+
if (Array.isArray(response) && response.every(
|
|
414
|
+
(x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
|
|
415
|
+
)) {
|
|
416
|
+
return response;
|
|
417
|
+
}
|
|
418
|
+
throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
|
|
380
419
|
}
|
|
420
|
+
};
|
|
421
|
+
var HFInferenceSentenceSimilarityTask = class extends HFInferenceTask {
|
|
381
422
|
async getResponse(response) {
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
throw new InferenceOutputError(
|
|
385
|
-
`Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
|
|
386
|
-
);
|
|
423
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
424
|
+
return response;
|
|
387
425
|
}
|
|
388
|
-
|
|
426
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
389
427
|
}
|
|
390
428
|
};
|
|
391
|
-
var
|
|
392
|
-
|
|
393
|
-
return
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
text: params.args.inputs
|
|
397
|
-
};
|
|
429
|
+
var HFInferenceTableQuestionAnsweringTask = class extends HFInferenceTask {
|
|
430
|
+
static validate(elem) {
|
|
431
|
+
return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
|
|
432
|
+
(coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
|
|
433
|
+
);
|
|
398
434
|
}
|
|
399
435
|
async getResponse(response) {
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
throw new InferenceOutputError(
|
|
403
|
-
`Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
|
|
404
|
-
);
|
|
436
|
+
if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) {
|
|
437
|
+
return Array.isArray(response) ? response[0] : response;
|
|
405
438
|
}
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
439
|
+
throw new InferenceOutputError(
|
|
440
|
+
"Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
|
|
441
|
+
);
|
|
442
|
+
}
|
|
443
|
+
};
|
|
444
|
+
var HFInferenceTokenClassificationTask = class extends HFInferenceTask {
|
|
445
|
+
async getResponse(response) {
|
|
446
|
+
if (Array.isArray(response) && response.every(
|
|
447
|
+
(x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
|
|
448
|
+
)) {
|
|
449
|
+
return response;
|
|
416
450
|
}
|
|
451
|
+
throw new InferenceOutputError(
|
|
452
|
+
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
|
|
453
|
+
);
|
|
417
454
|
}
|
|
418
455
|
};
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
456
|
+
var HFInferenceTranslationTask = class extends HFInferenceTask {
|
|
457
|
+
async getResponse(response) {
|
|
458
|
+
if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
|
|
459
|
+
return response?.length === 1 ? response?.[0] : response;
|
|
460
|
+
}
|
|
461
|
+
throw new InferenceOutputError("Expected Array<{translation_text: string}>");
|
|
425
462
|
}
|
|
426
463
|
};
|
|
427
|
-
var
|
|
428
|
-
|
|
429
|
-
|
|
464
|
+
var HFInferenceSummarizationTask = class extends HFInferenceTask {
|
|
465
|
+
async getResponse(response) {
|
|
466
|
+
if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
|
|
467
|
+
return response?.[0];
|
|
468
|
+
}
|
|
469
|
+
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
|
430
470
|
}
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
model: params.model,
|
|
436
|
-
prompt: params.args.inputs
|
|
437
|
-
};
|
|
471
|
+
};
|
|
472
|
+
var HFInferenceTextToSpeechTask = class extends HFInferenceTask {
|
|
473
|
+
async getResponse(response) {
|
|
474
|
+
return response;
|
|
438
475
|
}
|
|
476
|
+
};
|
|
477
|
+
var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
|
|
439
478
|
async getResponse(response) {
|
|
440
|
-
if (
|
|
441
|
-
|
|
442
|
-
return {
|
|
443
|
-
generated_text: completion.text
|
|
444
|
-
};
|
|
479
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
480
|
+
return response;
|
|
445
481
|
}
|
|
446
|
-
throw new InferenceOutputError("Expected
|
|
482
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
447
483
|
}
|
|
448
484
|
};
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
485
|
+
var HFInferenceVisualQuestionAnsweringTask = class extends HFInferenceTask {
|
|
486
|
+
async getResponse(response) {
|
|
487
|
+
if (Array.isArray(response) && response.every(
|
|
488
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
|
|
489
|
+
)) {
|
|
490
|
+
return response[0];
|
|
491
|
+
}
|
|
492
|
+
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
|
454
493
|
}
|
|
455
|
-
|
|
456
|
-
|
|
494
|
+
};
|
|
495
|
+
var HFInferenceTabularRegressionTask = class extends HFInferenceTask {
|
|
496
|
+
async getResponse(response) {
|
|
497
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
498
|
+
return response;
|
|
499
|
+
}
|
|
500
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
501
|
+
}
|
|
502
|
+
};
|
|
503
|
+
var HFInferenceTextToAudioTask = class extends HFInferenceTask {
|
|
504
|
+
async getResponse(response) {
|
|
505
|
+
return response;
|
|
457
506
|
}
|
|
458
507
|
};
|
|
459
508
|
|
|
460
|
-
// src/
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
509
|
+
// src/utils/typedInclude.ts
|
|
510
|
+
function typedInclude(arr, v) {
|
|
511
|
+
return arr.includes(v);
|
|
512
|
+
}
|
|
513
|
+
|
|
514
|
+
// src/lib/getInferenceProviderMapping.ts
|
|
515
|
+
var inferenceProviderMappingCache = /* @__PURE__ */ new Map();
|
|
516
|
+
async function fetchInferenceProviderMappingForModel(modelId, accessToken, options) {
|
|
517
|
+
let inferenceProviderMapping;
|
|
518
|
+
if (inferenceProviderMappingCache.has(modelId)) {
|
|
519
|
+
inferenceProviderMapping = inferenceProviderMappingCache.get(modelId);
|
|
520
|
+
} else {
|
|
521
|
+
const resp = await (options?.fetch ?? fetch)(
|
|
522
|
+
`${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
|
|
523
|
+
{
|
|
524
|
+
headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {}
|
|
525
|
+
}
|
|
526
|
+
);
|
|
527
|
+
if (resp.status === 404) {
|
|
528
|
+
throw new Error(`Model ${modelId} does not exist`);
|
|
529
|
+
}
|
|
530
|
+
inferenceProviderMapping = await resp.json().then((json) => json.inferenceProviderMapping).catch(() => null);
|
|
531
|
+
if (inferenceProviderMapping) {
|
|
532
|
+
inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
|
|
533
|
+
}
|
|
465
534
|
}
|
|
466
|
-
|
|
467
|
-
|
|
535
|
+
if (!inferenceProviderMapping) {
|
|
536
|
+
throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
|
|
468
537
|
}
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
538
|
+
return inferenceProviderMapping;
|
|
539
|
+
}
|
|
540
|
+
async function getInferenceProviderMapping(params, options) {
|
|
541
|
+
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
|
|
542
|
+
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
|
|
473
543
|
}
|
|
474
|
-
|
|
475
|
-
|
|
544
|
+
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
|
|
545
|
+
params.modelId,
|
|
546
|
+
params.accessToken,
|
|
547
|
+
options
|
|
548
|
+
);
|
|
549
|
+
const providerMapping = inferenceProviderMapping[params.provider];
|
|
550
|
+
if (providerMapping) {
|
|
551
|
+
const equivalentTasks = params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task];
|
|
552
|
+
if (!typedInclude(equivalentTasks, providerMapping.task)) {
|
|
553
|
+
throw new Error(
|
|
554
|
+
`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
|
|
555
|
+
);
|
|
556
|
+
}
|
|
557
|
+
if (providerMapping.status === "staging") {
|
|
558
|
+
console.warn(
|
|
559
|
+
`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
|
|
560
|
+
);
|
|
561
|
+
}
|
|
562
|
+
return { ...providerMapping, hfModelId: params.modelId };
|
|
476
563
|
}
|
|
477
|
-
|
|
564
|
+
return null;
|
|
565
|
+
}
|
|
566
|
+
async function resolveProvider(provider, modelId, endpointUrl) {
|
|
567
|
+
if (endpointUrl) {
|
|
568
|
+
if (provider) {
|
|
569
|
+
throw new Error("Specifying both endpointUrl and provider is not supported.");
|
|
570
|
+
}
|
|
571
|
+
return "hf-inference";
|
|
572
|
+
}
|
|
573
|
+
if (!provider) {
|
|
574
|
+
console.log(
|
|
575
|
+
"Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
|
|
576
|
+
);
|
|
577
|
+
provider = "auto";
|
|
578
|
+
}
|
|
579
|
+
if (provider === "auto") {
|
|
580
|
+
if (!modelId) {
|
|
581
|
+
throw new Error("Specifying a model is required when provider is 'auto'");
|
|
582
|
+
}
|
|
583
|
+
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
|
|
584
|
+
provider = Object.keys(inferenceProviderMapping)[0];
|
|
585
|
+
}
|
|
586
|
+
if (!provider) {
|
|
587
|
+
throw new Error(`No Inference Provider available for model ${modelId}.`);
|
|
588
|
+
}
|
|
589
|
+
return provider;
|
|
590
|
+
}
|
|
478
591
|
|
|
479
|
-
// src/
|
|
480
|
-
|
|
481
|
-
|
|
592
|
+
// src/utils/delay.ts
|
|
593
|
+
function delay(ms) {
|
|
594
|
+
return new Promise((resolve) => {
|
|
595
|
+
setTimeout(() => resolve(), ms);
|
|
596
|
+
});
|
|
597
|
+
}
|
|
598
|
+
|
|
599
|
+
// src/utils/pick.ts
|
|
600
|
+
function pick(o, props) {
|
|
601
|
+
return Object.assign(
|
|
602
|
+
{},
|
|
603
|
+
...props.map((prop) => {
|
|
604
|
+
if (o[prop] !== void 0) {
|
|
605
|
+
return { [prop]: o[prop] };
|
|
606
|
+
}
|
|
607
|
+
})
|
|
608
|
+
);
|
|
609
|
+
}
|
|
610
|
+
|
|
611
|
+
// src/utils/omit.ts
|
|
612
|
+
function omit(o, props) {
|
|
613
|
+
const propsArr = Array.isArray(props) ? props : [props];
|
|
614
|
+
const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
|
|
615
|
+
return pick(o, letsKeep);
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
// src/providers/black-forest-labs.ts
|
|
619
|
+
var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
|
|
620
|
+
var BlackForestLabsTextToImageTask = class extends TaskProviderHelper {
|
|
482
621
|
constructor() {
|
|
483
|
-
super("
|
|
622
|
+
super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
|
|
484
623
|
}
|
|
485
624
|
preparePayload(params) {
|
|
486
|
-
return
|
|
625
|
+
return {
|
|
626
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
627
|
+
...params.args.parameters,
|
|
628
|
+
prompt: params.args.inputs
|
|
629
|
+
};
|
|
487
630
|
}
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
631
|
+
prepareHeaders(params, binary) {
|
|
632
|
+
const headers = {
|
|
633
|
+
Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`
|
|
634
|
+
};
|
|
635
|
+
if (!binary) {
|
|
636
|
+
headers["Content-Type"] = "application/json";
|
|
491
637
|
}
|
|
492
|
-
return
|
|
638
|
+
return headers;
|
|
493
639
|
}
|
|
494
640
|
makeRoute(params) {
|
|
495
|
-
if (params
|
|
496
|
-
|
|
641
|
+
if (!params) {
|
|
642
|
+
throw new Error("Params are required");
|
|
497
643
|
}
|
|
498
|
-
return
|
|
499
|
-
}
|
|
500
|
-
async getResponse(response) {
|
|
501
|
-
return response;
|
|
644
|
+
return `/v1/${params.model}`;
|
|
502
645
|
}
|
|
503
|
-
};
|
|
504
|
-
var HFInferenceTextToImageTask = class extends HFInferenceTask {
|
|
505
646
|
async getResponse(response, url, headers, outputType) {
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
}
|
|
515
|
-
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
516
|
-
return await base64Response.blob();
|
|
647
|
+
const urlObj = new URL(response.polling_url);
|
|
648
|
+
for (let step = 0; step < 5; step++) {
|
|
649
|
+
await delay(1e3);
|
|
650
|
+
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
|
|
651
|
+
urlObj.searchParams.set("attempt", step.toString(10));
|
|
652
|
+
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
|
|
653
|
+
if (!resp.ok) {
|
|
654
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
517
655
|
}
|
|
518
|
-
|
|
656
|
+
const payload = await resp.json();
|
|
657
|
+
if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
|
|
519
658
|
if (outputType === "url") {
|
|
520
|
-
return
|
|
659
|
+
return payload.result.sample;
|
|
521
660
|
}
|
|
522
|
-
const
|
|
523
|
-
|
|
524
|
-
return blob;
|
|
525
|
-
}
|
|
526
|
-
}
|
|
527
|
-
if (response instanceof Blob) {
|
|
528
|
-
if (outputType === "url") {
|
|
529
|
-
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
|
|
530
|
-
return `data:image/jpeg;base64,${b64}`;
|
|
661
|
+
const image = await fetch(payload.result.sample);
|
|
662
|
+
return await image.blob();
|
|
531
663
|
}
|
|
532
|
-
return response;
|
|
533
664
|
}
|
|
534
|
-
throw new InferenceOutputError("
|
|
665
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
535
666
|
}
|
|
536
667
|
};
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
} else {
|
|
543
|
-
url = `${this.makeBaseUrl(params)}/models/${params.model}`;
|
|
544
|
-
}
|
|
545
|
-
url = url.replace(/\/+$/, "");
|
|
546
|
-
if (url.endsWith("/v1")) {
|
|
547
|
-
url += "/chat/completions";
|
|
548
|
-
} else if (!url.endsWith("/chat/completions")) {
|
|
549
|
-
url += "/v1/chat/completions";
|
|
550
|
-
}
|
|
551
|
-
return url;
|
|
552
|
-
}
|
|
553
|
-
preparePayload(params) {
|
|
554
|
-
return {
|
|
555
|
-
...params.args,
|
|
556
|
-
model: params.model
|
|
557
|
-
};
|
|
558
|
-
}
|
|
559
|
-
async getResponse(response) {
|
|
560
|
-
return response;
|
|
668
|
+
|
|
669
|
+
// src/providers/cerebras.ts
|
|
670
|
+
var CerebrasConversationalTask = class extends BaseConversationalTask {
|
|
671
|
+
constructor() {
|
|
672
|
+
super("cerebras", "https://api.cerebras.ai");
|
|
561
673
|
}
|
|
562
674
|
};
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
}
|
|
569
|
-
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
675
|
+
|
|
676
|
+
// src/providers/cohere.ts
|
|
677
|
+
var CohereConversationalTask = class extends BaseConversationalTask {
|
|
678
|
+
constructor() {
|
|
679
|
+
super("cohere", "https://api.cohere.com");
|
|
570
680
|
}
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
async getResponse(response) {
|
|
574
|
-
if (Array.isArray(response) && response.every(
|
|
575
|
-
(x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
|
|
576
|
-
)) {
|
|
577
|
-
return response;
|
|
578
|
-
}
|
|
579
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
|
|
681
|
+
makeRoute() {
|
|
682
|
+
return "/compatibility/v1/chat/completions";
|
|
580
683
|
}
|
|
581
684
|
};
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
685
|
+
|
|
686
|
+
// src/lib/isUrl.ts
|
|
687
|
+
function isUrl(modelOrUrl) {
|
|
688
|
+
return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
|
|
689
|
+
}
|
|
690
|
+
|
|
691
|
+
// src/providers/fal-ai.ts
|
|
692
|
+
var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
|
|
693
|
+
var FalAITask = class extends TaskProviderHelper {
|
|
694
|
+
constructor(url) {
|
|
695
|
+
super("fal-ai", url || "https://fal.run");
|
|
585
696
|
}
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
async getResponse(response) {
|
|
589
|
-
if (!Array.isArray(response)) {
|
|
590
|
-
throw new InferenceOutputError("Expected Array");
|
|
591
|
-
}
|
|
592
|
-
if (!response.every((elem) => {
|
|
593
|
-
return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
|
|
594
|
-
})) {
|
|
595
|
-
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
|
|
596
|
-
}
|
|
597
|
-
return response;
|
|
697
|
+
preparePayload(params) {
|
|
698
|
+
return params.args;
|
|
598
699
|
}
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
async getResponse(response) {
|
|
602
|
-
if (Array.isArray(response) && response.every(
|
|
603
|
-
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
|
|
604
|
-
)) {
|
|
605
|
-
return response[0];
|
|
606
|
-
}
|
|
607
|
-
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
700
|
+
makeRoute(params) {
|
|
701
|
+
return `/${params.model}`;
|
|
608
702
|
}
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
|
|
613
|
-
if (curDepth > maxDepth)
|
|
614
|
-
return false;
|
|
615
|
-
if (arr.every((x) => Array.isArray(x))) {
|
|
616
|
-
return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
|
|
617
|
-
} else {
|
|
618
|
-
return arr.every((x) => typeof x === "number");
|
|
619
|
-
}
|
|
703
|
+
prepareHeaders(params, binary) {
|
|
704
|
+
const headers = {
|
|
705
|
+
Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`
|
|
620
706
|
};
|
|
621
|
-
if (
|
|
622
|
-
|
|
707
|
+
if (!binary) {
|
|
708
|
+
headers["Content-Type"] = "application/json";
|
|
623
709
|
}
|
|
624
|
-
|
|
710
|
+
return headers;
|
|
625
711
|
}
|
|
626
712
|
};
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
713
|
+
function buildLoraPath(modelId, adapterWeightsPath) {
|
|
714
|
+
return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
|
|
715
|
+
}
|
|
716
|
+
var FalAITextToImageTask = class extends FalAITask {
|
|
717
|
+
preparePayload(params) {
|
|
718
|
+
const payload = {
|
|
719
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
720
|
+
...params.args.parameters,
|
|
721
|
+
sync_mode: true,
|
|
722
|
+
prompt: params.args.inputs
|
|
723
|
+
};
|
|
724
|
+
if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
|
|
725
|
+
payload.loras = [
|
|
726
|
+
{
|
|
727
|
+
path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
|
|
728
|
+
scale: 1
|
|
729
|
+
}
|
|
730
|
+
];
|
|
731
|
+
if (params.mapping.providerId === "fal-ai/lora") {
|
|
732
|
+
payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
|
|
733
|
+
}
|
|
631
734
|
}
|
|
632
|
-
|
|
735
|
+
return payload;
|
|
633
736
|
}
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
737
|
+
async getResponse(response, outputType) {
|
|
738
|
+
if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") {
|
|
739
|
+
if (outputType === "url") {
|
|
740
|
+
return response.images[0].url;
|
|
741
|
+
}
|
|
742
|
+
const urlResponse = await fetch(response.images[0].url);
|
|
743
|
+
return await urlResponse.blob();
|
|
639
744
|
}
|
|
640
|
-
throw new InferenceOutputError("Expected
|
|
745
|
+
throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
|
|
641
746
|
}
|
|
642
747
|
};
|
|
643
|
-
var
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
throw new InferenceOutputError("Expected {generated_text: string}");
|
|
647
|
-
}
|
|
648
|
-
return response;
|
|
748
|
+
var FalAITextToVideoTask = class extends FalAITask {
|
|
749
|
+
constructor() {
|
|
750
|
+
super("https://queue.fal.run");
|
|
649
751
|
}
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
if (response instanceof Blob) {
|
|
654
|
-
return response;
|
|
752
|
+
makeRoute(params) {
|
|
753
|
+
if (params.authMethod !== "provider-key") {
|
|
754
|
+
return `/${params.model}?_subdomain=queue`;
|
|
655
755
|
}
|
|
656
|
-
|
|
756
|
+
return `/${params.model}`;
|
|
657
757
|
}
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
return response;
|
|
665
|
-
}
|
|
666
|
-
throw new InferenceOutputError(
|
|
667
|
-
"Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
|
|
668
|
-
);
|
|
758
|
+
preparePayload(params) {
|
|
759
|
+
return {
|
|
760
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
761
|
+
...params.args.parameters,
|
|
762
|
+
prompt: params.args.inputs
|
|
763
|
+
};
|
|
669
764
|
}
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
|
|
674
|
-
return response;
|
|
765
|
+
async getResponse(response, url, headers) {
|
|
766
|
+
if (!url || !headers) {
|
|
767
|
+
throw new InferenceOutputError("URL and headers are required for text-to-video task");
|
|
675
768
|
}
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
var HFInferenceTextClassificationTask = class extends HFInferenceTask {
|
|
680
|
-
async getResponse(response) {
|
|
681
|
-
const output = response?.[0];
|
|
682
|
-
if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
|
|
683
|
-
return output;
|
|
769
|
+
const requestId = response.request_id;
|
|
770
|
+
if (!requestId) {
|
|
771
|
+
throw new InferenceOutputError("No request ID found in the response");
|
|
684
772
|
}
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
}
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
773
|
+
let status = response.status;
|
|
774
|
+
const parsedUrl = new URL(url);
|
|
775
|
+
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
|
|
776
|
+
const modelId = new URL(response.response_url).pathname;
|
|
777
|
+
const queryParams = parsedUrl.search;
|
|
778
|
+
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
|
|
779
|
+
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
|
|
780
|
+
while (status !== "COMPLETED") {
|
|
781
|
+
await delay(500);
|
|
782
|
+
const statusResponse = await fetch(statusUrl, { headers });
|
|
783
|
+
if (!statusResponse.ok) {
|
|
784
|
+
throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
|
|
785
|
+
}
|
|
786
|
+
try {
|
|
787
|
+
status = (await statusResponse.json()).status;
|
|
788
|
+
} catch (error) {
|
|
789
|
+
throw new InferenceOutputError("Failed to parse status response from fal-ai API");
|
|
790
|
+
}
|
|
694
791
|
}
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
(x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
|
|
702
|
-
)) {
|
|
703
|
-
return response;
|
|
792
|
+
const resultResponse = await fetch(resultUrl, { headers });
|
|
793
|
+
let result;
|
|
794
|
+
try {
|
|
795
|
+
result = await resultResponse.json();
|
|
796
|
+
} catch (error) {
|
|
797
|
+
throw new InferenceOutputError("Failed to parse result response from fal-ai API");
|
|
704
798
|
}
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
if (Array.isArray(response) && response.every(
|
|
713
|
-
(x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
|
|
714
|
-
)) {
|
|
715
|
-
return response;
|
|
799
|
+
if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
|
|
800
|
+
const urlResponse = await fetch(result.video.url);
|
|
801
|
+
return await urlResponse.blob();
|
|
802
|
+
} else {
|
|
803
|
+
throw new InferenceOutputError(
|
|
804
|
+
"Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
|
|
805
|
+
);
|
|
716
806
|
}
|
|
717
|
-
throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
|
|
718
807
|
}
|
|
719
808
|
};
|
|
720
|
-
var
|
|
809
|
+
var FalAIAutomaticSpeechRecognitionTask = class extends FalAITask {
|
|
810
|
+
prepareHeaders(params, binary) {
|
|
811
|
+
const headers = super.prepareHeaders(params, binary);
|
|
812
|
+
headers["Content-Type"] = "application/json";
|
|
813
|
+
return headers;
|
|
814
|
+
}
|
|
721
815
|
async getResponse(response) {
|
|
722
|
-
|
|
723
|
-
|
|
816
|
+
const res = response;
|
|
817
|
+
if (typeof res?.text !== "string") {
|
|
818
|
+
throw new InferenceOutputError(
|
|
819
|
+
`Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
|
|
820
|
+
);
|
|
724
821
|
}
|
|
725
|
-
|
|
822
|
+
return { text: res.text };
|
|
726
823
|
}
|
|
727
824
|
};
|
|
728
|
-
var
|
|
729
|
-
|
|
730
|
-
return
|
|
731
|
-
(
|
|
732
|
-
|
|
825
|
+
var FalAITextToSpeechTask = class extends FalAITask {
|
|
826
|
+
preparePayload(params) {
|
|
827
|
+
return {
|
|
828
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
829
|
+
...params.args.parameters,
|
|
830
|
+
text: params.args.inputs
|
|
831
|
+
};
|
|
733
832
|
}
|
|
734
833
|
async getResponse(response) {
|
|
735
|
-
|
|
736
|
-
|
|
834
|
+
const res = response;
|
|
835
|
+
if (typeof res?.audio?.url !== "string") {
|
|
836
|
+
throw new InferenceOutputError(
|
|
837
|
+
`Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
|
|
838
|
+
);
|
|
737
839
|
}
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
}
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
return response;
|
|
840
|
+
try {
|
|
841
|
+
const urlResponse = await fetch(res.audio.url);
|
|
842
|
+
if (!urlResponse.ok) {
|
|
843
|
+
throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
|
|
844
|
+
}
|
|
845
|
+
return await urlResponse.blob();
|
|
846
|
+
} catch (error) {
|
|
847
|
+
throw new InferenceOutputError(
|
|
848
|
+
`Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)}`
|
|
849
|
+
);
|
|
749
850
|
}
|
|
750
|
-
throw new InferenceOutputError(
|
|
751
|
-
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
|
|
752
|
-
);
|
|
753
851
|
}
|
|
754
852
|
};
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
853
|
+
|
|
854
|
+
// src/providers/featherless-ai.ts
|
|
855
|
+
var FEATHERLESS_API_BASE_URL = "https://api.featherless.ai";
|
|
856
|
+
var FeatherlessAIConversationalTask = class extends BaseConversationalTask {
|
|
857
|
+
constructor() {
|
|
858
|
+
super("featherless-ai", FEATHERLESS_API_BASE_URL);
|
|
761
859
|
}
|
|
762
860
|
};
|
|
763
|
-
var
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
return response?.[0];
|
|
767
|
-
}
|
|
768
|
-
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
|
861
|
+
var FeatherlessAITextGenerationTask = class extends BaseTextGenerationTask {
|
|
862
|
+
constructor() {
|
|
863
|
+
super("featherless-ai", FEATHERLESS_API_BASE_URL);
|
|
769
864
|
}
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
865
|
+
preparePayload(params) {
|
|
866
|
+
return {
|
|
867
|
+
...params.args,
|
|
868
|
+
...params.args.parameters,
|
|
869
|
+
model: params.model,
|
|
870
|
+
prompt: params.args.inputs
|
|
871
|
+
};
|
|
774
872
|
}
|
|
775
|
-
};
|
|
776
|
-
var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
|
|
777
873
|
async getResponse(response) {
|
|
778
|
-
if (
|
|
779
|
-
|
|
874
|
+
if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
|
|
875
|
+
const completion = response.choices[0];
|
|
876
|
+
return {
|
|
877
|
+
generated_text: completion.text
|
|
878
|
+
};
|
|
780
879
|
}
|
|
781
|
-
throw new InferenceOutputError("Expected
|
|
880
|
+
throw new InferenceOutputError("Expected Featherless AI text generation response format");
|
|
782
881
|
}
|
|
783
882
|
};
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
)
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
883
|
+
|
|
884
|
+
// src/providers/fireworks-ai.ts
|
|
885
|
+
var FireworksConversationalTask = class extends BaseConversationalTask {
|
|
886
|
+
constructor() {
|
|
887
|
+
super("fireworks-ai", "https://api.fireworks.ai");
|
|
888
|
+
}
|
|
889
|
+
makeRoute() {
|
|
890
|
+
return "/inference/v1/chat/completions";
|
|
792
891
|
}
|
|
793
892
|
};
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
893
|
+
|
|
894
|
+
// src/providers/groq.ts
|
|
895
|
+
var GROQ_API_BASE_URL = "https://api.groq.com";
|
|
896
|
+
var GroqTextGenerationTask = class extends BaseTextGenerationTask {
|
|
897
|
+
constructor() {
|
|
898
|
+
super("groq", GROQ_API_BASE_URL);
|
|
899
|
+
}
|
|
900
|
+
makeRoute() {
|
|
901
|
+
return "/openai/v1/chat/completions";
|
|
800
902
|
}
|
|
801
903
|
};
|
|
802
|
-
var
|
|
803
|
-
|
|
804
|
-
|
|
904
|
+
var GroqConversationalTask = class extends BaseConversationalTask {
|
|
905
|
+
constructor() {
|
|
906
|
+
super("groq", GROQ_API_BASE_URL);
|
|
907
|
+
}
|
|
908
|
+
makeRoute() {
|
|
909
|
+
return "/openai/v1/chat/completions";
|
|
805
910
|
}
|
|
806
911
|
};
|
|
807
912
|
|
|
@@ -968,6 +1073,39 @@ var OpenAIConversationalTask = class extends BaseConversationalTask {
|
|
|
968
1073
|
}
|
|
969
1074
|
};
|
|
970
1075
|
|
|
1076
|
+
// src/providers/ovhcloud.ts
|
|
1077
|
+
var OVHCLOUD_API_BASE_URL = "https://oai.endpoints.kepler.ai.cloud.ovh.net";
|
|
1078
|
+
var OvhCloudConversationalTask = class extends BaseConversationalTask {
|
|
1079
|
+
constructor() {
|
|
1080
|
+
super("ovhcloud", OVHCLOUD_API_BASE_URL);
|
|
1081
|
+
}
|
|
1082
|
+
};
|
|
1083
|
+
var OvhCloudTextGenerationTask = class extends BaseTextGenerationTask {
|
|
1084
|
+
constructor() {
|
|
1085
|
+
super("ovhcloud", OVHCLOUD_API_BASE_URL);
|
|
1086
|
+
}
|
|
1087
|
+
preparePayload(params) {
|
|
1088
|
+
return {
|
|
1089
|
+
model: params.model,
|
|
1090
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
1091
|
+
...params.args.parameters ? {
|
|
1092
|
+
max_tokens: params.args.parameters.max_new_tokens,
|
|
1093
|
+
...omit(params.args.parameters, "max_new_tokens")
|
|
1094
|
+
} : void 0,
|
|
1095
|
+
prompt: params.args.inputs
|
|
1096
|
+
};
|
|
1097
|
+
}
|
|
1098
|
+
async getResponse(response) {
|
|
1099
|
+
if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
|
|
1100
|
+
const completion = response.choices[0];
|
|
1101
|
+
return {
|
|
1102
|
+
generated_text: completion.text
|
|
1103
|
+
};
|
|
1104
|
+
}
|
|
1105
|
+
throw new InferenceOutputError("Expected OVHcloud text generation response format");
|
|
1106
|
+
}
|
|
1107
|
+
};
|
|
1108
|
+
|
|
971
1109
|
// src/providers/replicate.ts
|
|
972
1110
|
var ReplicateTask = class extends TaskProviderHelper {
|
|
973
1111
|
constructor(url) {
|
|
@@ -1220,6 +1358,10 @@ var PROVIDERS = {
|
|
|
1220
1358
|
openai: {
|
|
1221
1359
|
conversational: new OpenAIConversationalTask()
|
|
1222
1360
|
},
|
|
1361
|
+
ovhcloud: {
|
|
1362
|
+
conversational: new OvhCloudConversationalTask(),
|
|
1363
|
+
"text-generation": new OvhCloudTextGenerationTask()
|
|
1364
|
+
},
|
|
1223
1365
|
replicate: {
|
|
1224
1366
|
"text-to-image": new ReplicateTextToImageTask(),
|
|
1225
1367
|
"text-to-speech": new ReplicateTextToSpeechTask(),
|
|
@@ -1258,81 +1400,13 @@ function getProviderHelper(provider, task) {
|
|
|
1258
1400
|
|
|
1259
1401
|
// package.json
|
|
1260
1402
|
var name = "@huggingface/inference";
|
|
1261
|
-
var version = "3.
|
|
1262
|
-
|
|
1263
|
-
// src/providers/consts.ts
|
|
1264
|
-
var HARDCODED_MODEL_INFERENCE_MAPPING = {
|
|
1265
|
-
/**
|
|
1266
|
-
* "HF model ID" => "Model ID on Inference Provider's side"
|
|
1267
|
-
*
|
|
1268
|
-
* Example:
|
|
1269
|
-
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
|
|
1270
|
-
*/
|
|
1271
|
-
"black-forest-labs": {},
|
|
1272
|
-
cerebras: {},
|
|
1273
|
-
cohere: {},
|
|
1274
|
-
"fal-ai": {},
|
|
1275
|
-
"featherless-ai": {},
|
|
1276
|
-
"fireworks-ai": {},
|
|
1277
|
-
groq: {},
|
|
1278
|
-
"hf-inference": {},
|
|
1279
|
-
hyperbolic: {},
|
|
1280
|
-
nebius: {},
|
|
1281
|
-
novita: {},
|
|
1282
|
-
nscale: {},
|
|
1283
|
-
openai: {},
|
|
1284
|
-
replicate: {},
|
|
1285
|
-
sambanova: {},
|
|
1286
|
-
together: {}
|
|
1287
|
-
};
|
|
1288
|
-
|
|
1289
|
-
// src/lib/getInferenceProviderMapping.ts
|
|
1290
|
-
var inferenceProviderMappingCache = /* @__PURE__ */ new Map();
|
|
1291
|
-
async function getInferenceProviderMapping(params, options) {
|
|
1292
|
-
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
|
|
1293
|
-
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
|
|
1294
|
-
}
|
|
1295
|
-
let inferenceProviderMapping;
|
|
1296
|
-
if (inferenceProviderMappingCache.has(params.modelId)) {
|
|
1297
|
-
inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId);
|
|
1298
|
-
} else {
|
|
1299
|
-
const resp = await (options?.fetch ?? fetch)(
|
|
1300
|
-
`${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`,
|
|
1301
|
-
{
|
|
1302
|
-
headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {}
|
|
1303
|
-
}
|
|
1304
|
-
);
|
|
1305
|
-
if (resp.status === 404) {
|
|
1306
|
-
throw new Error(`Model ${params.modelId} does not exist`);
|
|
1307
|
-
}
|
|
1308
|
-
inferenceProviderMapping = await resp.json().then((json) => json.inferenceProviderMapping).catch(() => null);
|
|
1309
|
-
}
|
|
1310
|
-
if (!inferenceProviderMapping) {
|
|
1311
|
-
throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`);
|
|
1312
|
-
}
|
|
1313
|
-
const providerMapping = inferenceProviderMapping[params.provider];
|
|
1314
|
-
if (providerMapping) {
|
|
1315
|
-
const equivalentTasks = params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task];
|
|
1316
|
-
if (!typedInclude(equivalentTasks, providerMapping.task)) {
|
|
1317
|
-
throw new Error(
|
|
1318
|
-
`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
|
|
1319
|
-
);
|
|
1320
|
-
}
|
|
1321
|
-
if (providerMapping.status === "staging") {
|
|
1322
|
-
console.warn(
|
|
1323
|
-
`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
|
|
1324
|
-
);
|
|
1325
|
-
}
|
|
1326
|
-
return { ...providerMapping, hfModelId: params.modelId };
|
|
1327
|
-
}
|
|
1328
|
-
return null;
|
|
1329
|
-
}
|
|
1403
|
+
var version = "3.11.0";
|
|
1330
1404
|
|
|
1331
1405
|
// src/lib/makeRequestOptions.ts
|
|
1332
1406
|
var tasks = null;
|
|
1333
1407
|
async function makeRequestOptions(args, providerHelper, options) {
|
|
1334
|
-
const {
|
|
1335
|
-
const provider =
|
|
1408
|
+
const { model: maybeModel } = args;
|
|
1409
|
+
const provider = providerHelper.provider;
|
|
1336
1410
|
const { task } = options ?? {};
|
|
1337
1411
|
if (args.endpointUrl && provider !== "hf-inference") {
|
|
1338
1412
|
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
|
|
@@ -1387,7 +1461,7 @@ async function makeRequestOptions(args, providerHelper, options) {
|
|
|
1387
1461
|
}
|
|
1388
1462
|
function makeRequestOptionsFromResolvedModel(resolvedModel, providerHelper, args, mapping, options) {
|
|
1389
1463
|
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
|
|
1390
|
-
const provider =
|
|
1464
|
+
const provider = providerHelper.provider;
|
|
1391
1465
|
const { includeCredentials, task, signal, billTo } = options ?? {};
|
|
1392
1466
|
const authMethod = (() => {
|
|
1393
1467
|
if (providerHelper.clientSideRoutingOnly) {
|
|
@@ -1678,7 +1752,8 @@ async function request(args, options) {
|
|
|
1678
1752
|
console.warn(
|
|
1679
1753
|
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
|
|
1680
1754
|
);
|
|
1681
|
-
const
|
|
1755
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1756
|
+
const providerHelper = getProviderHelper(provider, options?.task);
|
|
1682
1757
|
const result = await innerRequest(args, providerHelper, options);
|
|
1683
1758
|
return result.data;
|
|
1684
1759
|
}
|
|
@@ -1688,7 +1763,8 @@ async function* streamingRequest(args, options) {
|
|
|
1688
1763
|
console.warn(
|
|
1689
1764
|
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
|
|
1690
1765
|
);
|
|
1691
|
-
const
|
|
1766
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1767
|
+
const providerHelper = getProviderHelper(provider, options?.task);
|
|
1692
1768
|
yield* innerStreamingRequest(args, providerHelper, options);
|
|
1693
1769
|
}
|
|
1694
1770
|
|
|
@@ -1702,7 +1778,8 @@ function preparePayload(args) {
|
|
|
1702
1778
|
|
|
1703
1779
|
// src/tasks/audio/audioClassification.ts
|
|
1704
1780
|
async function audioClassification(args, options) {
|
|
1705
|
-
const
|
|
1781
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1782
|
+
const providerHelper = getProviderHelper(provider, "audio-classification");
|
|
1706
1783
|
const payload = preparePayload(args);
|
|
1707
1784
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1708
1785
|
...options,
|
|
@@ -1713,7 +1790,9 @@ async function audioClassification(args, options) {
|
|
|
1713
1790
|
|
|
1714
1791
|
// src/tasks/audio/audioToAudio.ts
|
|
1715
1792
|
async function audioToAudio(args, options) {
|
|
1716
|
-
const
|
|
1793
|
+
const model = "inputs" in args ? args.model : void 0;
|
|
1794
|
+
const provider = await resolveProvider(args.provider, model);
|
|
1795
|
+
const providerHelper = getProviderHelper(provider, "audio-to-audio");
|
|
1717
1796
|
const payload = preparePayload(args);
|
|
1718
1797
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1719
1798
|
...options,
|
|
@@ -1737,7 +1816,8 @@ function base64FromBytes(arr) {
|
|
|
1737
1816
|
|
|
1738
1817
|
// src/tasks/audio/automaticSpeechRecognition.ts
|
|
1739
1818
|
async function automaticSpeechRecognition(args, options) {
|
|
1740
|
-
const
|
|
1819
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1820
|
+
const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
|
|
1741
1821
|
const payload = await buildPayload(args);
|
|
1742
1822
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1743
1823
|
...options,
|
|
@@ -1777,7 +1857,7 @@ async function buildPayload(args) {
|
|
|
1777
1857
|
|
|
1778
1858
|
// src/tasks/audio/textToSpeech.ts
|
|
1779
1859
|
async function textToSpeech(args, options) {
|
|
1780
|
-
const provider = args.provider
|
|
1860
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1781
1861
|
const providerHelper = getProviderHelper(provider, "text-to-speech");
|
|
1782
1862
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1783
1863
|
...options,
|
|
@@ -1793,7 +1873,8 @@ function preparePayload2(args) {
|
|
|
1793
1873
|
|
|
1794
1874
|
// src/tasks/cv/imageClassification.ts
|
|
1795
1875
|
async function imageClassification(args, options) {
|
|
1796
|
-
const
|
|
1876
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1877
|
+
const providerHelper = getProviderHelper(provider, "image-classification");
|
|
1797
1878
|
const payload = preparePayload2(args);
|
|
1798
1879
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1799
1880
|
...options,
|
|
@@ -1804,7 +1885,8 @@ async function imageClassification(args, options) {
|
|
|
1804
1885
|
|
|
1805
1886
|
// src/tasks/cv/imageSegmentation.ts
|
|
1806
1887
|
async function imageSegmentation(args, options) {
|
|
1807
|
-
const
|
|
1888
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1889
|
+
const providerHelper = getProviderHelper(provider, "image-segmentation");
|
|
1808
1890
|
const payload = preparePayload2(args);
|
|
1809
1891
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1810
1892
|
...options,
|
|
@@ -1815,7 +1897,8 @@ async function imageSegmentation(args, options) {
|
|
|
1815
1897
|
|
|
1816
1898
|
// src/tasks/cv/imageToImage.ts
|
|
1817
1899
|
async function imageToImage(args, options) {
|
|
1818
|
-
const
|
|
1900
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1901
|
+
const providerHelper = getProviderHelper(provider, "image-to-image");
|
|
1819
1902
|
let reqArgs;
|
|
1820
1903
|
if (!args.parameters) {
|
|
1821
1904
|
reqArgs = {
|
|
@@ -1840,7 +1923,8 @@ async function imageToImage(args, options) {
|
|
|
1840
1923
|
|
|
1841
1924
|
// src/tasks/cv/imageToText.ts
|
|
1842
1925
|
async function imageToText(args, options) {
|
|
1843
|
-
const
|
|
1926
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1927
|
+
const providerHelper = getProviderHelper(provider, "image-to-text");
|
|
1844
1928
|
const payload = preparePayload2(args);
|
|
1845
1929
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1846
1930
|
...options,
|
|
@@ -1851,7 +1935,8 @@ async function imageToText(args, options) {
|
|
|
1851
1935
|
|
|
1852
1936
|
// src/tasks/cv/objectDetection.ts
|
|
1853
1937
|
async function objectDetection(args, options) {
|
|
1854
|
-
const
|
|
1938
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1939
|
+
const providerHelper = getProviderHelper(provider, "object-detection");
|
|
1855
1940
|
const payload = preparePayload2(args);
|
|
1856
1941
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1857
1942
|
...options,
|
|
@@ -1862,7 +1947,7 @@ async function objectDetection(args, options) {
|
|
|
1862
1947
|
|
|
1863
1948
|
// src/tasks/cv/textToImage.ts
|
|
1864
1949
|
async function textToImage(args, options) {
|
|
1865
|
-
const provider = args.provider
|
|
1950
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1866
1951
|
const providerHelper = getProviderHelper(provider, "text-to-image");
|
|
1867
1952
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1868
1953
|
...options,
|
|
@@ -1874,7 +1959,7 @@ async function textToImage(args, options) {
|
|
|
1874
1959
|
|
|
1875
1960
|
// src/tasks/cv/textToVideo.ts
|
|
1876
1961
|
async function textToVideo(args, options) {
|
|
1877
|
-
const provider = args.provider
|
|
1962
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1878
1963
|
const providerHelper = getProviderHelper(provider, "text-to-video");
|
|
1879
1964
|
const { data: response } = await innerRequest(
|
|
1880
1965
|
args,
|
|
@@ -1911,7 +1996,8 @@ async function preparePayload3(args) {
|
|
|
1911
1996
|
}
|
|
1912
1997
|
}
|
|
1913
1998
|
async function zeroShotImageClassification(args, options) {
|
|
1914
|
-
const
|
|
1999
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2000
|
+
const providerHelper = getProviderHelper(provider, "zero-shot-image-classification");
|
|
1915
2001
|
const payload = await preparePayload3(args);
|
|
1916
2002
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1917
2003
|
...options,
|
|
@@ -1922,7 +2008,8 @@ async function zeroShotImageClassification(args, options) {
|
|
|
1922
2008
|
|
|
1923
2009
|
// src/tasks/nlp/chatCompletion.ts
|
|
1924
2010
|
async function chatCompletion(args, options) {
|
|
1925
|
-
const
|
|
2011
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2012
|
+
const providerHelper = getProviderHelper(provider, "conversational");
|
|
1926
2013
|
const { data: response } = await innerRequest(args, providerHelper, {
|
|
1927
2014
|
...options,
|
|
1928
2015
|
task: "conversational"
|
|
@@ -1932,7 +2019,8 @@ async function chatCompletion(args, options) {
|
|
|
1932
2019
|
|
|
1933
2020
|
// src/tasks/nlp/chatCompletionStream.ts
|
|
1934
2021
|
async function* chatCompletionStream(args, options) {
|
|
1935
|
-
const
|
|
2022
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2023
|
+
const providerHelper = getProviderHelper(provider, "conversational");
|
|
1936
2024
|
yield* innerStreamingRequest(args, providerHelper, {
|
|
1937
2025
|
...options,
|
|
1938
2026
|
task: "conversational"
|
|
@@ -1941,7 +2029,8 @@ async function* chatCompletionStream(args, options) {
|
|
|
1941
2029
|
|
|
1942
2030
|
// src/tasks/nlp/featureExtraction.ts
|
|
1943
2031
|
async function featureExtraction(args, options) {
|
|
1944
|
-
const
|
|
2032
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2033
|
+
const providerHelper = getProviderHelper(provider, "feature-extraction");
|
|
1945
2034
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1946
2035
|
...options,
|
|
1947
2036
|
task: "feature-extraction"
|
|
@@ -1951,7 +2040,8 @@ async function featureExtraction(args, options) {
|
|
|
1951
2040
|
|
|
1952
2041
|
// src/tasks/nlp/fillMask.ts
|
|
1953
2042
|
async function fillMask(args, options) {
|
|
1954
|
-
const
|
|
2043
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2044
|
+
const providerHelper = getProviderHelper(provider, "fill-mask");
|
|
1955
2045
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1956
2046
|
...options,
|
|
1957
2047
|
task: "fill-mask"
|
|
@@ -1961,7 +2051,8 @@ async function fillMask(args, options) {
|
|
|
1961
2051
|
|
|
1962
2052
|
// src/tasks/nlp/questionAnswering.ts
|
|
1963
2053
|
async function questionAnswering(args, options) {
|
|
1964
|
-
const
|
|
2054
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2055
|
+
const providerHelper = getProviderHelper(provider, "question-answering");
|
|
1965
2056
|
const { data: res } = await innerRequest(
|
|
1966
2057
|
args,
|
|
1967
2058
|
providerHelper,
|
|
@@ -1975,7 +2066,8 @@ async function questionAnswering(args, options) {
|
|
|
1975
2066
|
|
|
1976
2067
|
// src/tasks/nlp/sentenceSimilarity.ts
|
|
1977
2068
|
async function sentenceSimilarity(args, options) {
|
|
1978
|
-
const
|
|
2069
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2070
|
+
const providerHelper = getProviderHelper(provider, "sentence-similarity");
|
|
1979
2071
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1980
2072
|
...options,
|
|
1981
2073
|
task: "sentence-similarity"
|
|
@@ -1985,7 +2077,8 @@ async function sentenceSimilarity(args, options) {
|
|
|
1985
2077
|
|
|
1986
2078
|
// src/tasks/nlp/summarization.ts
|
|
1987
2079
|
async function summarization(args, options) {
|
|
1988
|
-
const
|
|
2080
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2081
|
+
const providerHelper = getProviderHelper(provider, "summarization");
|
|
1989
2082
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1990
2083
|
...options,
|
|
1991
2084
|
task: "summarization"
|
|
@@ -1995,7 +2088,8 @@ async function summarization(args, options) {
|
|
|
1995
2088
|
|
|
1996
2089
|
// src/tasks/nlp/tableQuestionAnswering.ts
|
|
1997
2090
|
async function tableQuestionAnswering(args, options) {
|
|
1998
|
-
const
|
|
2091
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2092
|
+
const providerHelper = getProviderHelper(provider, "table-question-answering");
|
|
1999
2093
|
const { data: res } = await innerRequest(
|
|
2000
2094
|
args,
|
|
2001
2095
|
providerHelper,
|
|
@@ -2009,7 +2103,8 @@ async function tableQuestionAnswering(args, options) {
|
|
|
2009
2103
|
|
|
2010
2104
|
// src/tasks/nlp/textClassification.ts
|
|
2011
2105
|
async function textClassification(args, options) {
|
|
2012
|
-
const
|
|
2106
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2107
|
+
const providerHelper = getProviderHelper(provider, "text-classification");
|
|
2013
2108
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2014
2109
|
...options,
|
|
2015
2110
|
task: "text-classification"
|
|
@@ -2019,7 +2114,8 @@ async function textClassification(args, options) {
|
|
|
2019
2114
|
|
|
2020
2115
|
// src/tasks/nlp/textGeneration.ts
|
|
2021
2116
|
async function textGeneration(args, options) {
|
|
2022
|
-
const
|
|
2117
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2118
|
+
const providerHelper = getProviderHelper(provider, "text-generation");
|
|
2023
2119
|
const { data: response } = await innerRequest(args, providerHelper, {
|
|
2024
2120
|
...options,
|
|
2025
2121
|
task: "text-generation"
|
|
@@ -2029,7 +2125,8 @@ async function textGeneration(args, options) {
|
|
|
2029
2125
|
|
|
2030
2126
|
// src/tasks/nlp/textGenerationStream.ts
|
|
2031
2127
|
async function* textGenerationStream(args, options) {
|
|
2032
|
-
const
|
|
2128
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2129
|
+
const providerHelper = getProviderHelper(provider, "text-generation");
|
|
2033
2130
|
yield* innerStreamingRequest(args, providerHelper, {
|
|
2034
2131
|
...options,
|
|
2035
2132
|
task: "text-generation"
|
|
@@ -2038,7 +2135,8 @@ async function* textGenerationStream(args, options) {
|
|
|
2038
2135
|
|
|
2039
2136
|
// src/tasks/nlp/tokenClassification.ts
|
|
2040
2137
|
async function tokenClassification(args, options) {
|
|
2041
|
-
const
|
|
2138
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2139
|
+
const providerHelper = getProviderHelper(provider, "token-classification");
|
|
2042
2140
|
const { data: res } = await innerRequest(
|
|
2043
2141
|
args,
|
|
2044
2142
|
providerHelper,
|
|
@@ -2052,7 +2150,8 @@ async function tokenClassification(args, options) {
|
|
|
2052
2150
|
|
|
2053
2151
|
// src/tasks/nlp/translation.ts
|
|
2054
2152
|
async function translation(args, options) {
|
|
2055
|
-
const
|
|
2153
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2154
|
+
const providerHelper = getProviderHelper(provider, "translation");
|
|
2056
2155
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2057
2156
|
...options,
|
|
2058
2157
|
task: "translation"
|
|
@@ -2062,7 +2161,8 @@ async function translation(args, options) {
|
|
|
2062
2161
|
|
|
2063
2162
|
// src/tasks/nlp/zeroShotClassification.ts
|
|
2064
2163
|
async function zeroShotClassification(args, options) {
|
|
2065
|
-
const
|
|
2164
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2165
|
+
const providerHelper = getProviderHelper(provider, "zero-shot-classification");
|
|
2066
2166
|
const { data: res } = await innerRequest(
|
|
2067
2167
|
args,
|
|
2068
2168
|
providerHelper,
|
|
@@ -2076,7 +2176,8 @@ async function zeroShotClassification(args, options) {
|
|
|
2076
2176
|
|
|
2077
2177
|
// src/tasks/multimodal/documentQuestionAnswering.ts
|
|
2078
2178
|
async function documentQuestionAnswering(args, options) {
|
|
2079
|
-
const
|
|
2179
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2180
|
+
const providerHelper = getProviderHelper(provider, "document-question-answering");
|
|
2080
2181
|
const reqArgs = {
|
|
2081
2182
|
...args,
|
|
2082
2183
|
inputs: {
|
|
@@ -2098,7 +2199,8 @@ async function documentQuestionAnswering(args, options) {
|
|
|
2098
2199
|
|
|
2099
2200
|
// src/tasks/multimodal/visualQuestionAnswering.ts
|
|
2100
2201
|
async function visualQuestionAnswering(args, options) {
|
|
2101
|
-
const
|
|
2202
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2203
|
+
const providerHelper = getProviderHelper(provider, "visual-question-answering");
|
|
2102
2204
|
const reqArgs = {
|
|
2103
2205
|
...args,
|
|
2104
2206
|
inputs: {
|
|
@@ -2116,7 +2218,8 @@ async function visualQuestionAnswering(args, options) {
|
|
|
2116
2218
|
|
|
2117
2219
|
// src/tasks/tabular/tabularClassification.ts
|
|
2118
2220
|
async function tabularClassification(args, options) {
|
|
2119
|
-
const
|
|
2221
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2222
|
+
const providerHelper = getProviderHelper(provider, "tabular-classification");
|
|
2120
2223
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2121
2224
|
...options,
|
|
2122
2225
|
task: "tabular-classification"
|
|
@@ -2126,7 +2229,8 @@ async function tabularClassification(args, options) {
|
|
|
2126
2229
|
|
|
2127
2230
|
// src/tasks/tabular/tabularRegression.ts
|
|
2128
2231
|
async function tabularRegression(args, options) {
|
|
2129
|
-
const
|
|
2232
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2233
|
+
const providerHelper = getProviderHelper(provider, "tabular-regression");
|
|
2130
2234
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2131
2235
|
...options,
|
|
2132
2236
|
task: "tabular-regression"
|
|
@@ -2134,6 +2238,11 @@ async function tabularRegression(args, options) {
|
|
|
2134
2238
|
return providerHelper.getResponse(res);
|
|
2135
2239
|
}
|
|
2136
2240
|
|
|
2241
|
+
// src/utils/typedEntries.ts
|
|
2242
|
+
function typedEntries(obj) {
|
|
2243
|
+
return Object.entries(obj);
|
|
2244
|
+
}
|
|
2245
|
+
|
|
2137
2246
|
// src/InferenceClient.ts
|
|
2138
2247
|
var InferenceClient = class {
|
|
2139
2248
|
accessToken;
|
|
@@ -2141,40 +2250,36 @@ var InferenceClient = class {
|
|
|
2141
2250
|
constructor(accessToken = "", defaultOptions = {}) {
|
|
2142
2251
|
this.accessToken = accessToken;
|
|
2143
2252
|
this.defaultOptions = defaultOptions;
|
|
2144
|
-
for (const [name2, fn] of
|
|
2253
|
+
for (const [name2, fn] of typedEntries(tasks_exports)) {
|
|
2145
2254
|
Object.defineProperty(this, name2, {
|
|
2146
2255
|
enumerable: false,
|
|
2147
2256
|
value: (params, options) => (
|
|
2148
2257
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
2149
|
-
fn(
|
|
2258
|
+
fn(
|
|
2259
|
+
/// ^ The cast of fn to any is necessary, otherwise TS can't compile because the generated union type is too complex
|
|
2260
|
+
{ endpointUrl: defaultOptions.endpointUrl, accessToken, ...params },
|
|
2261
|
+
{
|
|
2262
|
+
...omit(defaultOptions, ["endpointUrl"]),
|
|
2263
|
+
...options
|
|
2264
|
+
}
|
|
2265
|
+
)
|
|
2150
2266
|
)
|
|
2151
2267
|
});
|
|
2152
2268
|
}
|
|
2153
2269
|
}
|
|
2154
2270
|
/**
|
|
2155
|
-
* Returns
|
|
2271
|
+
* Returns a new instance of InferenceClient tied to a specified endpoint.
|
|
2272
|
+
*
|
|
2273
|
+
* For backward compatibility mostly.
|
|
2156
2274
|
*/
|
|
2157
2275
|
endpoint(endpointUrl) {
|
|
2158
|
-
return new
|
|
2159
|
-
}
|
|
2160
|
-
};
|
|
2161
|
-
var InferenceClientEndpoint = class {
|
|
2162
|
-
constructor(endpointUrl, accessToken = "", defaultOptions = {}) {
|
|
2163
|
-
accessToken;
|
|
2164
|
-
defaultOptions;
|
|
2165
|
-
for (const [name2, fn] of Object.entries(tasks_exports)) {
|
|
2166
|
-
Object.defineProperty(this, name2, {
|
|
2167
|
-
enumerable: false,
|
|
2168
|
-
value: (params, options) => (
|
|
2169
|
-
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
2170
|
-
fn({ ...params, accessToken, endpointUrl }, { ...defaultOptions, ...options })
|
|
2171
|
-
)
|
|
2172
|
-
});
|
|
2173
|
-
}
|
|
2276
|
+
return new InferenceClient(this.accessToken, { ...this.defaultOptions, endpointUrl });
|
|
2174
2277
|
}
|
|
2175
2278
|
};
|
|
2176
2279
|
var HfInference = class extends InferenceClient {
|
|
2177
2280
|
};
|
|
2281
|
+
var InferenceClientEndpoint = class extends InferenceClient {
|
|
2282
|
+
};
|
|
2178
2283
|
|
|
2179
2284
|
// src/types.ts
|
|
2180
2285
|
var INFERENCE_PROVIDERS = [
|
|
@@ -2191,10 +2296,12 @@ var INFERENCE_PROVIDERS = [
|
|
|
2191
2296
|
"novita",
|
|
2192
2297
|
"nscale",
|
|
2193
2298
|
"openai",
|
|
2299
|
+
"ovhcloud",
|
|
2194
2300
|
"replicate",
|
|
2195
2301
|
"sambanova",
|
|
2196
2302
|
"together"
|
|
2197
2303
|
];
|
|
2304
|
+
var PROVIDERS_OR_POLICIES = [...INFERENCE_PROVIDERS, "auto"];
|
|
2198
2305
|
|
|
2199
2306
|
// src/snippets/index.ts
|
|
2200
2307
|
var snippets_exports = {};
|
|
@@ -2218,6 +2325,7 @@ var templates = {
|
|
|
2218
2325
|
"basicImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "image/jpeg",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
|
|
2219
2326
|
"textToAudio": '{% if model.library_name == "transformers" %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ',
|
|
2220
2327
|
"textToImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\n\nquery({ {{ providerInputs.asTsString }} }).then((response) => {\n // Use image\n});',
|
|
2328
|
+
"textToSpeech": '{% if model.library_name == "transformers" %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\nquery({ text: {{ inputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ text: {{ inputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ',
|
|
2221
2329
|
"zeroShotClassification": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({\n inputs: {{ providerInputs.asObj.inputs }},\n parameters: { candidate_labels: ["refund", "legal", "faq"] }\n}).then((response) => {\n console.log(JSON.stringify(response));\n});'
|
|
2222
2330
|
},
|
|
2223
2331
|
"huggingface.js": {
|
|
@@ -2239,11 +2347,23 @@ const image = await client.textToImage({
|
|
|
2239
2347
|
billTo: "{{ billTo }}",
|
|
2240
2348
|
}{% endif %});
|
|
2241
2349
|
/// Use the generated image (it's a Blob)`,
|
|
2350
|
+
"textToSpeech": `import { InferenceClient } from "@huggingface/inference";
|
|
2351
|
+
|
|
2352
|
+
const client = new InferenceClient("{{ accessToken }}");
|
|
2353
|
+
|
|
2354
|
+
const audio = await client.textToSpeech({
|
|
2355
|
+
provider: "{{ provider }}",
|
|
2356
|
+
model: "{{ model.id }}",
|
|
2357
|
+
inputs: {{ inputs.asObj.inputs }},
|
|
2358
|
+
}{% if billTo %}, {
|
|
2359
|
+
billTo: "{{ billTo }}",
|
|
2360
|
+
}{% endif %});
|
|
2361
|
+
// Use the generated audio (it's a Blob)`,
|
|
2242
2362
|
"textToVideo": `import { InferenceClient } from "@huggingface/inference";
|
|
2243
2363
|
|
|
2244
2364
|
const client = new InferenceClient("{{ accessToken }}");
|
|
2245
2365
|
|
|
2246
|
-
const
|
|
2366
|
+
const video = await client.textToVideo({
|
|
2247
2367
|
provider: "{{ provider }}",
|
|
2248
2368
|
model: "{{ model.id }}",
|
|
2249
2369
|
inputs: {{ inputs.asObj.inputs }},
|
|
@@ -2259,7 +2379,7 @@ const image = await client.textToVideo({
|
|
|
2259
2379
|
},
|
|
2260
2380
|
"python": {
|
|
2261
2381
|
"fal_client": {
|
|
2262
|
-
"textToImage": '{% if provider == "fal-ai" %}\nimport fal_client\n\nresult = fal_client.subscribe(\n "{{ providerModelId }}",\n arguments={\n "prompt": {{ inputs.asObj.inputs }},\n },\n)\nprint(result)\n{% endif %} '
|
|
2382
|
+
"textToImage": '{% if provider == "fal-ai" %}\nimport fal_client\n\n{% if providerInputs.asObj.loras is defined and providerInputs.asObj.loras != none %}\nresult = fal_client.subscribe(\n "{{ providerModelId }}",\n arguments={\n "prompt": {{ inputs.asObj.inputs }},\n "loras":{{ providerInputs.asObj.loras | tojson }},\n },\n)\n{% else %}\nresult = fal_client.subscribe(\n "{{ providerModelId }}",\n arguments={\n "prompt": {{ inputs.asObj.inputs }},\n },\n)\n{% endif %} \nprint(result)\n{% endif %} '
|
|
2263
2383
|
},
|
|
2264
2384
|
"huggingface_hub": {
|
|
2265
2385
|
"basic": 'result = client.{{ methodName }}(\n inputs={{ inputs.asObj.inputs }},\n model="{{ model.id }}",\n)',
|
|
@@ -2271,6 +2391,7 @@ const image = await client.textToVideo({
|
|
|
2271
2391
|
"imageToImage": '# output is a PIL.Image object\nimage = client.image_to_image(\n "{{ inputs.asObj.inputs }}",\n prompt="{{ inputs.asObj.parameters.prompt }}",\n model="{{ model.id }}",\n) ',
|
|
2272
2392
|
"importInferenceClient": 'from huggingface_hub import InferenceClient\n\nclient = InferenceClient(\n provider="{{ provider }}",\n api_key="{{ accessToken }}",\n{% if billTo %}\n bill_to="{{ billTo }}",\n{% endif %}\n)',
|
|
2273
2393
|
"textToImage": '# output is a PIL.Image object\nimage = client.text_to_image(\n {{ inputs.asObj.inputs }},\n model="{{ model.id }}",\n) ',
|
|
2394
|
+
"textToSpeech": '# audio is returned as bytes\naudio = client.text_to_speech(\n {{ inputs.asObj.inputs }},\n model="{{ model.id }}",\n) \n',
|
|
2274
2395
|
"textToVideo": 'video = client.text_to_video(\n {{ inputs.asObj.inputs }},\n model="{{ model.id }}",\n) '
|
|
2275
2396
|
},
|
|
2276
2397
|
"openai": {
|
|
@@ -2287,8 +2408,9 @@ const image = await client.textToVideo({
|
|
|
2287
2408
|
"imageToImage": 'def query(payload):\n with open(payload["inputs"], "rb") as f:\n img = f.read()\n payload["inputs"] = base64.b64encode(img).decode("utf-8")\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nimage_bytes = query({\n{{ providerInputs.asJsonString }}\n})\n\n# You can access the image with PIL.Image for example\nimport io\nfrom PIL import Image\nimage = Image.open(io.BytesIO(image_bytes)) ',
|
|
2288
2409
|
"importRequests": '{% if importBase64 %}\nimport base64\n{% endif %}\n{% if importJson %}\nimport json\n{% endif %}\nimport requests\n\nAPI_URL = "{{ fullUrl }}"\nheaders = {\n "Authorization": "{{ authorizationHeader }}",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}"\n{% endif %}\n}',
|
|
2289
2410
|
"tabular": 'def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nresponse = query({\n "inputs": {\n "data": {{ providerInputs.asObj.inputs }}\n },\n}) ',
|
|
2290
|
-
"textToAudio": '{% if model.library_name == "transformers" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\naudio_bytes = query({\n "inputs": {{
|
|
2411
|
+
"textToAudio": '{% if model.library_name == "transformers" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\naudio_bytes = query({\n "inputs": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio_bytes)\n{% else %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\naudio, sampling_rate = query({\n "inputs": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio, rate=sampling_rate)\n{% endif %} ',
|
|
2291
2412
|
"textToImage": '{% if provider == "hf-inference" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nimage_bytes = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n})\n\n# You can access the image with PIL.Image for example\nimport io\nfrom PIL import Image\nimage = Image.open(io.BytesIO(image_bytes))\n{% endif %}',
|
|
2413
|
+
"textToSpeech": '{% if model.library_name == "transformers" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\naudio_bytes = query({\n "text": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio_bytes)\n{% else %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\naudio, sampling_rate = query({\n "text": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio, rate=sampling_rate)\n{% endif %} ',
|
|
2292
2414
|
"zeroShotClassification": 'def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\noutput = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n "parameters": {"candidate_labels": ["refund", "legal", "faq"]},\n}) ',
|
|
2293
2415
|
"zeroShotImageClassification": 'def query(data):\n with open(data["image_path"], "rb") as f:\n img = f.read()\n payload={\n "parameters": data["parameters"],\n "inputs": base64.b64encode(img).decode("utf-8")\n }\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\noutput = query({\n "image_path": {{ providerInputs.asObj.inputs }},\n "parameters": {"candidate_labels": ["cat", "dog", "llama"]},\n}) '
|
|
2294
2416
|
}
|
|
@@ -2391,6 +2513,7 @@ var HF_JS_METHODS = {
|
|
|
2391
2513
|
"text-generation": "textGeneration",
|
|
2392
2514
|
"text2text-generation": "textGeneration",
|
|
2393
2515
|
"token-classification": "tokenClassification",
|
|
2516
|
+
"text-to-speech": "textToSpeech",
|
|
2394
2517
|
translation: "translation"
|
|
2395
2518
|
};
|
|
2396
2519
|
var snippetGenerator = (templateName, inputPreparationFn) => {
|
|
@@ -2510,7 +2633,7 @@ var prepareConversationalInput = (model, opts) => {
|
|
|
2510
2633
|
return {
|
|
2511
2634
|
messages: opts?.messages ?? getModelInputSnippet(model),
|
|
2512
2635
|
...opts?.temperature ? { temperature: opts?.temperature } : void 0,
|
|
2513
|
-
max_tokens: opts?.max_tokens
|
|
2636
|
+
...opts?.max_tokens ? { max_tokens: opts?.max_tokens } : void 0,
|
|
2514
2637
|
...opts?.top_p ? { top_p: opts?.top_p } : void 0
|
|
2515
2638
|
};
|
|
2516
2639
|
};
|
|
@@ -2537,7 +2660,7 @@ var snippets = {
|
|
|
2537
2660
|
"text-generation": snippetGenerator("basic"),
|
|
2538
2661
|
"text-to-audio": snippetGenerator("textToAudio"),
|
|
2539
2662
|
"text-to-image": snippetGenerator("textToImage"),
|
|
2540
|
-
"text-to-speech": snippetGenerator("
|
|
2663
|
+
"text-to-speech": snippetGenerator("textToSpeech"),
|
|
2541
2664
|
"text-to-video": snippetGenerator("textToVideo"),
|
|
2542
2665
|
"text2text-generation": snippetGenerator("basic"),
|
|
2543
2666
|
"token-classification": snippetGenerator("basic"),
|
|
@@ -2603,6 +2726,7 @@ export {
|
|
|
2603
2726
|
InferenceClient,
|
|
2604
2727
|
InferenceClientEndpoint,
|
|
2605
2728
|
InferenceOutputError,
|
|
2729
|
+
PROVIDERS_OR_POLICIES,
|
|
2606
2730
|
audioClassification,
|
|
2607
2731
|
audioToAudio,
|
|
2608
2732
|
automaticSpeechRecognition,
|