@huggingface/inference 3.10.0 → 3.12.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.cjs +713 -643
- package/dist/index.js +712 -643
- package/dist/src/InferenceClient.d.ts +16 -17
- package/dist/src/InferenceClient.d.ts.map +1 -1
- package/dist/src/lib/getInferenceProviderMapping.d.ts +5 -1
- package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -1
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/providerHelper.d.ts +1 -1
- package/dist/src/providers/providerHelper.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioToAudio.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
- package/dist/src/types.d.ts +6 -4
- package/dist/src/types.d.ts.map +1 -1
- package/dist/src/utils/typedEntries.d.ts +4 -0
- package/dist/src/utils/typedEntries.d.ts.map +1 -0
- package/package.json +3 -3
- package/src/InferenceClient.ts +32 -43
- package/src/lib/getInferenceProviderMapping.ts +68 -19
- package/src/lib/makeRequestOptions.ts +4 -3
- package/src/providers/hf-inference.ts +1 -1
- package/src/providers/providerHelper.ts +1 -1
- package/src/snippets/getInferenceSnippets.ts +1 -1
- package/src/tasks/audio/audioClassification.ts +3 -1
- package/src/tasks/audio/audioToAudio.ts +4 -1
- package/src/tasks/audio/automaticSpeechRecognition.ts +3 -1
- package/src/tasks/audio/textToSpeech.ts +2 -1
- package/src/tasks/custom/request.ts +3 -1
- package/src/tasks/custom/streamingRequest.ts +3 -1
- package/src/tasks/cv/imageClassification.ts +3 -1
- package/src/tasks/cv/imageSegmentation.ts +3 -1
- package/src/tasks/cv/imageToImage.ts +3 -1
- package/src/tasks/cv/imageToText.ts +3 -1
- package/src/tasks/cv/objectDetection.ts +3 -1
- package/src/tasks/cv/textToImage.ts +2 -1
- package/src/tasks/cv/textToVideo.ts +2 -1
- package/src/tasks/cv/zeroShotImageClassification.ts +3 -1
- package/src/tasks/multimodal/documentQuestionAnswering.ts +3 -1
- package/src/tasks/multimodal/visualQuestionAnswering.ts +3 -1
- package/src/tasks/nlp/chatCompletion.ts +3 -1
- package/src/tasks/nlp/chatCompletionStream.ts +3 -1
- package/src/tasks/nlp/featureExtraction.ts +3 -1
- package/src/tasks/nlp/fillMask.ts +3 -1
- package/src/tasks/nlp/questionAnswering.ts +4 -1
- package/src/tasks/nlp/sentenceSimilarity.ts +3 -1
- package/src/tasks/nlp/summarization.ts +3 -1
- package/src/tasks/nlp/tableQuestionAnswering.ts +3 -1
- package/src/tasks/nlp/textClassification.ts +3 -1
- package/src/tasks/nlp/textGeneration.ts +3 -1
- package/src/tasks/nlp/textGenerationStream.ts +3 -1
- package/src/tasks/nlp/tokenClassification.ts +3 -1
- package/src/tasks/nlp/translation.ts +3 -1
- package/src/tasks/nlp/zeroShotClassification.ts +3 -1
- package/src/tasks/tabular/tabularClassification.ts +3 -1
- package/src/tasks/tabular/tabularRegression.ts +3 -1
- package/src/types.ts +8 -4
- package/src/utils/typedEntries.ts +5 -0
package/dist/index.cjs
CHANGED
|
@@ -25,6 +25,7 @@ __export(src_exports, {
|
|
|
25
25
|
InferenceClient: () => InferenceClient,
|
|
26
26
|
InferenceClientEndpoint: () => InferenceClientEndpoint,
|
|
27
27
|
InferenceOutputError: () => InferenceOutputError,
|
|
28
|
+
PROVIDERS_OR_POLICIES: () => PROVIDERS_OR_POLICIES,
|
|
28
29
|
audioClassification: () => audioClassification,
|
|
29
30
|
audioToAudio: () => audioToAudio,
|
|
30
31
|
automaticSpeechRecognition: () => automaticSpeechRecognition,
|
|
@@ -98,6 +99,38 @@ __export(tasks_exports, {
|
|
|
98
99
|
zeroShotImageClassification: () => zeroShotImageClassification
|
|
99
100
|
});
|
|
100
101
|
|
|
102
|
+
// src/config.ts
|
|
103
|
+
var HF_HUB_URL = "https://huggingface.co";
|
|
104
|
+
var HF_ROUTER_URL = "https://router.huggingface.co";
|
|
105
|
+
var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
|
|
106
|
+
|
|
107
|
+
// src/providers/consts.ts
|
|
108
|
+
var HARDCODED_MODEL_INFERENCE_MAPPING = {
|
|
109
|
+
/**
|
|
110
|
+
* "HF model ID" => "Model ID on Inference Provider's side"
|
|
111
|
+
*
|
|
112
|
+
* Example:
|
|
113
|
+
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
|
|
114
|
+
*/
|
|
115
|
+
"black-forest-labs": {},
|
|
116
|
+
cerebras: {},
|
|
117
|
+
cohere: {},
|
|
118
|
+
"fal-ai": {},
|
|
119
|
+
"featherless-ai": {},
|
|
120
|
+
"fireworks-ai": {},
|
|
121
|
+
groq: {},
|
|
122
|
+
"hf-inference": {},
|
|
123
|
+
hyperbolic: {},
|
|
124
|
+
nebius: {},
|
|
125
|
+
novita: {},
|
|
126
|
+
nscale: {},
|
|
127
|
+
openai: {},
|
|
128
|
+
ovhcloud: {},
|
|
129
|
+
replicate: {},
|
|
130
|
+
sambanova: {},
|
|
131
|
+
together: {}
|
|
132
|
+
};
|
|
133
|
+
|
|
101
134
|
// src/lib/InferenceOutputError.ts
|
|
102
135
|
var InferenceOutputError = class extends TypeError {
|
|
103
136
|
constructor(message) {
|
|
@@ -108,42 +141,6 @@ var InferenceOutputError = class extends TypeError {
|
|
|
108
141
|
}
|
|
109
142
|
};
|
|
110
143
|
|
|
111
|
-
// src/utils/delay.ts
|
|
112
|
-
function delay(ms) {
|
|
113
|
-
return new Promise((resolve) => {
|
|
114
|
-
setTimeout(() => resolve(), ms);
|
|
115
|
-
});
|
|
116
|
-
}
|
|
117
|
-
|
|
118
|
-
// src/utils/pick.ts
|
|
119
|
-
function pick(o, props) {
|
|
120
|
-
return Object.assign(
|
|
121
|
-
{},
|
|
122
|
-
...props.map((prop) => {
|
|
123
|
-
if (o[prop] !== void 0) {
|
|
124
|
-
return { [prop]: o[prop] };
|
|
125
|
-
}
|
|
126
|
-
})
|
|
127
|
-
);
|
|
128
|
-
}
|
|
129
|
-
|
|
130
|
-
// src/utils/typedInclude.ts
|
|
131
|
-
function typedInclude(arr, v) {
|
|
132
|
-
return arr.includes(v);
|
|
133
|
-
}
|
|
134
|
-
|
|
135
|
-
// src/utils/omit.ts
|
|
136
|
-
function omit(o, props) {
|
|
137
|
-
const propsArr = Array.isArray(props) ? props : [props];
|
|
138
|
-
const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
|
|
139
|
-
return pick(o, letsKeep);
|
|
140
|
-
}
|
|
141
|
-
|
|
142
|
-
// src/config.ts
|
|
143
|
-
var HF_HUB_URL = "https://huggingface.co";
|
|
144
|
-
var HF_ROUTER_URL = "https://router.huggingface.co";
|
|
145
|
-
var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
|
|
146
|
-
|
|
147
144
|
// src/utils/toArray.ts
|
|
148
145
|
function toArray(obj) {
|
|
149
146
|
if (Array.isArray(obj)) {
|
|
@@ -238,627 +235,736 @@ var BaseTextGenerationTask = class extends TaskProviderHelper {
|
|
|
238
235
|
}
|
|
239
236
|
};
|
|
240
237
|
|
|
241
|
-
// src/providers/
|
|
242
|
-
var
|
|
243
|
-
var
|
|
238
|
+
// src/providers/hf-inference.ts
|
|
239
|
+
var EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = ["feature-extraction", "sentence-similarity"];
|
|
240
|
+
var HFInferenceTask = class extends TaskProviderHelper {
|
|
244
241
|
constructor() {
|
|
245
|
-
super("
|
|
242
|
+
super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
|
|
246
243
|
}
|
|
247
244
|
preparePayload(params) {
|
|
248
|
-
return
|
|
249
|
-
...omit(params.args, ["inputs", "parameters"]),
|
|
250
|
-
...params.args.parameters,
|
|
251
|
-
prompt: params.args.inputs
|
|
252
|
-
};
|
|
245
|
+
return params.args;
|
|
253
246
|
}
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
};
|
|
258
|
-
if (!binary) {
|
|
259
|
-
headers["Content-Type"] = "application/json";
|
|
247
|
+
makeUrl(params) {
|
|
248
|
+
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
|
|
249
|
+
return params.model;
|
|
260
250
|
}
|
|
261
|
-
return
|
|
251
|
+
return super.makeUrl(params);
|
|
262
252
|
}
|
|
263
253
|
makeRoute(params) {
|
|
264
|
-
if (
|
|
265
|
-
|
|
254
|
+
if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
|
|
255
|
+
return `models/${params.model}/pipeline/${params.task}`;
|
|
266
256
|
}
|
|
267
|
-
return
|
|
257
|
+
return `models/${params.model}`;
|
|
258
|
+
}
|
|
259
|
+
async getResponse(response) {
|
|
260
|
+
return response;
|
|
268
261
|
}
|
|
262
|
+
};
|
|
263
|
+
var HFInferenceTextToImageTask = class extends HFInferenceTask {
|
|
269
264
|
async getResponse(response, url, headers, outputType) {
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
265
|
+
if (!response) {
|
|
266
|
+
throw new InferenceOutputError("response is undefined");
|
|
267
|
+
}
|
|
268
|
+
if (typeof response == "object") {
|
|
269
|
+
if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
|
|
270
|
+
const base64Data = response.data[0].b64_json;
|
|
271
|
+
if (outputType === "url") {
|
|
272
|
+
return `data:image/jpeg;base64,${base64Data}`;
|
|
273
|
+
}
|
|
274
|
+
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
275
|
+
return await base64Response.blob();
|
|
278
276
|
}
|
|
279
|
-
|
|
280
|
-
if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
|
|
277
|
+
if ("output" in response && Array.isArray(response.output)) {
|
|
281
278
|
if (outputType === "url") {
|
|
282
|
-
return
|
|
279
|
+
return response.output[0];
|
|
283
280
|
}
|
|
284
|
-
const
|
|
285
|
-
|
|
281
|
+
const urlResponse = await fetch(response.output[0]);
|
|
282
|
+
const blob = await urlResponse.blob();
|
|
283
|
+
return blob;
|
|
286
284
|
}
|
|
287
285
|
}
|
|
288
|
-
|
|
286
|
+
if (response instanceof Blob) {
|
|
287
|
+
if (outputType === "url") {
|
|
288
|
+
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
|
|
289
|
+
return `data:image/jpeg;base64,${b64}`;
|
|
290
|
+
}
|
|
291
|
+
return response;
|
|
292
|
+
}
|
|
293
|
+
throw new InferenceOutputError("Expected a Blob ");
|
|
289
294
|
}
|
|
290
295
|
};
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
+
var HFInferenceConversationalTask = class extends HFInferenceTask {
|
|
297
|
+
makeUrl(params) {
|
|
298
|
+
let url;
|
|
299
|
+
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
|
|
300
|
+
url = params.model.trim();
|
|
301
|
+
} else {
|
|
302
|
+
url = `${this.makeBaseUrl(params)}/models/${params.model}`;
|
|
303
|
+
}
|
|
304
|
+
url = url.replace(/\/+$/, "");
|
|
305
|
+
if (url.endsWith("/v1")) {
|
|
306
|
+
url += "/chat/completions";
|
|
307
|
+
} else if (!url.endsWith("/chat/completions")) {
|
|
308
|
+
url += "/v1/chat/completions";
|
|
309
|
+
}
|
|
310
|
+
return url;
|
|
311
|
+
}
|
|
312
|
+
preparePayload(params) {
|
|
313
|
+
return {
|
|
314
|
+
...params.args,
|
|
315
|
+
model: params.model
|
|
316
|
+
};
|
|
317
|
+
}
|
|
318
|
+
async getResponse(response) {
|
|
319
|
+
return response;
|
|
296
320
|
}
|
|
297
321
|
};
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
322
|
+
var HFInferenceTextGenerationTask = class extends HFInferenceTask {
|
|
323
|
+
async getResponse(response) {
|
|
324
|
+
const res = toArray(response);
|
|
325
|
+
if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
|
|
326
|
+
return res?.[0];
|
|
327
|
+
}
|
|
328
|
+
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
303
329
|
}
|
|
304
|
-
|
|
305
|
-
|
|
330
|
+
};
|
|
331
|
+
var HFInferenceAudioClassificationTask = class extends HFInferenceTask {
|
|
332
|
+
async getResponse(response) {
|
|
333
|
+
if (Array.isArray(response) && response.every(
|
|
334
|
+
(x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
|
|
335
|
+
)) {
|
|
336
|
+
return response;
|
|
337
|
+
}
|
|
338
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
|
|
306
339
|
}
|
|
307
340
|
};
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
|
|
312
|
-
}
|
|
313
|
-
|
|
314
|
-
// src/providers/fal-ai.ts
|
|
315
|
-
var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
|
|
316
|
-
var FalAITask = class extends TaskProviderHelper {
|
|
317
|
-
constructor(url) {
|
|
318
|
-
super("fal-ai", url || "https://fal.run");
|
|
341
|
+
var HFInferenceAutomaticSpeechRecognitionTask = class extends HFInferenceTask {
|
|
342
|
+
async getResponse(response) {
|
|
343
|
+
return response;
|
|
319
344
|
}
|
|
320
|
-
|
|
321
|
-
|
|
345
|
+
};
|
|
346
|
+
var HFInferenceAudioToAudioTask = class extends HFInferenceTask {
|
|
347
|
+
async getResponse(response) {
|
|
348
|
+
if (!Array.isArray(response)) {
|
|
349
|
+
throw new InferenceOutputError("Expected Array");
|
|
350
|
+
}
|
|
351
|
+
if (!response.every((elem) => {
|
|
352
|
+
return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
|
|
353
|
+
})) {
|
|
354
|
+
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
|
|
355
|
+
}
|
|
356
|
+
return response;
|
|
322
357
|
}
|
|
323
|
-
|
|
324
|
-
|
|
358
|
+
};
|
|
359
|
+
var HFInferenceDocumentQuestionAnsweringTask = class extends HFInferenceTask {
|
|
360
|
+
async getResponse(response) {
|
|
361
|
+
if (Array.isArray(response) && response.every(
|
|
362
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
|
|
363
|
+
)) {
|
|
364
|
+
return response[0];
|
|
365
|
+
}
|
|
366
|
+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
325
367
|
}
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
368
|
+
};
|
|
369
|
+
var HFInferenceFeatureExtractionTask = class extends HFInferenceTask {
|
|
370
|
+
async getResponse(response) {
|
|
371
|
+
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
|
|
372
|
+
if (curDepth > maxDepth)
|
|
373
|
+
return false;
|
|
374
|
+
if (arr.every((x) => Array.isArray(x))) {
|
|
375
|
+
return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
|
|
376
|
+
} else {
|
|
377
|
+
return arr.every((x) => typeof x === "number");
|
|
378
|
+
}
|
|
329
379
|
};
|
|
330
|
-
if (
|
|
331
|
-
|
|
380
|
+
if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
|
|
381
|
+
return response;
|
|
332
382
|
}
|
|
333
|
-
|
|
383
|
+
throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
|
|
334
384
|
}
|
|
335
385
|
};
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
preparePayload(params) {
|
|
341
|
-
const payload = {
|
|
342
|
-
...omit(params.args, ["inputs", "parameters"]),
|
|
343
|
-
...params.args.parameters,
|
|
344
|
-
sync_mode: true,
|
|
345
|
-
prompt: params.args.inputs
|
|
346
|
-
};
|
|
347
|
-
if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
|
|
348
|
-
payload.loras = [
|
|
349
|
-
{
|
|
350
|
-
path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
|
|
351
|
-
scale: 1
|
|
352
|
-
}
|
|
353
|
-
];
|
|
354
|
-
if (params.mapping.providerId === "fal-ai/lora") {
|
|
355
|
-
payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
|
|
356
|
-
}
|
|
386
|
+
var HFInferenceImageClassificationTask = class extends HFInferenceTask {
|
|
387
|
+
async getResponse(response) {
|
|
388
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
|
|
389
|
+
return response;
|
|
357
390
|
}
|
|
358
|
-
|
|
391
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
359
392
|
}
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
const urlResponse = await fetch(response.images[0].url);
|
|
366
|
-
return await urlResponse.blob();
|
|
393
|
+
};
|
|
394
|
+
var HFInferenceImageSegmentationTask = class extends HFInferenceTask {
|
|
395
|
+
async getResponse(response) {
|
|
396
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) {
|
|
397
|
+
return response;
|
|
367
398
|
}
|
|
368
|
-
throw new InferenceOutputError("Expected
|
|
399
|
+
throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
|
|
369
400
|
}
|
|
370
401
|
};
|
|
371
|
-
var
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
makeRoute(params) {
|
|
376
|
-
if (params.authMethod !== "provider-key") {
|
|
377
|
-
return `/${params.model}?_subdomain=queue`;
|
|
402
|
+
var HFInferenceImageToTextTask = class extends HFInferenceTask {
|
|
403
|
+
async getResponse(response) {
|
|
404
|
+
if (typeof response?.generated_text !== "string") {
|
|
405
|
+
throw new InferenceOutputError("Expected {generated_text: string}");
|
|
378
406
|
}
|
|
379
|
-
return
|
|
407
|
+
return response;
|
|
380
408
|
}
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
}
|
|
409
|
+
};
|
|
410
|
+
var HFInferenceImageToImageTask = class extends HFInferenceTask {
|
|
411
|
+
async getResponse(response) {
|
|
412
|
+
if (response instanceof Blob) {
|
|
413
|
+
return response;
|
|
414
|
+
}
|
|
415
|
+
throw new InferenceOutputError("Expected Blob");
|
|
387
416
|
}
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
417
|
+
};
|
|
418
|
+
var HFInferenceObjectDetectionTask = class extends HFInferenceTask {
|
|
419
|
+
async getResponse(response) {
|
|
420
|
+
if (Array.isArray(response) && response.every(
|
|
421
|
+
(x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
|
|
422
|
+
)) {
|
|
423
|
+
return response;
|
|
391
424
|
}
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
425
|
+
throw new InferenceOutputError(
|
|
426
|
+
"Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
|
|
427
|
+
);
|
|
428
|
+
}
|
|
429
|
+
};
|
|
430
|
+
var HFInferenceZeroShotImageClassificationTask = class extends HFInferenceTask {
|
|
431
|
+
async getResponse(response) {
|
|
432
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
|
|
433
|
+
return response;
|
|
395
434
|
}
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
const
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
await delay(500);
|
|
405
|
-
const statusResponse = await fetch(statusUrl, { headers });
|
|
406
|
-
if (!statusResponse.ok) {
|
|
407
|
-
throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
|
|
408
|
-
}
|
|
409
|
-
try {
|
|
410
|
-
status = (await statusResponse.json()).status;
|
|
411
|
-
} catch (error) {
|
|
412
|
-
throw new InferenceOutputError("Failed to parse status response from fal-ai API");
|
|
413
|
-
}
|
|
435
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
436
|
+
}
|
|
437
|
+
};
|
|
438
|
+
var HFInferenceTextClassificationTask = class extends HFInferenceTask {
|
|
439
|
+
async getResponse(response) {
|
|
440
|
+
const output = response?.[0];
|
|
441
|
+
if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
|
|
442
|
+
return output;
|
|
414
443
|
}
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
444
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
445
|
+
}
|
|
446
|
+
};
|
|
447
|
+
var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask {
|
|
448
|
+
async getResponse(response) {
|
|
449
|
+
if (Array.isArray(response) ? response.every(
|
|
450
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
|
|
451
|
+
) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number") {
|
|
452
|
+
return Array.isArray(response) ? response[0] : response;
|
|
421
453
|
}
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
)
|
|
454
|
+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
455
|
+
}
|
|
456
|
+
};
|
|
457
|
+
var HFInferenceFillMaskTask = class extends HFInferenceTask {
|
|
458
|
+
async getResponse(response) {
|
|
459
|
+
if (Array.isArray(response) && response.every(
|
|
460
|
+
(x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
|
|
461
|
+
)) {
|
|
462
|
+
return response;
|
|
429
463
|
}
|
|
464
|
+
throw new InferenceOutputError(
|
|
465
|
+
"Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
|
|
466
|
+
);
|
|
430
467
|
}
|
|
431
468
|
};
|
|
432
|
-
var
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
469
|
+
var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask {
|
|
470
|
+
async getResponse(response) {
|
|
471
|
+
if (Array.isArray(response) && response.every(
|
|
472
|
+
(x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
|
|
473
|
+
)) {
|
|
474
|
+
return response;
|
|
475
|
+
}
|
|
476
|
+
throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
|
|
437
477
|
}
|
|
478
|
+
};
|
|
479
|
+
var HFInferenceSentenceSimilarityTask = class extends HFInferenceTask {
|
|
438
480
|
async getResponse(response) {
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
throw new InferenceOutputError(
|
|
442
|
-
`Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
|
|
443
|
-
);
|
|
481
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
482
|
+
return response;
|
|
444
483
|
}
|
|
445
|
-
|
|
484
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
446
485
|
}
|
|
447
486
|
};
|
|
448
|
-
var
|
|
449
|
-
|
|
450
|
-
return
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
text: params.args.inputs
|
|
454
|
-
};
|
|
487
|
+
var HFInferenceTableQuestionAnsweringTask = class extends HFInferenceTask {
|
|
488
|
+
static validate(elem) {
|
|
489
|
+
return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
|
|
490
|
+
(coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
|
|
491
|
+
);
|
|
455
492
|
}
|
|
456
493
|
async getResponse(response) {
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
throw new InferenceOutputError(
|
|
460
|
-
`Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
|
|
461
|
-
);
|
|
494
|
+
if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) {
|
|
495
|
+
return Array.isArray(response) ? response[0] : response;
|
|
462
496
|
}
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
497
|
+
throw new InferenceOutputError(
|
|
498
|
+
"Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
|
|
499
|
+
);
|
|
500
|
+
}
|
|
501
|
+
};
|
|
502
|
+
var HFInferenceTokenClassificationTask = class extends HFInferenceTask {
|
|
503
|
+
async getResponse(response) {
|
|
504
|
+
if (Array.isArray(response) && response.every(
|
|
505
|
+
(x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
|
|
506
|
+
)) {
|
|
507
|
+
return response;
|
|
473
508
|
}
|
|
509
|
+
throw new InferenceOutputError(
|
|
510
|
+
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
|
|
511
|
+
);
|
|
474
512
|
}
|
|
475
513
|
};
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
514
|
+
var HFInferenceTranslationTask = class extends HFInferenceTask {
|
|
515
|
+
async getResponse(response) {
|
|
516
|
+
if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
|
|
517
|
+
return response?.length === 1 ? response?.[0] : response;
|
|
518
|
+
}
|
|
519
|
+
throw new InferenceOutputError("Expected Array<{translation_text: string}>");
|
|
482
520
|
}
|
|
483
521
|
};
|
|
484
|
-
var
|
|
485
|
-
|
|
486
|
-
|
|
522
|
+
var HFInferenceSummarizationTask = class extends HFInferenceTask {
|
|
523
|
+
async getResponse(response) {
|
|
524
|
+
if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
|
|
525
|
+
return response?.[0];
|
|
526
|
+
}
|
|
527
|
+
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
|
487
528
|
}
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
model: params.model,
|
|
493
|
-
prompt: params.args.inputs
|
|
494
|
-
};
|
|
529
|
+
};
|
|
530
|
+
var HFInferenceTextToSpeechTask = class extends HFInferenceTask {
|
|
531
|
+
async getResponse(response) {
|
|
532
|
+
return response;
|
|
495
533
|
}
|
|
534
|
+
};
|
|
535
|
+
var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
|
|
496
536
|
async getResponse(response) {
|
|
497
|
-
if (
|
|
498
|
-
|
|
499
|
-
return {
|
|
500
|
-
generated_text: completion.text
|
|
501
|
-
};
|
|
537
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
538
|
+
return response;
|
|
502
539
|
}
|
|
503
|
-
throw new InferenceOutputError("Expected
|
|
540
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
504
541
|
}
|
|
505
542
|
};
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
543
|
+
var HFInferenceVisualQuestionAnsweringTask = class extends HFInferenceTask {
|
|
544
|
+
async getResponse(response) {
|
|
545
|
+
if (Array.isArray(response) && response.every(
|
|
546
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
|
|
547
|
+
)) {
|
|
548
|
+
return response[0];
|
|
549
|
+
}
|
|
550
|
+
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
|
511
551
|
}
|
|
512
|
-
|
|
513
|
-
|
|
552
|
+
};
|
|
553
|
+
var HFInferenceTabularRegressionTask = class extends HFInferenceTask {
|
|
554
|
+
async getResponse(response) {
|
|
555
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
556
|
+
return response;
|
|
557
|
+
}
|
|
558
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
559
|
+
}
|
|
560
|
+
};
|
|
561
|
+
var HFInferenceTextToAudioTask = class extends HFInferenceTask {
|
|
562
|
+
async getResponse(response) {
|
|
563
|
+
return response;
|
|
514
564
|
}
|
|
515
565
|
};
|
|
516
566
|
|
|
517
|
-
// src/
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
567
|
+
// src/utils/typedInclude.ts
|
|
568
|
+
function typedInclude(arr, v) {
|
|
569
|
+
return arr.includes(v);
|
|
570
|
+
}
|
|
571
|
+
|
|
572
|
+
// src/lib/getInferenceProviderMapping.ts
|
|
573
|
+
var inferenceProviderMappingCache = /* @__PURE__ */ new Map();
|
|
574
|
+
async function fetchInferenceProviderMappingForModel(modelId, accessToken, options) {
|
|
575
|
+
let inferenceProviderMapping;
|
|
576
|
+
if (inferenceProviderMappingCache.has(modelId)) {
|
|
577
|
+
inferenceProviderMapping = inferenceProviderMappingCache.get(modelId);
|
|
578
|
+
} else {
|
|
579
|
+
const resp = await (options?.fetch ?? fetch)(
|
|
580
|
+
`${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
|
|
581
|
+
{
|
|
582
|
+
headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {}
|
|
583
|
+
}
|
|
584
|
+
);
|
|
585
|
+
if (resp.status === 404) {
|
|
586
|
+
throw new Error(`Model ${modelId} does not exist`);
|
|
587
|
+
}
|
|
588
|
+
inferenceProviderMapping = await resp.json().then((json) => json.inferenceProviderMapping).catch(() => null);
|
|
589
|
+
if (inferenceProviderMapping) {
|
|
590
|
+
inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
|
|
591
|
+
}
|
|
522
592
|
}
|
|
523
|
-
|
|
524
|
-
|
|
593
|
+
if (!inferenceProviderMapping) {
|
|
594
|
+
throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
|
|
525
595
|
}
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
596
|
+
return inferenceProviderMapping;
|
|
597
|
+
}
|
|
598
|
+
async function getInferenceProviderMapping(params, options) {
|
|
599
|
+
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
|
|
600
|
+
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
|
|
530
601
|
}
|
|
531
|
-
|
|
532
|
-
|
|
602
|
+
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
|
|
603
|
+
params.modelId,
|
|
604
|
+
params.accessToken,
|
|
605
|
+
options
|
|
606
|
+
);
|
|
607
|
+
const providerMapping = inferenceProviderMapping[params.provider];
|
|
608
|
+
if (providerMapping) {
|
|
609
|
+
const equivalentTasks = params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task];
|
|
610
|
+
if (!typedInclude(equivalentTasks, providerMapping.task)) {
|
|
611
|
+
throw new Error(
|
|
612
|
+
`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
|
|
613
|
+
);
|
|
614
|
+
}
|
|
615
|
+
if (providerMapping.status === "staging") {
|
|
616
|
+
console.warn(
|
|
617
|
+
`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
|
|
618
|
+
);
|
|
619
|
+
}
|
|
620
|
+
return { ...providerMapping, hfModelId: params.modelId };
|
|
533
621
|
}
|
|
534
|
-
|
|
622
|
+
return null;
|
|
623
|
+
}
|
|
624
|
+
async function resolveProvider(provider, modelId, endpointUrl) {
|
|
625
|
+
if (endpointUrl) {
|
|
626
|
+
if (provider) {
|
|
627
|
+
throw new Error("Specifying both endpointUrl and provider is not supported.");
|
|
628
|
+
}
|
|
629
|
+
return "hf-inference";
|
|
630
|
+
}
|
|
631
|
+
if (!provider) {
|
|
632
|
+
console.log(
|
|
633
|
+
"Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
|
|
634
|
+
);
|
|
635
|
+
provider = "auto";
|
|
636
|
+
}
|
|
637
|
+
if (provider === "auto") {
|
|
638
|
+
if (!modelId) {
|
|
639
|
+
throw new Error("Specifying a model is required when provider is 'auto'");
|
|
640
|
+
}
|
|
641
|
+
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
|
|
642
|
+
provider = Object.keys(inferenceProviderMapping)[0];
|
|
643
|
+
}
|
|
644
|
+
if (!provider) {
|
|
645
|
+
throw new Error(`No Inference Provider available for model ${modelId}.`);
|
|
646
|
+
}
|
|
647
|
+
return provider;
|
|
648
|
+
}
|
|
535
649
|
|
|
536
|
-
// src/
|
|
537
|
-
|
|
538
|
-
|
|
650
|
+
// src/utils/delay.ts
|
|
651
|
+
function delay(ms) {
|
|
652
|
+
return new Promise((resolve) => {
|
|
653
|
+
setTimeout(() => resolve(), ms);
|
|
654
|
+
});
|
|
655
|
+
}
|
|
656
|
+
|
|
657
|
+
// src/utils/pick.ts
|
|
658
|
+
function pick(o, props) {
|
|
659
|
+
return Object.assign(
|
|
660
|
+
{},
|
|
661
|
+
...props.map((prop) => {
|
|
662
|
+
if (o[prop] !== void 0) {
|
|
663
|
+
return { [prop]: o[prop] };
|
|
664
|
+
}
|
|
665
|
+
})
|
|
666
|
+
);
|
|
667
|
+
}
|
|
668
|
+
|
|
669
|
+
// src/utils/omit.ts
|
|
670
|
+
function omit(o, props) {
|
|
671
|
+
const propsArr = Array.isArray(props) ? props : [props];
|
|
672
|
+
const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
|
|
673
|
+
return pick(o, letsKeep);
|
|
674
|
+
}
|
|
675
|
+
|
|
676
|
+
// src/providers/black-forest-labs.ts
|
|
677
|
+
var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
|
|
678
|
+
var BlackForestLabsTextToImageTask = class extends TaskProviderHelper {
|
|
539
679
|
constructor() {
|
|
540
|
-
super("
|
|
680
|
+
super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
|
|
541
681
|
}
|
|
542
682
|
preparePayload(params) {
|
|
543
|
-
return
|
|
683
|
+
return {
|
|
684
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
685
|
+
...params.args.parameters,
|
|
686
|
+
prompt: params.args.inputs
|
|
687
|
+
};
|
|
544
688
|
}
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
689
|
+
prepareHeaders(params, binary) {
|
|
690
|
+
const headers = {
|
|
691
|
+
Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`
|
|
692
|
+
};
|
|
693
|
+
if (!binary) {
|
|
694
|
+
headers["Content-Type"] = "application/json";
|
|
548
695
|
}
|
|
549
|
-
return
|
|
696
|
+
return headers;
|
|
550
697
|
}
|
|
551
698
|
makeRoute(params) {
|
|
552
|
-
if (params
|
|
553
|
-
|
|
699
|
+
if (!params) {
|
|
700
|
+
throw new Error("Params are required");
|
|
554
701
|
}
|
|
555
|
-
return
|
|
556
|
-
}
|
|
557
|
-
async getResponse(response) {
|
|
558
|
-
return response;
|
|
702
|
+
return `/v1/${params.model}`;
|
|
559
703
|
}
|
|
560
|
-
};
|
|
561
|
-
var HFInferenceTextToImageTask = class extends HFInferenceTask {
|
|
562
704
|
async getResponse(response, url, headers, outputType) {
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
}
|
|
572
|
-
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
573
|
-
return await base64Response.blob();
|
|
705
|
+
const urlObj = new URL(response.polling_url);
|
|
706
|
+
for (let step = 0; step < 5; step++) {
|
|
707
|
+
await delay(1e3);
|
|
708
|
+
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
|
|
709
|
+
urlObj.searchParams.set("attempt", step.toString(10));
|
|
710
|
+
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
|
|
711
|
+
if (!resp.ok) {
|
|
712
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
574
713
|
}
|
|
575
|
-
|
|
714
|
+
const payload = await resp.json();
|
|
715
|
+
if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
|
|
576
716
|
if (outputType === "url") {
|
|
577
|
-
return
|
|
717
|
+
return payload.result.sample;
|
|
578
718
|
}
|
|
579
|
-
const
|
|
580
|
-
|
|
581
|
-
return blob;
|
|
582
|
-
}
|
|
583
|
-
}
|
|
584
|
-
if (response instanceof Blob) {
|
|
585
|
-
if (outputType === "url") {
|
|
586
|
-
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
|
|
587
|
-
return `data:image/jpeg;base64,${b64}`;
|
|
719
|
+
const image = await fetch(payload.result.sample);
|
|
720
|
+
return await image.blob();
|
|
588
721
|
}
|
|
589
|
-
return response;
|
|
590
722
|
}
|
|
591
|
-
throw new InferenceOutputError("
|
|
723
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
592
724
|
}
|
|
593
725
|
};
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
} else {
|
|
600
|
-
url = `${this.makeBaseUrl(params)}/models/${params.model}`;
|
|
601
|
-
}
|
|
602
|
-
url = url.replace(/\/+$/, "");
|
|
603
|
-
if (url.endsWith("/v1")) {
|
|
604
|
-
url += "/chat/completions";
|
|
605
|
-
} else if (!url.endsWith("/chat/completions")) {
|
|
606
|
-
url += "/v1/chat/completions";
|
|
607
|
-
}
|
|
608
|
-
return url;
|
|
726
|
+
|
|
727
|
+
// src/providers/cerebras.ts
|
|
728
|
+
var CerebrasConversationalTask = class extends BaseConversationalTask {
|
|
729
|
+
constructor() {
|
|
730
|
+
super("cerebras", "https://api.cerebras.ai");
|
|
609
731
|
}
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
732
|
+
};
|
|
733
|
+
|
|
734
|
+
// src/providers/cohere.ts
|
|
735
|
+
var CohereConversationalTask = class extends BaseConversationalTask {
|
|
736
|
+
constructor() {
|
|
737
|
+
super("cohere", "https://api.cohere.com");
|
|
615
738
|
}
|
|
616
|
-
|
|
617
|
-
return
|
|
739
|
+
makeRoute() {
|
|
740
|
+
return "/compatibility/v1/chat/completions";
|
|
618
741
|
}
|
|
619
742
|
};
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
743
|
+
|
|
744
|
+
// src/lib/isUrl.ts
|
|
745
|
+
function isUrl(modelOrUrl) {
|
|
746
|
+
return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
// src/providers/fal-ai.ts
|
|
750
|
+
var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
|
|
751
|
+
var FalAITask = class extends TaskProviderHelper {
|
|
752
|
+
constructor(url) {
|
|
753
|
+
super("fal-ai", url || "https://fal.run");
|
|
627
754
|
}
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
async getResponse(response) {
|
|
631
|
-
if (Array.isArray(response) && response.every(
|
|
632
|
-
(x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
|
|
633
|
-
)) {
|
|
634
|
-
return response;
|
|
635
|
-
}
|
|
636
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
|
|
755
|
+
preparePayload(params) {
|
|
756
|
+
return params.args;
|
|
637
757
|
}
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
async getResponse(response) {
|
|
641
|
-
return response;
|
|
758
|
+
makeRoute(params) {
|
|
759
|
+
return `/${params.model}`;
|
|
642
760
|
}
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
if (!response.every((elem) => {
|
|
650
|
-
return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
|
|
651
|
-
})) {
|
|
652
|
-
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
|
|
761
|
+
prepareHeaders(params, binary) {
|
|
762
|
+
const headers = {
|
|
763
|
+
Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`
|
|
764
|
+
};
|
|
765
|
+
if (!binary) {
|
|
766
|
+
headers["Content-Type"] = "application/json";
|
|
653
767
|
}
|
|
654
|
-
return
|
|
768
|
+
return headers;
|
|
655
769
|
}
|
|
656
770
|
};
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
771
|
+
function buildLoraPath(modelId, adapterWeightsPath) {
|
|
772
|
+
return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
|
|
773
|
+
}
|
|
774
|
+
var FalAITextToImageTask = class extends FalAITask {
|
|
775
|
+
preparePayload(params) {
|
|
776
|
+
const payload = {
|
|
777
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
778
|
+
...params.args.parameters,
|
|
779
|
+
sync_mode: true,
|
|
780
|
+
prompt: params.args.inputs
|
|
781
|
+
};
|
|
782
|
+
if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
|
|
783
|
+
payload.loras = [
|
|
784
|
+
{
|
|
785
|
+
path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
|
|
786
|
+
scale: 1
|
|
787
|
+
}
|
|
788
|
+
];
|
|
789
|
+
if (params.mapping.providerId === "fal-ai/lora") {
|
|
790
|
+
payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
|
|
791
|
+
}
|
|
663
792
|
}
|
|
664
|
-
|
|
793
|
+
return payload;
|
|
665
794
|
}
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
if (curDepth > maxDepth)
|
|
671
|
-
return false;
|
|
672
|
-
if (arr.every((x) => Array.isArray(x))) {
|
|
673
|
-
return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
|
|
674
|
-
} else {
|
|
675
|
-
return arr.every((x) => typeof x === "number");
|
|
795
|
+
async getResponse(response, outputType) {
|
|
796
|
+
if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") {
|
|
797
|
+
if (outputType === "url") {
|
|
798
|
+
return response.images[0].url;
|
|
676
799
|
}
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
return response;
|
|
800
|
+
const urlResponse = await fetch(response.images[0].url);
|
|
801
|
+
return await urlResponse.blob();
|
|
680
802
|
}
|
|
681
|
-
throw new InferenceOutputError("Expected
|
|
803
|
+
throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
|
|
682
804
|
}
|
|
683
805
|
};
|
|
684
|
-
var
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
return response;
|
|
688
|
-
}
|
|
689
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
806
|
+
var FalAITextToVideoTask = class extends FalAITask {
|
|
807
|
+
constructor() {
|
|
808
|
+
super("https://queue.fal.run");
|
|
690
809
|
}
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) {
|
|
695
|
-
return response;
|
|
810
|
+
makeRoute(params) {
|
|
811
|
+
if (params.authMethod !== "provider-key") {
|
|
812
|
+
return `/${params.model}?_subdomain=queue`;
|
|
696
813
|
}
|
|
697
|
-
|
|
814
|
+
return `/${params.model}`;
|
|
698
815
|
}
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
}
|
|
705
|
-
return response;
|
|
816
|
+
preparePayload(params) {
|
|
817
|
+
return {
|
|
818
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
819
|
+
...params.args.parameters,
|
|
820
|
+
prompt: params.args.inputs
|
|
821
|
+
};
|
|
706
822
|
}
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
if (response instanceof Blob) {
|
|
711
|
-
return response;
|
|
823
|
+
async getResponse(response, url, headers) {
|
|
824
|
+
if (!url || !headers) {
|
|
825
|
+
throw new InferenceOutputError("URL and headers are required for text-to-video task");
|
|
712
826
|
}
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
var HFInferenceObjectDetectionTask = class extends HFInferenceTask {
|
|
717
|
-
async getResponse(response) {
|
|
718
|
-
if (Array.isArray(response) && response.every(
|
|
719
|
-
(x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
|
|
720
|
-
)) {
|
|
721
|
-
return response;
|
|
827
|
+
const requestId = response.request_id;
|
|
828
|
+
if (!requestId) {
|
|
829
|
+
throw new InferenceOutputError("No request ID found in the response");
|
|
722
830
|
}
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
831
|
+
let status = response.status;
|
|
832
|
+
const parsedUrl = new URL(url);
|
|
833
|
+
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
|
|
834
|
+
const modelId = new URL(response.response_url).pathname;
|
|
835
|
+
const queryParams = parsedUrl.search;
|
|
836
|
+
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
|
|
837
|
+
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
|
|
838
|
+
while (status !== "COMPLETED") {
|
|
839
|
+
await delay(500);
|
|
840
|
+
const statusResponse = await fetch(statusUrl, { headers });
|
|
841
|
+
if (!statusResponse.ok) {
|
|
842
|
+
throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
|
|
843
|
+
}
|
|
844
|
+
try {
|
|
845
|
+
status = (await statusResponse.json()).status;
|
|
846
|
+
} catch (error) {
|
|
847
|
+
throw new InferenceOutputError("Failed to parse status response from fal-ai API");
|
|
848
|
+
}
|
|
849
|
+
}
|
|
850
|
+
const resultResponse = await fetch(resultUrl, { headers });
|
|
851
|
+
let result;
|
|
852
|
+
try {
|
|
853
|
+
result = await resultResponse.json();
|
|
854
|
+
} catch (error) {
|
|
855
|
+
throw new InferenceOutputError("Failed to parse result response from fal-ai API");
|
|
856
|
+
}
|
|
857
|
+
if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
|
|
858
|
+
const urlResponse = await fetch(result.video.url);
|
|
859
|
+
return await urlResponse.blob();
|
|
860
|
+
} else {
|
|
861
|
+
throw new InferenceOutputError(
|
|
862
|
+
"Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
|
|
863
|
+
);
|
|
732
864
|
}
|
|
733
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
734
865
|
}
|
|
735
866
|
};
|
|
736
|
-
var
|
|
737
|
-
|
|
738
|
-
const
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
}
|
|
742
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
867
|
+
var FalAIAutomaticSpeechRecognitionTask = class extends FalAITask {
|
|
868
|
+
prepareHeaders(params, binary) {
|
|
869
|
+
const headers = super.prepareHeaders(params, binary);
|
|
870
|
+
headers["Content-Type"] = "application/json";
|
|
871
|
+
return headers;
|
|
743
872
|
}
|
|
744
|
-
};
|
|
745
|
-
var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask {
|
|
746
873
|
async getResponse(response) {
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
874
|
+
const res = response;
|
|
875
|
+
if (typeof res?.text !== "string") {
|
|
876
|
+
throw new InferenceOutputError(
|
|
877
|
+
`Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
|
|
878
|
+
);
|
|
751
879
|
}
|
|
752
|
-
|
|
880
|
+
return { text: res.text };
|
|
753
881
|
}
|
|
754
882
|
};
|
|
755
|
-
var
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
(
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
}
|
|
762
|
-
throw new InferenceOutputError(
|
|
763
|
-
"Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
|
|
764
|
-
);
|
|
883
|
+
var FalAITextToSpeechTask = class extends FalAITask {
|
|
884
|
+
preparePayload(params) {
|
|
885
|
+
return {
|
|
886
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
887
|
+
...params.args.parameters,
|
|
888
|
+
text: params.args.inputs
|
|
889
|
+
};
|
|
765
890
|
}
|
|
766
|
-
};
|
|
767
|
-
var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask {
|
|
768
891
|
async getResponse(response) {
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
892
|
+
const res = response;
|
|
893
|
+
if (typeof res?.audio?.url !== "string") {
|
|
894
|
+
throw new InferenceOutputError(
|
|
895
|
+
`Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
|
|
896
|
+
);
|
|
773
897
|
}
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
898
|
+
try {
|
|
899
|
+
const urlResponse = await fetch(res.audio.url);
|
|
900
|
+
if (!urlResponse.ok) {
|
|
901
|
+
throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
|
|
902
|
+
}
|
|
903
|
+
return await urlResponse.blob();
|
|
904
|
+
} catch (error) {
|
|
905
|
+
throw new InferenceOutputError(
|
|
906
|
+
`Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)}`
|
|
907
|
+
);
|
|
781
908
|
}
|
|
782
|
-
throw new InferenceOutputError("Expected Array<number>");
|
|
783
909
|
}
|
|
784
910
|
};
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
async getResponse(response) {
|
|
792
|
-
if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) {
|
|
793
|
-
return Array.isArray(response) ? response[0] : response;
|
|
794
|
-
}
|
|
795
|
-
throw new InferenceOutputError(
|
|
796
|
-
"Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
|
|
797
|
-
);
|
|
911
|
+
|
|
912
|
+
// src/providers/featherless-ai.ts
|
|
913
|
+
var FEATHERLESS_API_BASE_URL = "https://api.featherless.ai";
|
|
914
|
+
var FeatherlessAIConversationalTask = class extends BaseConversationalTask {
|
|
915
|
+
constructor() {
|
|
916
|
+
super("featherless-ai", FEATHERLESS_API_BASE_URL);
|
|
798
917
|
}
|
|
799
918
|
};
|
|
800
|
-
var
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
(x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
|
|
804
|
-
)) {
|
|
805
|
-
return response;
|
|
806
|
-
}
|
|
807
|
-
throw new InferenceOutputError(
|
|
808
|
-
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
|
|
809
|
-
);
|
|
919
|
+
var FeatherlessAITextGenerationTask = class extends BaseTextGenerationTask {
|
|
920
|
+
constructor() {
|
|
921
|
+
super("featherless-ai", FEATHERLESS_API_BASE_URL);
|
|
810
922
|
}
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
923
|
+
preparePayload(params) {
|
|
924
|
+
return {
|
|
925
|
+
...params.args,
|
|
926
|
+
...params.args.parameters,
|
|
927
|
+
model: params.model,
|
|
928
|
+
prompt: params.args.inputs
|
|
929
|
+
};
|
|
818
930
|
}
|
|
819
|
-
};
|
|
820
|
-
var HFInferenceSummarizationTask = class extends HFInferenceTask {
|
|
821
931
|
async getResponse(response) {
|
|
822
|
-
if (
|
|
823
|
-
|
|
932
|
+
if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
|
|
933
|
+
const completion = response.choices[0];
|
|
934
|
+
return {
|
|
935
|
+
generated_text: completion.text
|
|
936
|
+
};
|
|
824
937
|
}
|
|
825
|
-
throw new InferenceOutputError("Expected
|
|
938
|
+
throw new InferenceOutputError("Expected Featherless AI text generation response format");
|
|
826
939
|
}
|
|
827
940
|
};
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
941
|
+
|
|
942
|
+
// src/providers/fireworks-ai.ts
|
|
943
|
+
var FireworksConversationalTask = class extends BaseConversationalTask {
|
|
944
|
+
constructor() {
|
|
945
|
+
super("fireworks-ai", "https://api.fireworks.ai");
|
|
831
946
|
}
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
async getResponse(response) {
|
|
835
|
-
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
836
|
-
return response;
|
|
837
|
-
}
|
|
838
|
-
throw new InferenceOutputError("Expected Array<number>");
|
|
947
|
+
makeRoute() {
|
|
948
|
+
return "/inference/v1/chat/completions";
|
|
839
949
|
}
|
|
840
950
|
};
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
}
|
|
848
|
-
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
|
951
|
+
|
|
952
|
+
// src/providers/groq.ts
|
|
953
|
+
var GROQ_API_BASE_URL = "https://api.groq.com";
|
|
954
|
+
var GroqTextGenerationTask = class extends BaseTextGenerationTask {
|
|
955
|
+
constructor() {
|
|
956
|
+
super("groq", GROQ_API_BASE_URL);
|
|
849
957
|
}
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
async getResponse(response) {
|
|
853
|
-
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
854
|
-
return response;
|
|
855
|
-
}
|
|
856
|
-
throw new InferenceOutputError("Expected Array<number>");
|
|
958
|
+
makeRoute() {
|
|
959
|
+
return "/openai/v1/chat/completions";
|
|
857
960
|
}
|
|
858
961
|
};
|
|
859
|
-
var
|
|
860
|
-
|
|
861
|
-
|
|
962
|
+
var GroqConversationalTask = class extends BaseConversationalTask {
|
|
963
|
+
constructor() {
|
|
964
|
+
super("groq", GROQ_API_BASE_URL);
|
|
965
|
+
}
|
|
966
|
+
makeRoute() {
|
|
967
|
+
return "/openai/v1/chat/completions";
|
|
862
968
|
}
|
|
863
969
|
};
|
|
864
970
|
|
|
@@ -1352,82 +1458,13 @@ function getProviderHelper(provider, task) {
|
|
|
1352
1458
|
|
|
1353
1459
|
// package.json
|
|
1354
1460
|
var name = "@huggingface/inference";
|
|
1355
|
-
var version = "3.
|
|
1356
|
-
|
|
1357
|
-
// src/providers/consts.ts
|
|
1358
|
-
var HARDCODED_MODEL_INFERENCE_MAPPING = {
|
|
1359
|
-
/**
|
|
1360
|
-
* "HF model ID" => "Model ID on Inference Provider's side"
|
|
1361
|
-
*
|
|
1362
|
-
* Example:
|
|
1363
|
-
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
|
|
1364
|
-
*/
|
|
1365
|
-
"black-forest-labs": {},
|
|
1366
|
-
cerebras: {},
|
|
1367
|
-
cohere: {},
|
|
1368
|
-
"fal-ai": {},
|
|
1369
|
-
"featherless-ai": {},
|
|
1370
|
-
"fireworks-ai": {},
|
|
1371
|
-
groq: {},
|
|
1372
|
-
"hf-inference": {},
|
|
1373
|
-
hyperbolic: {},
|
|
1374
|
-
nebius: {},
|
|
1375
|
-
novita: {},
|
|
1376
|
-
nscale: {},
|
|
1377
|
-
openai: {},
|
|
1378
|
-
ovhcloud: {},
|
|
1379
|
-
replicate: {},
|
|
1380
|
-
sambanova: {},
|
|
1381
|
-
together: {}
|
|
1382
|
-
};
|
|
1383
|
-
|
|
1384
|
-
// src/lib/getInferenceProviderMapping.ts
|
|
1385
|
-
var inferenceProviderMappingCache = /* @__PURE__ */ new Map();
|
|
1386
|
-
async function getInferenceProviderMapping(params, options) {
|
|
1387
|
-
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
|
|
1388
|
-
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
|
|
1389
|
-
}
|
|
1390
|
-
let inferenceProviderMapping;
|
|
1391
|
-
if (inferenceProviderMappingCache.has(params.modelId)) {
|
|
1392
|
-
inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId);
|
|
1393
|
-
} else {
|
|
1394
|
-
const resp = await (options?.fetch ?? fetch)(
|
|
1395
|
-
`${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`,
|
|
1396
|
-
{
|
|
1397
|
-
headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {}
|
|
1398
|
-
}
|
|
1399
|
-
);
|
|
1400
|
-
if (resp.status === 404) {
|
|
1401
|
-
throw new Error(`Model ${params.modelId} does not exist`);
|
|
1402
|
-
}
|
|
1403
|
-
inferenceProviderMapping = await resp.json().then((json) => json.inferenceProviderMapping).catch(() => null);
|
|
1404
|
-
}
|
|
1405
|
-
if (!inferenceProviderMapping) {
|
|
1406
|
-
throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`);
|
|
1407
|
-
}
|
|
1408
|
-
const providerMapping = inferenceProviderMapping[params.provider];
|
|
1409
|
-
if (providerMapping) {
|
|
1410
|
-
const equivalentTasks = params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task];
|
|
1411
|
-
if (!typedInclude(equivalentTasks, providerMapping.task)) {
|
|
1412
|
-
throw new Error(
|
|
1413
|
-
`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
|
|
1414
|
-
);
|
|
1415
|
-
}
|
|
1416
|
-
if (providerMapping.status === "staging") {
|
|
1417
|
-
console.warn(
|
|
1418
|
-
`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
|
|
1419
|
-
);
|
|
1420
|
-
}
|
|
1421
|
-
return { ...providerMapping, hfModelId: params.modelId };
|
|
1422
|
-
}
|
|
1423
|
-
return null;
|
|
1424
|
-
}
|
|
1461
|
+
var version = "3.12.0";
|
|
1425
1462
|
|
|
1426
1463
|
// src/lib/makeRequestOptions.ts
|
|
1427
1464
|
var tasks = null;
|
|
1428
1465
|
async function makeRequestOptions(args, providerHelper, options) {
|
|
1429
|
-
const {
|
|
1430
|
-
const provider =
|
|
1466
|
+
const { model: maybeModel } = args;
|
|
1467
|
+
const provider = providerHelper.provider;
|
|
1431
1468
|
const { task } = options ?? {};
|
|
1432
1469
|
if (args.endpointUrl && provider !== "hf-inference") {
|
|
1433
1470
|
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
|
|
@@ -1482,7 +1519,7 @@ async function makeRequestOptions(args, providerHelper, options) {
|
|
|
1482
1519
|
}
|
|
1483
1520
|
function makeRequestOptionsFromResolvedModel(resolvedModel, providerHelper, args, mapping, options) {
|
|
1484
1521
|
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
|
|
1485
|
-
const provider =
|
|
1522
|
+
const provider = providerHelper.provider;
|
|
1486
1523
|
const { includeCredentials, task, signal, billTo } = options ?? {};
|
|
1487
1524
|
const authMethod = (() => {
|
|
1488
1525
|
if (providerHelper.clientSideRoutingOnly) {
|
|
@@ -1773,7 +1810,8 @@ async function request(args, options) {
|
|
|
1773
1810
|
console.warn(
|
|
1774
1811
|
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
|
|
1775
1812
|
);
|
|
1776
|
-
const
|
|
1813
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1814
|
+
const providerHelper = getProviderHelper(provider, options?.task);
|
|
1777
1815
|
const result = await innerRequest(args, providerHelper, options);
|
|
1778
1816
|
return result.data;
|
|
1779
1817
|
}
|
|
@@ -1783,7 +1821,8 @@ async function* streamingRequest(args, options) {
|
|
|
1783
1821
|
console.warn(
|
|
1784
1822
|
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
|
|
1785
1823
|
);
|
|
1786
|
-
const
|
|
1824
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1825
|
+
const providerHelper = getProviderHelper(provider, options?.task);
|
|
1787
1826
|
yield* innerStreamingRequest(args, providerHelper, options);
|
|
1788
1827
|
}
|
|
1789
1828
|
|
|
@@ -1797,7 +1836,8 @@ function preparePayload(args) {
|
|
|
1797
1836
|
|
|
1798
1837
|
// src/tasks/audio/audioClassification.ts
|
|
1799
1838
|
async function audioClassification(args, options) {
|
|
1800
|
-
const
|
|
1839
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1840
|
+
const providerHelper = getProviderHelper(provider, "audio-classification");
|
|
1801
1841
|
const payload = preparePayload(args);
|
|
1802
1842
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1803
1843
|
...options,
|
|
@@ -1808,7 +1848,9 @@ async function audioClassification(args, options) {
|
|
|
1808
1848
|
|
|
1809
1849
|
// src/tasks/audio/audioToAudio.ts
|
|
1810
1850
|
async function audioToAudio(args, options) {
|
|
1811
|
-
const
|
|
1851
|
+
const model = "inputs" in args ? args.model : void 0;
|
|
1852
|
+
const provider = await resolveProvider(args.provider, model);
|
|
1853
|
+
const providerHelper = getProviderHelper(provider, "audio-to-audio");
|
|
1812
1854
|
const payload = preparePayload(args);
|
|
1813
1855
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1814
1856
|
...options,
|
|
@@ -1832,7 +1874,8 @@ function base64FromBytes(arr) {
|
|
|
1832
1874
|
|
|
1833
1875
|
// src/tasks/audio/automaticSpeechRecognition.ts
|
|
1834
1876
|
async function automaticSpeechRecognition(args, options) {
|
|
1835
|
-
const
|
|
1877
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1878
|
+
const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
|
|
1836
1879
|
const payload = await buildPayload(args);
|
|
1837
1880
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1838
1881
|
...options,
|
|
@@ -1872,7 +1915,7 @@ async function buildPayload(args) {
|
|
|
1872
1915
|
|
|
1873
1916
|
// src/tasks/audio/textToSpeech.ts
|
|
1874
1917
|
async function textToSpeech(args, options) {
|
|
1875
|
-
const provider = args.provider
|
|
1918
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1876
1919
|
const providerHelper = getProviderHelper(provider, "text-to-speech");
|
|
1877
1920
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1878
1921
|
...options,
|
|
@@ -1888,7 +1931,8 @@ function preparePayload2(args) {
|
|
|
1888
1931
|
|
|
1889
1932
|
// src/tasks/cv/imageClassification.ts
|
|
1890
1933
|
async function imageClassification(args, options) {
|
|
1891
|
-
const
|
|
1934
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1935
|
+
const providerHelper = getProviderHelper(provider, "image-classification");
|
|
1892
1936
|
const payload = preparePayload2(args);
|
|
1893
1937
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1894
1938
|
...options,
|
|
@@ -1899,7 +1943,8 @@ async function imageClassification(args, options) {
|
|
|
1899
1943
|
|
|
1900
1944
|
// src/tasks/cv/imageSegmentation.ts
|
|
1901
1945
|
async function imageSegmentation(args, options) {
|
|
1902
|
-
const
|
|
1946
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1947
|
+
const providerHelper = getProviderHelper(provider, "image-segmentation");
|
|
1903
1948
|
const payload = preparePayload2(args);
|
|
1904
1949
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1905
1950
|
...options,
|
|
@@ -1910,7 +1955,8 @@ async function imageSegmentation(args, options) {
|
|
|
1910
1955
|
|
|
1911
1956
|
// src/tasks/cv/imageToImage.ts
|
|
1912
1957
|
async function imageToImage(args, options) {
|
|
1913
|
-
const
|
|
1958
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1959
|
+
const providerHelper = getProviderHelper(provider, "image-to-image");
|
|
1914
1960
|
let reqArgs;
|
|
1915
1961
|
if (!args.parameters) {
|
|
1916
1962
|
reqArgs = {
|
|
@@ -1935,7 +1981,8 @@ async function imageToImage(args, options) {
|
|
|
1935
1981
|
|
|
1936
1982
|
// src/tasks/cv/imageToText.ts
|
|
1937
1983
|
async function imageToText(args, options) {
|
|
1938
|
-
const
|
|
1984
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1985
|
+
const providerHelper = getProviderHelper(provider, "image-to-text");
|
|
1939
1986
|
const payload = preparePayload2(args);
|
|
1940
1987
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1941
1988
|
...options,
|
|
@@ -1946,7 +1993,8 @@ async function imageToText(args, options) {
|
|
|
1946
1993
|
|
|
1947
1994
|
// src/tasks/cv/objectDetection.ts
|
|
1948
1995
|
async function objectDetection(args, options) {
|
|
1949
|
-
const
|
|
1996
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1997
|
+
const providerHelper = getProviderHelper(provider, "object-detection");
|
|
1950
1998
|
const payload = preparePayload2(args);
|
|
1951
1999
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1952
2000
|
...options,
|
|
@@ -1957,7 +2005,7 @@ async function objectDetection(args, options) {
|
|
|
1957
2005
|
|
|
1958
2006
|
// src/tasks/cv/textToImage.ts
|
|
1959
2007
|
async function textToImage(args, options) {
|
|
1960
|
-
const provider = args.provider
|
|
2008
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1961
2009
|
const providerHelper = getProviderHelper(provider, "text-to-image");
|
|
1962
2010
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1963
2011
|
...options,
|
|
@@ -1969,7 +2017,7 @@ async function textToImage(args, options) {
|
|
|
1969
2017
|
|
|
1970
2018
|
// src/tasks/cv/textToVideo.ts
|
|
1971
2019
|
async function textToVideo(args, options) {
|
|
1972
|
-
const provider = args.provider
|
|
2020
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
1973
2021
|
const providerHelper = getProviderHelper(provider, "text-to-video");
|
|
1974
2022
|
const { data: response } = await innerRequest(
|
|
1975
2023
|
args,
|
|
@@ -2006,7 +2054,8 @@ async function preparePayload3(args) {
|
|
|
2006
2054
|
}
|
|
2007
2055
|
}
|
|
2008
2056
|
async function zeroShotImageClassification(args, options) {
|
|
2009
|
-
const
|
|
2057
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2058
|
+
const providerHelper = getProviderHelper(provider, "zero-shot-image-classification");
|
|
2010
2059
|
const payload = await preparePayload3(args);
|
|
2011
2060
|
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
2012
2061
|
...options,
|
|
@@ -2017,7 +2066,8 @@ async function zeroShotImageClassification(args, options) {
|
|
|
2017
2066
|
|
|
2018
2067
|
// src/tasks/nlp/chatCompletion.ts
|
|
2019
2068
|
async function chatCompletion(args, options) {
|
|
2020
|
-
const
|
|
2069
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2070
|
+
const providerHelper = getProviderHelper(provider, "conversational");
|
|
2021
2071
|
const { data: response } = await innerRequest(args, providerHelper, {
|
|
2022
2072
|
...options,
|
|
2023
2073
|
task: "conversational"
|
|
@@ -2027,7 +2077,8 @@ async function chatCompletion(args, options) {
|
|
|
2027
2077
|
|
|
2028
2078
|
// src/tasks/nlp/chatCompletionStream.ts
|
|
2029
2079
|
async function* chatCompletionStream(args, options) {
|
|
2030
|
-
const
|
|
2080
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2081
|
+
const providerHelper = getProviderHelper(provider, "conversational");
|
|
2031
2082
|
yield* innerStreamingRequest(args, providerHelper, {
|
|
2032
2083
|
...options,
|
|
2033
2084
|
task: "conversational"
|
|
@@ -2036,7 +2087,8 @@ async function* chatCompletionStream(args, options) {
|
|
|
2036
2087
|
|
|
2037
2088
|
// src/tasks/nlp/featureExtraction.ts
|
|
2038
2089
|
async function featureExtraction(args, options) {
|
|
2039
|
-
const
|
|
2090
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2091
|
+
const providerHelper = getProviderHelper(provider, "feature-extraction");
|
|
2040
2092
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2041
2093
|
...options,
|
|
2042
2094
|
task: "feature-extraction"
|
|
@@ -2046,7 +2098,8 @@ async function featureExtraction(args, options) {
|
|
|
2046
2098
|
|
|
2047
2099
|
// src/tasks/nlp/fillMask.ts
|
|
2048
2100
|
async function fillMask(args, options) {
|
|
2049
|
-
const
|
|
2101
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2102
|
+
const providerHelper = getProviderHelper(provider, "fill-mask");
|
|
2050
2103
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2051
2104
|
...options,
|
|
2052
2105
|
task: "fill-mask"
|
|
@@ -2056,7 +2109,8 @@ async function fillMask(args, options) {
|
|
|
2056
2109
|
|
|
2057
2110
|
// src/tasks/nlp/questionAnswering.ts
|
|
2058
2111
|
async function questionAnswering(args, options) {
|
|
2059
|
-
const
|
|
2112
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2113
|
+
const providerHelper = getProviderHelper(provider, "question-answering");
|
|
2060
2114
|
const { data: res } = await innerRequest(
|
|
2061
2115
|
args,
|
|
2062
2116
|
providerHelper,
|
|
@@ -2070,7 +2124,8 @@ async function questionAnswering(args, options) {
|
|
|
2070
2124
|
|
|
2071
2125
|
// src/tasks/nlp/sentenceSimilarity.ts
|
|
2072
2126
|
async function sentenceSimilarity(args, options) {
|
|
2073
|
-
const
|
|
2127
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2128
|
+
const providerHelper = getProviderHelper(provider, "sentence-similarity");
|
|
2074
2129
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2075
2130
|
...options,
|
|
2076
2131
|
task: "sentence-similarity"
|
|
@@ -2080,7 +2135,8 @@ async function sentenceSimilarity(args, options) {
|
|
|
2080
2135
|
|
|
2081
2136
|
// src/tasks/nlp/summarization.ts
|
|
2082
2137
|
async function summarization(args, options) {
|
|
2083
|
-
const
|
|
2138
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2139
|
+
const providerHelper = getProviderHelper(provider, "summarization");
|
|
2084
2140
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2085
2141
|
...options,
|
|
2086
2142
|
task: "summarization"
|
|
@@ -2090,7 +2146,8 @@ async function summarization(args, options) {
|
|
|
2090
2146
|
|
|
2091
2147
|
// src/tasks/nlp/tableQuestionAnswering.ts
|
|
2092
2148
|
async function tableQuestionAnswering(args, options) {
|
|
2093
|
-
const
|
|
2149
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2150
|
+
const providerHelper = getProviderHelper(provider, "table-question-answering");
|
|
2094
2151
|
const { data: res } = await innerRequest(
|
|
2095
2152
|
args,
|
|
2096
2153
|
providerHelper,
|
|
@@ -2104,7 +2161,8 @@ async function tableQuestionAnswering(args, options) {
|
|
|
2104
2161
|
|
|
2105
2162
|
// src/tasks/nlp/textClassification.ts
|
|
2106
2163
|
async function textClassification(args, options) {
|
|
2107
|
-
const
|
|
2164
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2165
|
+
const providerHelper = getProviderHelper(provider, "text-classification");
|
|
2108
2166
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2109
2167
|
...options,
|
|
2110
2168
|
task: "text-classification"
|
|
@@ -2114,7 +2172,8 @@ async function textClassification(args, options) {
|
|
|
2114
2172
|
|
|
2115
2173
|
// src/tasks/nlp/textGeneration.ts
|
|
2116
2174
|
async function textGeneration(args, options) {
|
|
2117
|
-
const
|
|
2175
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2176
|
+
const providerHelper = getProviderHelper(provider, "text-generation");
|
|
2118
2177
|
const { data: response } = await innerRequest(args, providerHelper, {
|
|
2119
2178
|
...options,
|
|
2120
2179
|
task: "text-generation"
|
|
@@ -2124,7 +2183,8 @@ async function textGeneration(args, options) {
|
|
|
2124
2183
|
|
|
2125
2184
|
// src/tasks/nlp/textGenerationStream.ts
|
|
2126
2185
|
async function* textGenerationStream(args, options) {
|
|
2127
|
-
const
|
|
2186
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2187
|
+
const providerHelper = getProviderHelper(provider, "text-generation");
|
|
2128
2188
|
yield* innerStreamingRequest(args, providerHelper, {
|
|
2129
2189
|
...options,
|
|
2130
2190
|
task: "text-generation"
|
|
@@ -2133,7 +2193,8 @@ async function* textGenerationStream(args, options) {
|
|
|
2133
2193
|
|
|
2134
2194
|
// src/tasks/nlp/tokenClassification.ts
|
|
2135
2195
|
async function tokenClassification(args, options) {
|
|
2136
|
-
const
|
|
2196
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2197
|
+
const providerHelper = getProviderHelper(provider, "token-classification");
|
|
2137
2198
|
const { data: res } = await innerRequest(
|
|
2138
2199
|
args,
|
|
2139
2200
|
providerHelper,
|
|
@@ -2147,7 +2208,8 @@ async function tokenClassification(args, options) {
|
|
|
2147
2208
|
|
|
2148
2209
|
// src/tasks/nlp/translation.ts
|
|
2149
2210
|
async function translation(args, options) {
|
|
2150
|
-
const
|
|
2211
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2212
|
+
const providerHelper = getProviderHelper(provider, "translation");
|
|
2151
2213
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2152
2214
|
...options,
|
|
2153
2215
|
task: "translation"
|
|
@@ -2157,7 +2219,8 @@ async function translation(args, options) {
|
|
|
2157
2219
|
|
|
2158
2220
|
// src/tasks/nlp/zeroShotClassification.ts
|
|
2159
2221
|
async function zeroShotClassification(args, options) {
|
|
2160
|
-
const
|
|
2222
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2223
|
+
const providerHelper = getProviderHelper(provider, "zero-shot-classification");
|
|
2161
2224
|
const { data: res } = await innerRequest(
|
|
2162
2225
|
args,
|
|
2163
2226
|
providerHelper,
|
|
@@ -2171,7 +2234,8 @@ async function zeroShotClassification(args, options) {
|
|
|
2171
2234
|
|
|
2172
2235
|
// src/tasks/multimodal/documentQuestionAnswering.ts
|
|
2173
2236
|
async function documentQuestionAnswering(args, options) {
|
|
2174
|
-
const
|
|
2237
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2238
|
+
const providerHelper = getProviderHelper(provider, "document-question-answering");
|
|
2175
2239
|
const reqArgs = {
|
|
2176
2240
|
...args,
|
|
2177
2241
|
inputs: {
|
|
@@ -2193,7 +2257,8 @@ async function documentQuestionAnswering(args, options) {
|
|
|
2193
2257
|
|
|
2194
2258
|
// src/tasks/multimodal/visualQuestionAnswering.ts
|
|
2195
2259
|
async function visualQuestionAnswering(args, options) {
|
|
2196
|
-
const
|
|
2260
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2261
|
+
const providerHelper = getProviderHelper(provider, "visual-question-answering");
|
|
2197
2262
|
const reqArgs = {
|
|
2198
2263
|
...args,
|
|
2199
2264
|
inputs: {
|
|
@@ -2211,7 +2276,8 @@ async function visualQuestionAnswering(args, options) {
|
|
|
2211
2276
|
|
|
2212
2277
|
// src/tasks/tabular/tabularClassification.ts
|
|
2213
2278
|
async function tabularClassification(args, options) {
|
|
2214
|
-
const
|
|
2279
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2280
|
+
const providerHelper = getProviderHelper(provider, "tabular-classification");
|
|
2215
2281
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2216
2282
|
...options,
|
|
2217
2283
|
task: "tabular-classification"
|
|
@@ -2221,7 +2287,8 @@ async function tabularClassification(args, options) {
|
|
|
2221
2287
|
|
|
2222
2288
|
// src/tasks/tabular/tabularRegression.ts
|
|
2223
2289
|
async function tabularRegression(args, options) {
|
|
2224
|
-
const
|
|
2290
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
2291
|
+
const providerHelper = getProviderHelper(provider, "tabular-regression");
|
|
2225
2292
|
const { data: res } = await innerRequest(args, providerHelper, {
|
|
2226
2293
|
...options,
|
|
2227
2294
|
task: "tabular-regression"
|
|
@@ -2229,6 +2296,11 @@ async function tabularRegression(args, options) {
|
|
|
2229
2296
|
return providerHelper.getResponse(res);
|
|
2230
2297
|
}
|
|
2231
2298
|
|
|
2299
|
+
// src/utils/typedEntries.ts
|
|
2300
|
+
function typedEntries(obj) {
|
|
2301
|
+
return Object.entries(obj);
|
|
2302
|
+
}
|
|
2303
|
+
|
|
2232
2304
|
// src/InferenceClient.ts
|
|
2233
2305
|
var InferenceClient = class {
|
|
2234
2306
|
accessToken;
|
|
@@ -2236,40 +2308,36 @@ var InferenceClient = class {
|
|
|
2236
2308
|
constructor(accessToken = "", defaultOptions = {}) {
|
|
2237
2309
|
this.accessToken = accessToken;
|
|
2238
2310
|
this.defaultOptions = defaultOptions;
|
|
2239
|
-
for (const [name2, fn] of
|
|
2311
|
+
for (const [name2, fn] of typedEntries(tasks_exports)) {
|
|
2240
2312
|
Object.defineProperty(this, name2, {
|
|
2241
2313
|
enumerable: false,
|
|
2242
2314
|
value: (params, options) => (
|
|
2243
2315
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
2244
|
-
fn(
|
|
2316
|
+
fn(
|
|
2317
|
+
/// ^ The cast of fn to any is necessary, otherwise TS can't compile because the generated union type is too complex
|
|
2318
|
+
{ endpointUrl: defaultOptions.endpointUrl, accessToken, ...params },
|
|
2319
|
+
{
|
|
2320
|
+
...omit(defaultOptions, ["endpointUrl"]),
|
|
2321
|
+
...options
|
|
2322
|
+
}
|
|
2323
|
+
)
|
|
2245
2324
|
)
|
|
2246
2325
|
});
|
|
2247
2326
|
}
|
|
2248
2327
|
}
|
|
2249
2328
|
/**
|
|
2250
|
-
* Returns
|
|
2329
|
+
* Returns a new instance of InferenceClient tied to a specified endpoint.
|
|
2330
|
+
*
|
|
2331
|
+
* For backward compatibility mostly.
|
|
2251
2332
|
*/
|
|
2252
2333
|
endpoint(endpointUrl) {
|
|
2253
|
-
return new
|
|
2254
|
-
}
|
|
2255
|
-
};
|
|
2256
|
-
var InferenceClientEndpoint = class {
|
|
2257
|
-
constructor(endpointUrl, accessToken = "", defaultOptions = {}) {
|
|
2258
|
-
accessToken;
|
|
2259
|
-
defaultOptions;
|
|
2260
|
-
for (const [name2, fn] of Object.entries(tasks_exports)) {
|
|
2261
|
-
Object.defineProperty(this, name2, {
|
|
2262
|
-
enumerable: false,
|
|
2263
|
-
value: (params, options) => (
|
|
2264
|
-
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
2265
|
-
fn({ ...params, accessToken, endpointUrl }, { ...defaultOptions, ...options })
|
|
2266
|
-
)
|
|
2267
|
-
});
|
|
2268
|
-
}
|
|
2334
|
+
return new InferenceClient(this.accessToken, { ...this.defaultOptions, endpointUrl });
|
|
2269
2335
|
}
|
|
2270
2336
|
};
|
|
2271
2337
|
var HfInference = class extends InferenceClient {
|
|
2272
2338
|
};
|
|
2339
|
+
var InferenceClientEndpoint = class extends InferenceClient {
|
|
2340
|
+
};
|
|
2273
2341
|
|
|
2274
2342
|
// src/types.ts
|
|
2275
2343
|
var INFERENCE_PROVIDERS = [
|
|
@@ -2291,6 +2359,7 @@ var INFERENCE_PROVIDERS = [
|
|
|
2291
2359
|
"sambanova",
|
|
2292
2360
|
"together"
|
|
2293
2361
|
];
|
|
2362
|
+
var PROVIDERS_OR_POLICIES = [...INFERENCE_PROVIDERS, "auto"];
|
|
2294
2363
|
|
|
2295
2364
|
// src/snippets/index.ts
|
|
2296
2365
|
var snippets_exports = {};
|
|
@@ -2619,7 +2688,7 @@ var prepareConversationalInput = (model, opts) => {
|
|
|
2619
2688
|
return {
|
|
2620
2689
|
messages: opts?.messages ?? (0, import_tasks.getModelInputSnippet)(model),
|
|
2621
2690
|
...opts?.temperature ? { temperature: opts?.temperature } : void 0,
|
|
2622
|
-
max_tokens: opts?.max_tokens
|
|
2691
|
+
...opts?.max_tokens ? { max_tokens: opts?.max_tokens } : void 0,
|
|
2623
2692
|
...opts?.top_p ? { top_p: opts?.top_p } : void 0
|
|
2624
2693
|
};
|
|
2625
2694
|
};
|
|
@@ -2713,6 +2782,7 @@ function removeSuffix(str, suffix) {
|
|
|
2713
2782
|
InferenceClient,
|
|
2714
2783
|
InferenceClientEndpoint,
|
|
2715
2784
|
InferenceOutputError,
|
|
2785
|
+
PROVIDERS_OR_POLICIES,
|
|
2716
2786
|
audioClassification,
|
|
2717
2787
|
audioToAudio,
|
|
2718
2788
|
automaticSpeechRecognition,
|