@huggingface/inference 3.10.0 → 3.12.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. package/dist/index.cjs +713 -643
  2. package/dist/index.js +712 -643
  3. package/dist/src/InferenceClient.d.ts +16 -17
  4. package/dist/src/InferenceClient.d.ts.map +1 -1
  5. package/dist/src/lib/getInferenceProviderMapping.d.ts +5 -1
  6. package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -1
  7. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  8. package/dist/src/providers/providerHelper.d.ts +1 -1
  9. package/dist/src/providers/providerHelper.d.ts.map +1 -1
  10. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  11. package/dist/src/tasks/audio/audioToAudio.d.ts.map +1 -1
  12. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  13. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  14. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  15. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  16. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  17. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  18. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  19. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  20. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  21. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  22. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
  23. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  24. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  25. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  26. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  27. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
  28. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  29. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  30. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  31. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  32. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  33. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  34. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  35. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  36. package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
  37. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  38. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  39. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  40. package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
  41. package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
  42. package/dist/src/types.d.ts +6 -4
  43. package/dist/src/types.d.ts.map +1 -1
  44. package/dist/src/utils/typedEntries.d.ts +4 -0
  45. package/dist/src/utils/typedEntries.d.ts.map +1 -0
  46. package/package.json +3 -3
  47. package/src/InferenceClient.ts +32 -43
  48. package/src/lib/getInferenceProviderMapping.ts +68 -19
  49. package/src/lib/makeRequestOptions.ts +4 -3
  50. package/src/providers/hf-inference.ts +1 -1
  51. package/src/providers/providerHelper.ts +1 -1
  52. package/src/snippets/getInferenceSnippets.ts +1 -1
  53. package/src/tasks/audio/audioClassification.ts +3 -1
  54. package/src/tasks/audio/audioToAudio.ts +4 -1
  55. package/src/tasks/audio/automaticSpeechRecognition.ts +3 -1
  56. package/src/tasks/audio/textToSpeech.ts +2 -1
  57. package/src/tasks/custom/request.ts +3 -1
  58. package/src/tasks/custom/streamingRequest.ts +3 -1
  59. package/src/tasks/cv/imageClassification.ts +3 -1
  60. package/src/tasks/cv/imageSegmentation.ts +3 -1
  61. package/src/tasks/cv/imageToImage.ts +3 -1
  62. package/src/tasks/cv/imageToText.ts +3 -1
  63. package/src/tasks/cv/objectDetection.ts +3 -1
  64. package/src/tasks/cv/textToImage.ts +2 -1
  65. package/src/tasks/cv/textToVideo.ts +2 -1
  66. package/src/tasks/cv/zeroShotImageClassification.ts +3 -1
  67. package/src/tasks/multimodal/documentQuestionAnswering.ts +3 -1
  68. package/src/tasks/multimodal/visualQuestionAnswering.ts +3 -1
  69. package/src/tasks/nlp/chatCompletion.ts +3 -1
  70. package/src/tasks/nlp/chatCompletionStream.ts +3 -1
  71. package/src/tasks/nlp/featureExtraction.ts +3 -1
  72. package/src/tasks/nlp/fillMask.ts +3 -1
  73. package/src/tasks/nlp/questionAnswering.ts +4 -1
  74. package/src/tasks/nlp/sentenceSimilarity.ts +3 -1
  75. package/src/tasks/nlp/summarization.ts +3 -1
  76. package/src/tasks/nlp/tableQuestionAnswering.ts +3 -1
  77. package/src/tasks/nlp/textClassification.ts +3 -1
  78. package/src/tasks/nlp/textGeneration.ts +3 -1
  79. package/src/tasks/nlp/textGenerationStream.ts +3 -1
  80. package/src/tasks/nlp/tokenClassification.ts +3 -1
  81. package/src/tasks/nlp/translation.ts +3 -1
  82. package/src/tasks/nlp/zeroShotClassification.ts +3 -1
  83. package/src/tasks/tabular/tabularClassification.ts +3 -1
  84. package/src/tasks/tabular/tabularRegression.ts +3 -1
  85. package/src/types.ts +8 -4
  86. package/src/utils/typedEntries.ts +5 -0
package/dist/index.js CHANGED
@@ -41,6 +41,38 @@ __export(tasks_exports, {
41
41
  zeroShotImageClassification: () => zeroShotImageClassification
42
42
  });
43
43
 
44
+ // src/config.ts
45
+ var HF_HUB_URL = "https://huggingface.co";
46
+ var HF_ROUTER_URL = "https://router.huggingface.co";
47
+ var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
48
+
49
+ // src/providers/consts.ts
50
+ var HARDCODED_MODEL_INFERENCE_MAPPING = {
51
+ /**
52
+ * "HF model ID" => "Model ID on Inference Provider's side"
53
+ *
54
+ * Example:
55
+ * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
56
+ */
57
+ "black-forest-labs": {},
58
+ cerebras: {},
59
+ cohere: {},
60
+ "fal-ai": {},
61
+ "featherless-ai": {},
62
+ "fireworks-ai": {},
63
+ groq: {},
64
+ "hf-inference": {},
65
+ hyperbolic: {},
66
+ nebius: {},
67
+ novita: {},
68
+ nscale: {},
69
+ openai: {},
70
+ ovhcloud: {},
71
+ replicate: {},
72
+ sambanova: {},
73
+ together: {}
74
+ };
75
+
44
76
  // src/lib/InferenceOutputError.ts
45
77
  var InferenceOutputError = class extends TypeError {
46
78
  constructor(message) {
@@ -51,42 +83,6 @@ var InferenceOutputError = class extends TypeError {
51
83
  }
52
84
  };
53
85
 
54
- // src/utils/delay.ts
55
- function delay(ms) {
56
- return new Promise((resolve) => {
57
- setTimeout(() => resolve(), ms);
58
- });
59
- }
60
-
61
- // src/utils/pick.ts
62
- function pick(o, props) {
63
- return Object.assign(
64
- {},
65
- ...props.map((prop) => {
66
- if (o[prop] !== void 0) {
67
- return { [prop]: o[prop] };
68
- }
69
- })
70
- );
71
- }
72
-
73
- // src/utils/typedInclude.ts
74
- function typedInclude(arr, v) {
75
- return arr.includes(v);
76
- }
77
-
78
- // src/utils/omit.ts
79
- function omit(o, props) {
80
- const propsArr = Array.isArray(props) ? props : [props];
81
- const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
82
- return pick(o, letsKeep);
83
- }
84
-
85
- // src/config.ts
86
- var HF_HUB_URL = "https://huggingface.co";
87
- var HF_ROUTER_URL = "https://router.huggingface.co";
88
- var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
89
-
90
86
  // src/utils/toArray.ts
91
87
  function toArray(obj) {
92
88
  if (Array.isArray(obj)) {
@@ -181,627 +177,736 @@ var BaseTextGenerationTask = class extends TaskProviderHelper {
181
177
  }
182
178
  };
183
179
 
184
- // src/providers/black-forest-labs.ts
185
- var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
186
- var BlackForestLabsTextToImageTask = class extends TaskProviderHelper {
180
+ // src/providers/hf-inference.ts
181
+ var EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = ["feature-extraction", "sentence-similarity"];
182
+ var HFInferenceTask = class extends TaskProviderHelper {
187
183
  constructor() {
188
- super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
184
+ super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
189
185
  }
190
186
  preparePayload(params) {
191
- return {
192
- ...omit(params.args, ["inputs", "parameters"]),
193
- ...params.args.parameters,
194
- prompt: params.args.inputs
195
- };
187
+ return params.args;
196
188
  }
197
- prepareHeaders(params, binary) {
198
- const headers = {
199
- Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`
200
- };
201
- if (!binary) {
202
- headers["Content-Type"] = "application/json";
189
+ makeUrl(params) {
190
+ if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
191
+ return params.model;
203
192
  }
204
- return headers;
193
+ return super.makeUrl(params);
205
194
  }
206
195
  makeRoute(params) {
207
- if (!params) {
208
- throw new Error("Params are required");
196
+ if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
197
+ return `models/${params.model}/pipeline/${params.task}`;
209
198
  }
210
- return `/v1/${params.model}`;
199
+ return `models/${params.model}`;
200
+ }
201
+ async getResponse(response) {
202
+ return response;
211
203
  }
204
+ };
205
+ var HFInferenceTextToImageTask = class extends HFInferenceTask {
212
206
  async getResponse(response, url, headers, outputType) {
213
- const urlObj = new URL(response.polling_url);
214
- for (let step = 0; step < 5; step++) {
215
- await delay(1e3);
216
- console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
217
- urlObj.searchParams.set("attempt", step.toString(10));
218
- const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
219
- if (!resp.ok) {
220
- throw new InferenceOutputError("Failed to fetch result from black forest labs API");
207
+ if (!response) {
208
+ throw new InferenceOutputError("response is undefined");
209
+ }
210
+ if (typeof response == "object") {
211
+ if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
212
+ const base64Data = response.data[0].b64_json;
213
+ if (outputType === "url") {
214
+ return `data:image/jpeg;base64,${base64Data}`;
215
+ }
216
+ const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
217
+ return await base64Response.blob();
221
218
  }
222
- const payload = await resp.json();
223
- if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
219
+ if ("output" in response && Array.isArray(response.output)) {
224
220
  if (outputType === "url") {
225
- return payload.result.sample;
221
+ return response.output[0];
226
222
  }
227
- const image = await fetch(payload.result.sample);
228
- return await image.blob();
223
+ const urlResponse = await fetch(response.output[0]);
224
+ const blob = await urlResponse.blob();
225
+ return blob;
229
226
  }
230
227
  }
231
- throw new InferenceOutputError("Failed to fetch result from black forest labs API");
228
+ if (response instanceof Blob) {
229
+ if (outputType === "url") {
230
+ const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
231
+ return `data:image/jpeg;base64,${b64}`;
232
+ }
233
+ return response;
234
+ }
235
+ throw new InferenceOutputError("Expected a Blob ");
232
236
  }
233
237
  };
234
-
235
- // src/providers/cerebras.ts
236
- var CerebrasConversationalTask = class extends BaseConversationalTask {
237
- constructor() {
238
- super("cerebras", "https://api.cerebras.ai");
238
+ var HFInferenceConversationalTask = class extends HFInferenceTask {
239
+ makeUrl(params) {
240
+ let url;
241
+ if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
242
+ url = params.model.trim();
243
+ } else {
244
+ url = `${this.makeBaseUrl(params)}/models/${params.model}`;
245
+ }
246
+ url = url.replace(/\/+$/, "");
247
+ if (url.endsWith("/v1")) {
248
+ url += "/chat/completions";
249
+ } else if (!url.endsWith("/chat/completions")) {
250
+ url += "/v1/chat/completions";
251
+ }
252
+ return url;
253
+ }
254
+ preparePayload(params) {
255
+ return {
256
+ ...params.args,
257
+ model: params.model
258
+ };
259
+ }
260
+ async getResponse(response) {
261
+ return response;
239
262
  }
240
263
  };
241
-
242
- // src/providers/cohere.ts
243
- var CohereConversationalTask = class extends BaseConversationalTask {
244
- constructor() {
245
- super("cohere", "https://api.cohere.com");
264
+ var HFInferenceTextGenerationTask = class extends HFInferenceTask {
265
+ async getResponse(response) {
266
+ const res = toArray(response);
267
+ if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
268
+ return res?.[0];
269
+ }
270
+ throw new InferenceOutputError("Expected Array<{generated_text: string}>");
246
271
  }
247
- makeRoute() {
248
- return "/compatibility/v1/chat/completions";
272
+ };
273
+ var HFInferenceAudioClassificationTask = class extends HFInferenceTask {
274
+ async getResponse(response) {
275
+ if (Array.isArray(response) && response.every(
276
+ (x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
277
+ )) {
278
+ return response;
279
+ }
280
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
249
281
  }
250
282
  };
251
-
252
- // src/lib/isUrl.ts
253
- function isUrl(modelOrUrl) {
254
- return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
255
- }
256
-
257
- // src/providers/fal-ai.ts
258
- var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
259
- var FalAITask = class extends TaskProviderHelper {
260
- constructor(url) {
261
- super("fal-ai", url || "https://fal.run");
283
+ var HFInferenceAutomaticSpeechRecognitionTask = class extends HFInferenceTask {
284
+ async getResponse(response) {
285
+ return response;
262
286
  }
263
- preparePayload(params) {
264
- return params.args;
287
+ };
288
+ var HFInferenceAudioToAudioTask = class extends HFInferenceTask {
289
+ async getResponse(response) {
290
+ if (!Array.isArray(response)) {
291
+ throw new InferenceOutputError("Expected Array");
292
+ }
293
+ if (!response.every((elem) => {
294
+ return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
295
+ })) {
296
+ throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
297
+ }
298
+ return response;
265
299
  }
266
- makeRoute(params) {
267
- return `/${params.model}`;
300
+ };
301
+ var HFInferenceDocumentQuestionAnsweringTask = class extends HFInferenceTask {
302
+ async getResponse(response) {
303
+ if (Array.isArray(response) && response.every(
304
+ (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
305
+ )) {
306
+ return response[0];
307
+ }
308
+ throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
268
309
  }
269
- prepareHeaders(params, binary) {
270
- const headers = {
271
- Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`
310
+ };
311
+ var HFInferenceFeatureExtractionTask = class extends HFInferenceTask {
312
+ async getResponse(response) {
313
+ const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
314
+ if (curDepth > maxDepth)
315
+ return false;
316
+ if (arr.every((x) => Array.isArray(x))) {
317
+ return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
318
+ } else {
319
+ return arr.every((x) => typeof x === "number");
320
+ }
272
321
  };
273
- if (!binary) {
274
- headers["Content-Type"] = "application/json";
322
+ if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
323
+ return response;
275
324
  }
276
- return headers;
325
+ throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
277
326
  }
278
327
  };
279
- function buildLoraPath(modelId, adapterWeightsPath) {
280
- return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
281
- }
282
- var FalAITextToImageTask = class extends FalAITask {
283
- preparePayload(params) {
284
- const payload = {
285
- ...omit(params.args, ["inputs", "parameters"]),
286
- ...params.args.parameters,
287
- sync_mode: true,
288
- prompt: params.args.inputs
289
- };
290
- if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
291
- payload.loras = [
292
- {
293
- path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
294
- scale: 1
295
- }
296
- ];
297
- if (params.mapping.providerId === "fal-ai/lora") {
298
- payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
299
- }
328
+ var HFInferenceImageClassificationTask = class extends HFInferenceTask {
329
+ async getResponse(response) {
330
+ if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
331
+ return response;
300
332
  }
301
- return payload;
333
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
302
334
  }
303
- async getResponse(response, outputType) {
304
- if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") {
305
- if (outputType === "url") {
306
- return response.images[0].url;
307
- }
308
- const urlResponse = await fetch(response.images[0].url);
309
- return await urlResponse.blob();
335
+ };
336
+ var HFInferenceImageSegmentationTask = class extends HFInferenceTask {
337
+ async getResponse(response) {
338
+ if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) {
339
+ return response;
310
340
  }
311
- throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
341
+ throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
312
342
  }
313
343
  };
314
- var FalAITextToVideoTask = class extends FalAITask {
315
- constructor() {
316
- super("https://queue.fal.run");
317
- }
318
- makeRoute(params) {
319
- if (params.authMethod !== "provider-key") {
320
- return `/${params.model}?_subdomain=queue`;
344
+ var HFInferenceImageToTextTask = class extends HFInferenceTask {
345
+ async getResponse(response) {
346
+ if (typeof response?.generated_text !== "string") {
347
+ throw new InferenceOutputError("Expected {generated_text: string}");
321
348
  }
322
- return `/${params.model}`;
349
+ return response;
323
350
  }
324
- preparePayload(params) {
325
- return {
326
- ...omit(params.args, ["inputs", "parameters"]),
327
- ...params.args.parameters,
328
- prompt: params.args.inputs
329
- };
351
+ };
352
+ var HFInferenceImageToImageTask = class extends HFInferenceTask {
353
+ async getResponse(response) {
354
+ if (response instanceof Blob) {
355
+ return response;
356
+ }
357
+ throw new InferenceOutputError("Expected Blob");
330
358
  }
331
- async getResponse(response, url, headers) {
332
- if (!url || !headers) {
333
- throw new InferenceOutputError("URL and headers are required for text-to-video task");
359
+ };
360
+ var HFInferenceObjectDetectionTask = class extends HFInferenceTask {
361
+ async getResponse(response) {
362
+ if (Array.isArray(response) && response.every(
363
+ (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
364
+ )) {
365
+ return response;
334
366
  }
335
- const requestId = response.request_id;
336
- if (!requestId) {
337
- throw new InferenceOutputError("No request ID found in the response");
367
+ throw new InferenceOutputError(
368
+ "Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
369
+ );
370
+ }
371
+ };
372
+ var HFInferenceZeroShotImageClassificationTask = class extends HFInferenceTask {
373
+ async getResponse(response) {
374
+ if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
375
+ return response;
338
376
  }
339
- let status = response.status;
340
- const parsedUrl = new URL(url);
341
- const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
342
- const modelId = new URL(response.response_url).pathname;
343
- const queryParams = parsedUrl.search;
344
- const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
345
- const resultUrl = `${baseUrl}${modelId}${queryParams}`;
346
- while (status !== "COMPLETED") {
347
- await delay(500);
348
- const statusResponse = await fetch(statusUrl, { headers });
349
- if (!statusResponse.ok) {
350
- throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
351
- }
352
- try {
353
- status = (await statusResponse.json()).status;
354
- } catch (error) {
355
- throw new InferenceOutputError("Failed to parse status response from fal-ai API");
356
- }
377
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
378
+ }
379
+ };
380
+ var HFInferenceTextClassificationTask = class extends HFInferenceTask {
381
+ async getResponse(response) {
382
+ const output = response?.[0];
383
+ if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
384
+ return output;
357
385
  }
358
- const resultResponse = await fetch(resultUrl, { headers });
359
- let result;
360
- try {
361
- result = await resultResponse.json();
362
- } catch (error) {
363
- throw new InferenceOutputError("Failed to parse result response from fal-ai API");
386
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
387
+ }
388
+ };
389
+ var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask {
390
+ async getResponse(response) {
391
+ if (Array.isArray(response) ? response.every(
392
+ (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
393
+ ) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number") {
394
+ return Array.isArray(response) ? response[0] : response;
364
395
  }
365
- if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
366
- const urlResponse = await fetch(result.video.url);
367
- return await urlResponse.blob();
368
- } else {
369
- throw new InferenceOutputError(
370
- "Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
371
- );
396
+ throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
397
+ }
398
+ };
399
+ var HFInferenceFillMaskTask = class extends HFInferenceTask {
400
+ async getResponse(response) {
401
+ if (Array.isArray(response) && response.every(
402
+ (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
403
+ )) {
404
+ return response;
372
405
  }
406
+ throw new InferenceOutputError(
407
+ "Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
408
+ );
373
409
  }
374
410
  };
375
- var FalAIAutomaticSpeechRecognitionTask = class extends FalAITask {
376
- prepareHeaders(params, binary) {
377
- const headers = super.prepareHeaders(params, binary);
378
- headers["Content-Type"] = "application/json";
379
- return headers;
411
+ var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask {
412
+ async getResponse(response) {
413
+ if (Array.isArray(response) && response.every(
414
+ (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
415
+ )) {
416
+ return response;
417
+ }
418
+ throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
380
419
  }
420
+ };
421
+ var HFInferenceSentenceSimilarityTask = class extends HFInferenceTask {
381
422
  async getResponse(response) {
382
- const res = response;
383
- if (typeof res?.text !== "string") {
384
- throw new InferenceOutputError(
385
- `Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
386
- );
423
+ if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
424
+ return response;
387
425
  }
388
- return { text: res.text };
426
+ throw new InferenceOutputError("Expected Array<number>");
389
427
  }
390
428
  };
391
- var FalAITextToSpeechTask = class extends FalAITask {
392
- preparePayload(params) {
393
- return {
394
- ...omit(params.args, ["inputs", "parameters"]),
395
- ...params.args.parameters,
396
- text: params.args.inputs
397
- };
429
+ var HFInferenceTableQuestionAnsweringTask = class extends HFInferenceTask {
430
+ static validate(elem) {
431
+ return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
432
+ (coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
433
+ );
398
434
  }
399
435
  async getResponse(response) {
400
- const res = response;
401
- if (typeof res?.audio?.url !== "string") {
402
- throw new InferenceOutputError(
403
- `Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
404
- );
436
+ if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) {
437
+ return Array.isArray(response) ? response[0] : response;
405
438
  }
406
- try {
407
- const urlResponse = await fetch(res.audio.url);
408
- if (!urlResponse.ok) {
409
- throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
410
- }
411
- return await urlResponse.blob();
412
- } catch (error) {
413
- throw new InferenceOutputError(
414
- `Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)}`
415
- );
439
+ throw new InferenceOutputError(
440
+ "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
441
+ );
442
+ }
443
+ };
444
+ var HFInferenceTokenClassificationTask = class extends HFInferenceTask {
445
+ async getResponse(response) {
446
+ if (Array.isArray(response) && response.every(
447
+ (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
448
+ )) {
449
+ return response;
416
450
  }
451
+ throw new InferenceOutputError(
452
+ "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
453
+ );
417
454
  }
418
455
  };
419
-
420
- // src/providers/featherless-ai.ts
421
- var FEATHERLESS_API_BASE_URL = "https://api.featherless.ai";
422
- var FeatherlessAIConversationalTask = class extends BaseConversationalTask {
423
- constructor() {
424
- super("featherless-ai", FEATHERLESS_API_BASE_URL);
456
+ var HFInferenceTranslationTask = class extends HFInferenceTask {
457
+ async getResponse(response) {
458
+ if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
459
+ return response?.length === 1 ? response?.[0] : response;
460
+ }
461
+ throw new InferenceOutputError("Expected Array<{translation_text: string}>");
425
462
  }
426
463
  };
427
- var FeatherlessAITextGenerationTask = class extends BaseTextGenerationTask {
428
- constructor() {
429
- super("featherless-ai", FEATHERLESS_API_BASE_URL);
464
+ var HFInferenceSummarizationTask = class extends HFInferenceTask {
465
+ async getResponse(response) {
466
+ if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
467
+ return response?.[0];
468
+ }
469
+ throw new InferenceOutputError("Expected Array<{summary_text: string}>");
430
470
  }
431
- preparePayload(params) {
432
- return {
433
- ...params.args,
434
- ...params.args.parameters,
435
- model: params.model,
436
- prompt: params.args.inputs
437
- };
471
+ };
472
+ var HFInferenceTextToSpeechTask = class extends HFInferenceTask {
473
+ async getResponse(response) {
474
+ return response;
438
475
  }
476
+ };
477
+ var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
439
478
  async getResponse(response) {
440
- if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
441
- const completion = response.choices[0];
442
- return {
443
- generated_text: completion.text
444
- };
479
+ if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
480
+ return response;
445
481
  }
446
- throw new InferenceOutputError("Expected Featherless AI text generation response format");
482
+ throw new InferenceOutputError("Expected Array<number>");
447
483
  }
448
484
  };
449
-
450
- // src/providers/fireworks-ai.ts
451
- var FireworksConversationalTask = class extends BaseConversationalTask {
452
- constructor() {
453
- super("fireworks-ai", "https://api.fireworks.ai");
485
+ var HFInferenceVisualQuestionAnsweringTask = class extends HFInferenceTask {
486
+ async getResponse(response) {
487
+ if (Array.isArray(response) && response.every(
488
+ (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
489
+ )) {
490
+ return response[0];
491
+ }
492
+ throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
454
493
  }
455
- makeRoute() {
456
- return "/inference/v1/chat/completions";
494
+ };
495
+ var HFInferenceTabularRegressionTask = class extends HFInferenceTask {
496
+ async getResponse(response) {
497
+ if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
498
+ return response;
499
+ }
500
+ throw new InferenceOutputError("Expected Array<number>");
501
+ }
502
+ };
503
+ var HFInferenceTextToAudioTask = class extends HFInferenceTask {
504
+ async getResponse(response) {
505
+ return response;
457
506
  }
458
507
  };
459
508
 
460
- // src/providers/groq.ts
461
- var GROQ_API_BASE_URL = "https://api.groq.com";
462
- var GroqTextGenerationTask = class extends BaseTextGenerationTask {
463
- constructor() {
464
- super("groq", GROQ_API_BASE_URL);
509
+ // src/utils/typedInclude.ts
510
+ function typedInclude(arr, v) {
511
+ return arr.includes(v);
512
+ }
513
+
514
+ // src/lib/getInferenceProviderMapping.ts
515
+ var inferenceProviderMappingCache = /* @__PURE__ */ new Map();
516
+ async function fetchInferenceProviderMappingForModel(modelId, accessToken, options) {
517
+ let inferenceProviderMapping;
518
+ if (inferenceProviderMappingCache.has(modelId)) {
519
+ inferenceProviderMapping = inferenceProviderMappingCache.get(modelId);
520
+ } else {
521
+ const resp = await (options?.fetch ?? fetch)(
522
+ `${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
523
+ {
524
+ headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {}
525
+ }
526
+ );
527
+ if (resp.status === 404) {
528
+ throw new Error(`Model ${modelId} does not exist`);
529
+ }
530
+ inferenceProviderMapping = await resp.json().then((json) => json.inferenceProviderMapping).catch(() => null);
531
+ if (inferenceProviderMapping) {
532
+ inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
533
+ }
465
534
  }
466
- makeRoute() {
467
- return "/openai/v1/chat/completions";
535
+ if (!inferenceProviderMapping) {
536
+ throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
468
537
  }
469
- };
470
- var GroqConversationalTask = class extends BaseConversationalTask {
471
- constructor() {
472
- super("groq", GROQ_API_BASE_URL);
538
+ return inferenceProviderMapping;
539
+ }
540
+ async function getInferenceProviderMapping(params, options) {
541
+ if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
542
+ return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
473
543
  }
474
- makeRoute() {
475
- return "/openai/v1/chat/completions";
544
+ const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
545
+ params.modelId,
546
+ params.accessToken,
547
+ options
548
+ );
549
+ const providerMapping = inferenceProviderMapping[params.provider];
550
+ if (providerMapping) {
551
+ const equivalentTasks = params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task];
552
+ if (!typedInclude(equivalentTasks, providerMapping.task)) {
553
+ throw new Error(
554
+ `Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
555
+ );
556
+ }
557
+ if (providerMapping.status === "staging") {
558
+ console.warn(
559
+ `Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
560
+ );
561
+ }
562
+ return { ...providerMapping, hfModelId: params.modelId };
476
563
  }
477
- };
564
+ return null;
565
+ }
566
+ async function resolveProvider(provider, modelId, endpointUrl) {
567
+ if (endpointUrl) {
568
+ if (provider) {
569
+ throw new Error("Specifying both endpointUrl and provider is not supported.");
570
+ }
571
+ return "hf-inference";
572
+ }
573
+ if (!provider) {
574
+ console.log(
575
+ "Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
576
+ );
577
+ provider = "auto";
578
+ }
579
+ if (provider === "auto") {
580
+ if (!modelId) {
581
+ throw new Error("Specifying a model is required when provider is 'auto'");
582
+ }
583
+ const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
584
+ provider = Object.keys(inferenceProviderMapping)[0];
585
+ }
586
+ if (!provider) {
587
+ throw new Error(`No Inference Provider available for model ${modelId}.`);
588
+ }
589
+ return provider;
590
+ }
478
591
 
479
- // src/providers/hf-inference.ts
480
- var EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = ["feature-extraction", "sentence-similarity"];
481
- var HFInferenceTask = class extends TaskProviderHelper {
592
+ // src/utils/delay.ts
593
+ function delay(ms) {
594
+ return new Promise((resolve) => {
595
+ setTimeout(() => resolve(), ms);
596
+ });
597
+ }
598
+
599
+ // src/utils/pick.ts
600
+ function pick(o, props) {
601
+ return Object.assign(
602
+ {},
603
+ ...props.map((prop) => {
604
+ if (o[prop] !== void 0) {
605
+ return { [prop]: o[prop] };
606
+ }
607
+ })
608
+ );
609
+ }
610
+
611
+ // src/utils/omit.ts
612
+ function omit(o, props) {
613
+ const propsArr = Array.isArray(props) ? props : [props];
614
+ const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
615
+ return pick(o, letsKeep);
616
+ }
617
+
618
+ // src/providers/black-forest-labs.ts
619
+ var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
620
+ var BlackForestLabsTextToImageTask = class extends TaskProviderHelper {
482
621
  constructor() {
483
- super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
622
+ super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
484
623
  }
485
624
  preparePayload(params) {
486
- return params.args;
625
+ return {
626
+ ...omit(params.args, ["inputs", "parameters"]),
627
+ ...params.args.parameters,
628
+ prompt: params.args.inputs
629
+ };
487
630
  }
488
- makeUrl(params) {
489
- if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
490
- return params.model;
631
+ prepareHeaders(params, binary) {
632
+ const headers = {
633
+ Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`
634
+ };
635
+ if (!binary) {
636
+ headers["Content-Type"] = "application/json";
491
637
  }
492
- return super.makeUrl(params);
638
+ return headers;
493
639
  }
494
640
  makeRoute(params) {
495
- if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
496
- return `pipeline/${params.task}/${params.model}`;
641
+ if (!params) {
642
+ throw new Error("Params are required");
497
643
  }
498
- return `models/${params.model}`;
499
- }
500
- async getResponse(response) {
501
- return response;
644
+ return `/v1/${params.model}`;
502
645
  }
503
- };
504
- var HFInferenceTextToImageTask = class extends HFInferenceTask {
505
646
  async getResponse(response, url, headers, outputType) {
506
- if (!response) {
507
- throw new InferenceOutputError("response is undefined");
508
- }
509
- if (typeof response == "object") {
510
- if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
511
- const base64Data = response.data[0].b64_json;
512
- if (outputType === "url") {
513
- return `data:image/jpeg;base64,${base64Data}`;
514
- }
515
- const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
516
- return await base64Response.blob();
647
+ const urlObj = new URL(response.polling_url);
648
+ for (let step = 0; step < 5; step++) {
649
+ await delay(1e3);
650
+ console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
651
+ urlObj.searchParams.set("attempt", step.toString(10));
652
+ const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
653
+ if (!resp.ok) {
654
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
517
655
  }
518
- if ("output" in response && Array.isArray(response.output)) {
656
+ const payload = await resp.json();
657
+ if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
519
658
  if (outputType === "url") {
520
- return response.output[0];
659
+ return payload.result.sample;
521
660
  }
522
- const urlResponse = await fetch(response.output[0]);
523
- const blob = await urlResponse.blob();
524
- return blob;
525
- }
526
- }
527
- if (response instanceof Blob) {
528
- if (outputType === "url") {
529
- const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
530
- return `data:image/jpeg;base64,${b64}`;
661
+ const image = await fetch(payload.result.sample);
662
+ return await image.blob();
531
663
  }
532
- return response;
533
664
  }
534
- throw new InferenceOutputError("Expected a Blob ");
665
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
535
666
  }
536
667
  };
537
- var HFInferenceConversationalTask = class extends HFInferenceTask {
538
- makeUrl(params) {
539
- let url;
540
- if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
541
- url = params.model.trim();
542
- } else {
543
- url = `${this.makeBaseUrl(params)}/models/${params.model}`;
544
- }
545
- url = url.replace(/\/+$/, "");
546
- if (url.endsWith("/v1")) {
547
- url += "/chat/completions";
548
- } else if (!url.endsWith("/chat/completions")) {
549
- url += "/v1/chat/completions";
550
- }
551
- return url;
668
+
669
+ // src/providers/cerebras.ts
670
+ var CerebrasConversationalTask = class extends BaseConversationalTask {
671
+ constructor() {
672
+ super("cerebras", "https://api.cerebras.ai");
552
673
  }
553
- preparePayload(params) {
554
- return {
555
- ...params.args,
556
- model: params.model
557
- };
674
+ };
675
+
676
+ // src/providers/cohere.ts
677
+ var CohereConversationalTask = class extends BaseConversationalTask {
678
+ constructor() {
679
+ super("cohere", "https://api.cohere.com");
558
680
  }
559
- async getResponse(response) {
560
- return response;
681
+ makeRoute() {
682
+ return "/compatibility/v1/chat/completions";
561
683
  }
562
684
  };
563
- var HFInferenceTextGenerationTask = class extends HFInferenceTask {
564
- async getResponse(response) {
565
- const res = toArray(response);
566
- if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
567
- return res?.[0];
568
- }
569
- throw new InferenceOutputError("Expected Array<{generated_text: string}>");
685
+
686
+ // src/lib/isUrl.ts
687
+ function isUrl(modelOrUrl) {
688
+ return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
689
+ }
690
+
691
+ // src/providers/fal-ai.ts
692
+ var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
693
+ var FalAITask = class extends TaskProviderHelper {
694
+ constructor(url) {
695
+ super("fal-ai", url || "https://fal.run");
570
696
  }
571
- };
572
- var HFInferenceAudioClassificationTask = class extends HFInferenceTask {
573
- async getResponse(response) {
574
- if (Array.isArray(response) && response.every(
575
- (x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
576
- )) {
577
- return response;
578
- }
579
- throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
697
+ preparePayload(params) {
698
+ return params.args;
580
699
  }
581
- };
582
- var HFInferenceAutomaticSpeechRecognitionTask = class extends HFInferenceTask {
583
- async getResponse(response) {
584
- return response;
700
+ makeRoute(params) {
701
+ return `/${params.model}`;
585
702
  }
586
- };
587
- var HFInferenceAudioToAudioTask = class extends HFInferenceTask {
588
- async getResponse(response) {
589
- if (!Array.isArray(response)) {
590
- throw new InferenceOutputError("Expected Array");
591
- }
592
- if (!response.every((elem) => {
593
- return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
594
- })) {
595
- throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
703
+ prepareHeaders(params, binary) {
704
+ const headers = {
705
+ Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`
706
+ };
707
+ if (!binary) {
708
+ headers["Content-Type"] = "application/json";
596
709
  }
597
- return response;
710
+ return headers;
598
711
  }
599
712
  };
600
- var HFInferenceDocumentQuestionAnsweringTask = class extends HFInferenceTask {
601
- async getResponse(response) {
602
- if (Array.isArray(response) && response.every(
603
- (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
604
- )) {
605
- return response[0];
713
+ function buildLoraPath(modelId, adapterWeightsPath) {
714
+ return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
715
+ }
716
+ var FalAITextToImageTask = class extends FalAITask {
717
+ preparePayload(params) {
718
+ const payload = {
719
+ ...omit(params.args, ["inputs", "parameters"]),
720
+ ...params.args.parameters,
721
+ sync_mode: true,
722
+ prompt: params.args.inputs
723
+ };
724
+ if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
725
+ payload.loras = [
726
+ {
727
+ path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
728
+ scale: 1
729
+ }
730
+ ];
731
+ if (params.mapping.providerId === "fal-ai/lora") {
732
+ payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
733
+ }
606
734
  }
607
- throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
735
+ return payload;
608
736
  }
609
- };
610
- var HFInferenceFeatureExtractionTask = class extends HFInferenceTask {
611
- async getResponse(response) {
612
- const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
613
- if (curDepth > maxDepth)
614
- return false;
615
- if (arr.every((x) => Array.isArray(x))) {
616
- return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
617
- } else {
618
- return arr.every((x) => typeof x === "number");
737
+ async getResponse(response, outputType) {
738
+ if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") {
739
+ if (outputType === "url") {
740
+ return response.images[0].url;
619
741
  }
620
- };
621
- if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
622
- return response;
742
+ const urlResponse = await fetch(response.images[0].url);
743
+ return await urlResponse.blob();
623
744
  }
624
- throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
745
+ throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
625
746
  }
626
747
  };
627
- var HFInferenceImageClassificationTask = class extends HFInferenceTask {
628
- async getResponse(response) {
629
- if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
630
- return response;
631
- }
632
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
748
+ var FalAITextToVideoTask = class extends FalAITask {
749
+ constructor() {
750
+ super("https://queue.fal.run");
633
751
  }
634
- };
635
- var HFInferenceImageSegmentationTask = class extends HFInferenceTask {
636
- async getResponse(response) {
637
- if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) {
638
- return response;
752
+ makeRoute(params) {
753
+ if (params.authMethod !== "provider-key") {
754
+ return `/${params.model}?_subdomain=queue`;
639
755
  }
640
- throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
756
+ return `/${params.model}`;
641
757
  }
642
- };
643
- var HFInferenceImageToTextTask = class extends HFInferenceTask {
644
- async getResponse(response) {
645
- if (typeof response?.generated_text !== "string") {
646
- throw new InferenceOutputError("Expected {generated_text: string}");
647
- }
648
- return response;
758
+ preparePayload(params) {
759
+ return {
760
+ ...omit(params.args, ["inputs", "parameters"]),
761
+ ...params.args.parameters,
762
+ prompt: params.args.inputs
763
+ };
649
764
  }
650
- };
651
- var HFInferenceImageToImageTask = class extends HFInferenceTask {
652
- async getResponse(response) {
653
- if (response instanceof Blob) {
654
- return response;
765
+ async getResponse(response, url, headers) {
766
+ if (!url || !headers) {
767
+ throw new InferenceOutputError("URL and headers are required for text-to-video task");
655
768
  }
656
- throw new InferenceOutputError("Expected Blob");
657
- }
658
- };
659
- var HFInferenceObjectDetectionTask = class extends HFInferenceTask {
660
- async getResponse(response) {
661
- if (Array.isArray(response) && response.every(
662
- (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
663
- )) {
664
- return response;
769
+ const requestId = response.request_id;
770
+ if (!requestId) {
771
+ throw new InferenceOutputError("No request ID found in the response");
665
772
  }
666
- throw new InferenceOutputError(
667
- "Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
668
- );
669
- }
670
- };
671
- var HFInferenceZeroShotImageClassificationTask = class extends HFInferenceTask {
672
- async getResponse(response) {
673
- if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
674
- return response;
773
+ let status = response.status;
774
+ const parsedUrl = new URL(url);
775
+ const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
776
+ const modelId = new URL(response.response_url).pathname;
777
+ const queryParams = parsedUrl.search;
778
+ const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
779
+ const resultUrl = `${baseUrl}${modelId}${queryParams}`;
780
+ while (status !== "COMPLETED") {
781
+ await delay(500);
782
+ const statusResponse = await fetch(statusUrl, { headers });
783
+ if (!statusResponse.ok) {
784
+ throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
785
+ }
786
+ try {
787
+ status = (await statusResponse.json()).status;
788
+ } catch (error) {
789
+ throw new InferenceOutputError("Failed to parse status response from fal-ai API");
790
+ }
791
+ }
792
+ const resultResponse = await fetch(resultUrl, { headers });
793
+ let result;
794
+ try {
795
+ result = await resultResponse.json();
796
+ } catch (error) {
797
+ throw new InferenceOutputError("Failed to parse result response from fal-ai API");
798
+ }
799
+ if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
800
+ const urlResponse = await fetch(result.video.url);
801
+ return await urlResponse.blob();
802
+ } else {
803
+ throw new InferenceOutputError(
804
+ "Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
805
+ );
675
806
  }
676
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
677
807
  }
678
808
  };
679
- var HFInferenceTextClassificationTask = class extends HFInferenceTask {
680
- async getResponse(response) {
681
- const output = response?.[0];
682
- if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
683
- return output;
684
- }
685
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
809
+ var FalAIAutomaticSpeechRecognitionTask = class extends FalAITask {
810
+ prepareHeaders(params, binary) {
811
+ const headers = super.prepareHeaders(params, binary);
812
+ headers["Content-Type"] = "application/json";
813
+ return headers;
686
814
  }
687
- };
688
- var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask {
689
815
  async getResponse(response) {
690
- if (Array.isArray(response) ? response.every(
691
- (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
692
- ) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number") {
693
- return Array.isArray(response) ? response[0] : response;
816
+ const res = response;
817
+ if (typeof res?.text !== "string") {
818
+ throw new InferenceOutputError(
819
+ `Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
820
+ );
694
821
  }
695
- throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
822
+ return { text: res.text };
696
823
  }
697
824
  };
698
- var HFInferenceFillMaskTask = class extends HFInferenceTask {
699
- async getResponse(response) {
700
- if (Array.isArray(response) && response.every(
701
- (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
702
- )) {
703
- return response;
704
- }
705
- throw new InferenceOutputError(
706
- "Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
707
- );
825
+ var FalAITextToSpeechTask = class extends FalAITask {
826
+ preparePayload(params) {
827
+ return {
828
+ ...omit(params.args, ["inputs", "parameters"]),
829
+ ...params.args.parameters,
830
+ text: params.args.inputs
831
+ };
708
832
  }
709
- };
710
- var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask {
711
833
  async getResponse(response) {
712
- if (Array.isArray(response) && response.every(
713
- (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
714
- )) {
715
- return response;
834
+ const res = response;
835
+ if (typeof res?.audio?.url !== "string") {
836
+ throw new InferenceOutputError(
837
+ `Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
838
+ );
716
839
  }
717
- throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
718
- }
719
- };
720
- var HFInferenceSentenceSimilarityTask = class extends HFInferenceTask {
721
- async getResponse(response) {
722
- if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
723
- return response;
840
+ try {
841
+ const urlResponse = await fetch(res.audio.url);
842
+ if (!urlResponse.ok) {
843
+ throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
844
+ }
845
+ return await urlResponse.blob();
846
+ } catch (error) {
847
+ throw new InferenceOutputError(
848
+ `Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)}`
849
+ );
724
850
  }
725
- throw new InferenceOutputError("Expected Array<number>");
726
851
  }
727
852
  };
728
- var HFInferenceTableQuestionAnsweringTask = class extends HFInferenceTask {
729
- static validate(elem) {
730
- return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
731
- (coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
732
- );
733
- }
734
- async getResponse(response) {
735
- if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) {
736
- return Array.isArray(response) ? response[0] : response;
737
- }
738
- throw new InferenceOutputError(
739
- "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
740
- );
853
+
854
+ // src/providers/featherless-ai.ts
855
+ var FEATHERLESS_API_BASE_URL = "https://api.featherless.ai";
856
+ var FeatherlessAIConversationalTask = class extends BaseConversationalTask {
857
+ constructor() {
858
+ super("featherless-ai", FEATHERLESS_API_BASE_URL);
741
859
  }
742
860
  };
743
- var HFInferenceTokenClassificationTask = class extends HFInferenceTask {
744
- async getResponse(response) {
745
- if (Array.isArray(response) && response.every(
746
- (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
747
- )) {
748
- return response;
749
- }
750
- throw new InferenceOutputError(
751
- "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
752
- );
861
+ var FeatherlessAITextGenerationTask = class extends BaseTextGenerationTask {
862
+ constructor() {
863
+ super("featherless-ai", FEATHERLESS_API_BASE_URL);
753
864
  }
754
- };
755
- var HFInferenceTranslationTask = class extends HFInferenceTask {
756
- async getResponse(response) {
757
- if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
758
- return response?.length === 1 ? response?.[0] : response;
759
- }
760
- throw new InferenceOutputError("Expected Array<{translation_text: string}>");
865
+ preparePayload(params) {
866
+ return {
867
+ ...params.args,
868
+ ...params.args.parameters,
869
+ model: params.model,
870
+ prompt: params.args.inputs
871
+ };
761
872
  }
762
- };
763
- var HFInferenceSummarizationTask = class extends HFInferenceTask {
764
873
  async getResponse(response) {
765
- if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
766
- return response?.[0];
874
+ if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
875
+ const completion = response.choices[0];
876
+ return {
877
+ generated_text: completion.text
878
+ };
767
879
  }
768
- throw new InferenceOutputError("Expected Array<{summary_text: string}>");
880
+ throw new InferenceOutputError("Expected Featherless AI text generation response format");
769
881
  }
770
882
  };
771
- var HFInferenceTextToSpeechTask = class extends HFInferenceTask {
772
- async getResponse(response) {
773
- return response;
883
+
884
+ // src/providers/fireworks-ai.ts
885
+ var FireworksConversationalTask = class extends BaseConversationalTask {
886
+ constructor() {
887
+ super("fireworks-ai", "https://api.fireworks.ai");
774
888
  }
775
- };
776
- var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
777
- async getResponse(response) {
778
- if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
779
- return response;
780
- }
781
- throw new InferenceOutputError("Expected Array<number>");
889
+ makeRoute() {
890
+ return "/inference/v1/chat/completions";
782
891
  }
783
892
  };
784
- var HFInferenceVisualQuestionAnsweringTask = class extends HFInferenceTask {
785
- async getResponse(response) {
786
- if (Array.isArray(response) && response.every(
787
- (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
788
- )) {
789
- return response[0];
790
- }
791
- throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
893
+
894
+ // src/providers/groq.ts
895
+ var GROQ_API_BASE_URL = "https://api.groq.com";
896
+ var GroqTextGenerationTask = class extends BaseTextGenerationTask {
897
+ constructor() {
898
+ super("groq", GROQ_API_BASE_URL);
792
899
  }
793
- };
794
- var HFInferenceTabularRegressionTask = class extends HFInferenceTask {
795
- async getResponse(response) {
796
- if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
797
- return response;
798
- }
799
- throw new InferenceOutputError("Expected Array<number>");
900
+ makeRoute() {
901
+ return "/openai/v1/chat/completions";
800
902
  }
801
903
  };
802
- var HFInferenceTextToAudioTask = class extends HFInferenceTask {
803
- async getResponse(response) {
804
- return response;
904
+ var GroqConversationalTask = class extends BaseConversationalTask {
905
+ constructor() {
906
+ super("groq", GROQ_API_BASE_URL);
907
+ }
908
+ makeRoute() {
909
+ return "/openai/v1/chat/completions";
805
910
  }
806
911
  };
807
912
 
@@ -1295,82 +1400,13 @@ function getProviderHelper(provider, task) {
1295
1400
 
1296
1401
  // package.json
1297
1402
  var name = "@huggingface/inference";
1298
- var version = "3.10.0";
1299
-
1300
- // src/providers/consts.ts
1301
- var HARDCODED_MODEL_INFERENCE_MAPPING = {
1302
- /**
1303
- * "HF model ID" => "Model ID on Inference Provider's side"
1304
- *
1305
- * Example:
1306
- * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
1307
- */
1308
- "black-forest-labs": {},
1309
- cerebras: {},
1310
- cohere: {},
1311
- "fal-ai": {},
1312
- "featherless-ai": {},
1313
- "fireworks-ai": {},
1314
- groq: {},
1315
- "hf-inference": {},
1316
- hyperbolic: {},
1317
- nebius: {},
1318
- novita: {},
1319
- nscale: {},
1320
- openai: {},
1321
- ovhcloud: {},
1322
- replicate: {},
1323
- sambanova: {},
1324
- together: {}
1325
- };
1326
-
1327
- // src/lib/getInferenceProviderMapping.ts
1328
- var inferenceProviderMappingCache = /* @__PURE__ */ new Map();
1329
- async function getInferenceProviderMapping(params, options) {
1330
- if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
1331
- return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
1332
- }
1333
- let inferenceProviderMapping;
1334
- if (inferenceProviderMappingCache.has(params.modelId)) {
1335
- inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId);
1336
- } else {
1337
- const resp = await (options?.fetch ?? fetch)(
1338
- `${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`,
1339
- {
1340
- headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {}
1341
- }
1342
- );
1343
- if (resp.status === 404) {
1344
- throw new Error(`Model ${params.modelId} does not exist`);
1345
- }
1346
- inferenceProviderMapping = await resp.json().then((json) => json.inferenceProviderMapping).catch(() => null);
1347
- }
1348
- if (!inferenceProviderMapping) {
1349
- throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`);
1350
- }
1351
- const providerMapping = inferenceProviderMapping[params.provider];
1352
- if (providerMapping) {
1353
- const equivalentTasks = params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task];
1354
- if (!typedInclude(equivalentTasks, providerMapping.task)) {
1355
- throw new Error(
1356
- `Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
1357
- );
1358
- }
1359
- if (providerMapping.status === "staging") {
1360
- console.warn(
1361
- `Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
1362
- );
1363
- }
1364
- return { ...providerMapping, hfModelId: params.modelId };
1365
- }
1366
- return null;
1367
- }
1403
+ var version = "3.12.0";
1368
1404
 
1369
1405
  // src/lib/makeRequestOptions.ts
1370
1406
  var tasks = null;
1371
1407
  async function makeRequestOptions(args, providerHelper, options) {
1372
- const { provider: maybeProvider, model: maybeModel } = args;
1373
- const provider = maybeProvider ?? "hf-inference";
1408
+ const { model: maybeModel } = args;
1409
+ const provider = providerHelper.provider;
1374
1410
  const { task } = options ?? {};
1375
1411
  if (args.endpointUrl && provider !== "hf-inference") {
1376
1412
  throw new Error(`Cannot use endpointUrl with a third-party provider.`);
@@ -1425,7 +1461,7 @@ async function makeRequestOptions(args, providerHelper, options) {
1425
1461
  }
1426
1462
  function makeRequestOptionsFromResolvedModel(resolvedModel, providerHelper, args, mapping, options) {
1427
1463
  const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
1428
- const provider = maybeProvider ?? "hf-inference";
1464
+ const provider = providerHelper.provider;
1429
1465
  const { includeCredentials, task, signal, billTo } = options ?? {};
1430
1466
  const authMethod = (() => {
1431
1467
  if (providerHelper.clientSideRoutingOnly) {
@@ -1716,7 +1752,8 @@ async function request(args, options) {
1716
1752
  console.warn(
1717
1753
  "The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1718
1754
  );
1719
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
1755
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1756
+ const providerHelper = getProviderHelper(provider, options?.task);
1720
1757
  const result = await innerRequest(args, providerHelper, options);
1721
1758
  return result.data;
1722
1759
  }
@@ -1726,7 +1763,8 @@ async function* streamingRequest(args, options) {
1726
1763
  console.warn(
1727
1764
  "The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1728
1765
  );
1729
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
1766
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1767
+ const providerHelper = getProviderHelper(provider, options?.task);
1730
1768
  yield* innerStreamingRequest(args, providerHelper, options);
1731
1769
  }
1732
1770
 
@@ -1740,7 +1778,8 @@ function preparePayload(args) {
1740
1778
 
1741
1779
  // src/tasks/audio/audioClassification.ts
1742
1780
  async function audioClassification(args, options) {
1743
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
1781
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1782
+ const providerHelper = getProviderHelper(provider, "audio-classification");
1744
1783
  const payload = preparePayload(args);
1745
1784
  const { data: res } = await innerRequest(payload, providerHelper, {
1746
1785
  ...options,
@@ -1751,7 +1790,9 @@ async function audioClassification(args, options) {
1751
1790
 
1752
1791
  // src/tasks/audio/audioToAudio.ts
1753
1792
  async function audioToAudio(args, options) {
1754
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
1793
+ const model = "inputs" in args ? args.model : void 0;
1794
+ const provider = await resolveProvider(args.provider, model);
1795
+ const providerHelper = getProviderHelper(provider, "audio-to-audio");
1755
1796
  const payload = preparePayload(args);
1756
1797
  const { data: res } = await innerRequest(payload, providerHelper, {
1757
1798
  ...options,
@@ -1775,7 +1816,8 @@ function base64FromBytes(arr) {
1775
1816
 
1776
1817
  // src/tasks/audio/automaticSpeechRecognition.ts
1777
1818
  async function automaticSpeechRecognition(args, options) {
1778
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
1819
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1820
+ const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
1779
1821
  const payload = await buildPayload(args);
1780
1822
  const { data: res } = await innerRequest(payload, providerHelper, {
1781
1823
  ...options,
@@ -1815,7 +1857,7 @@ async function buildPayload(args) {
1815
1857
 
1816
1858
  // src/tasks/audio/textToSpeech.ts
1817
1859
  async function textToSpeech(args, options) {
1818
- const provider = args.provider ?? "hf-inference";
1860
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1819
1861
  const providerHelper = getProviderHelper(provider, "text-to-speech");
1820
1862
  const { data: res } = await innerRequest(args, providerHelper, {
1821
1863
  ...options,
@@ -1831,7 +1873,8 @@ function preparePayload2(args) {
1831
1873
 
1832
1874
  // src/tasks/cv/imageClassification.ts
1833
1875
  async function imageClassification(args, options) {
1834
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
1876
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1877
+ const providerHelper = getProviderHelper(provider, "image-classification");
1835
1878
  const payload = preparePayload2(args);
1836
1879
  const { data: res } = await innerRequest(payload, providerHelper, {
1837
1880
  ...options,
@@ -1842,7 +1885,8 @@ async function imageClassification(args, options) {
1842
1885
 
1843
1886
  // src/tasks/cv/imageSegmentation.ts
1844
1887
  async function imageSegmentation(args, options) {
1845
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
1888
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1889
+ const providerHelper = getProviderHelper(provider, "image-segmentation");
1846
1890
  const payload = preparePayload2(args);
1847
1891
  const { data: res } = await innerRequest(payload, providerHelper, {
1848
1892
  ...options,
@@ -1853,7 +1897,8 @@ async function imageSegmentation(args, options) {
1853
1897
 
1854
1898
  // src/tasks/cv/imageToImage.ts
1855
1899
  async function imageToImage(args, options) {
1856
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
1900
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1901
+ const providerHelper = getProviderHelper(provider, "image-to-image");
1857
1902
  let reqArgs;
1858
1903
  if (!args.parameters) {
1859
1904
  reqArgs = {
@@ -1878,7 +1923,8 @@ async function imageToImage(args, options) {
1878
1923
 
1879
1924
  // src/tasks/cv/imageToText.ts
1880
1925
  async function imageToText(args, options) {
1881
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
1926
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1927
+ const providerHelper = getProviderHelper(provider, "image-to-text");
1882
1928
  const payload = preparePayload2(args);
1883
1929
  const { data: res } = await innerRequest(payload, providerHelper, {
1884
1930
  ...options,
@@ -1889,7 +1935,8 @@ async function imageToText(args, options) {
1889
1935
 
1890
1936
  // src/tasks/cv/objectDetection.ts
1891
1937
  async function objectDetection(args, options) {
1892
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
1938
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1939
+ const providerHelper = getProviderHelper(provider, "object-detection");
1893
1940
  const payload = preparePayload2(args);
1894
1941
  const { data: res } = await innerRequest(payload, providerHelper, {
1895
1942
  ...options,
@@ -1900,7 +1947,7 @@ async function objectDetection(args, options) {
1900
1947
 
1901
1948
  // src/tasks/cv/textToImage.ts
1902
1949
  async function textToImage(args, options) {
1903
- const provider = args.provider ?? "hf-inference";
1950
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1904
1951
  const providerHelper = getProviderHelper(provider, "text-to-image");
1905
1952
  const { data: res } = await innerRequest(args, providerHelper, {
1906
1953
  ...options,
@@ -1912,7 +1959,7 @@ async function textToImage(args, options) {
1912
1959
 
1913
1960
  // src/tasks/cv/textToVideo.ts
1914
1961
  async function textToVideo(args, options) {
1915
- const provider = args.provider ?? "hf-inference";
1962
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1916
1963
  const providerHelper = getProviderHelper(provider, "text-to-video");
1917
1964
  const { data: response } = await innerRequest(
1918
1965
  args,
@@ -1949,7 +1996,8 @@ async function preparePayload3(args) {
1949
1996
  }
1950
1997
  }
1951
1998
  async function zeroShotImageClassification(args, options) {
1952
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
1999
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2000
+ const providerHelper = getProviderHelper(provider, "zero-shot-image-classification");
1953
2001
  const payload = await preparePayload3(args);
1954
2002
  const { data: res } = await innerRequest(payload, providerHelper, {
1955
2003
  ...options,
@@ -1960,7 +2008,8 @@ async function zeroShotImageClassification(args, options) {
1960
2008
 
1961
2009
  // src/tasks/nlp/chatCompletion.ts
1962
2010
  async function chatCompletion(args, options) {
1963
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
2011
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2012
+ const providerHelper = getProviderHelper(provider, "conversational");
1964
2013
  const { data: response } = await innerRequest(args, providerHelper, {
1965
2014
  ...options,
1966
2015
  task: "conversational"
@@ -1970,7 +2019,8 @@ async function chatCompletion(args, options) {
1970
2019
 
1971
2020
  // src/tasks/nlp/chatCompletionStream.ts
1972
2021
  async function* chatCompletionStream(args, options) {
1973
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
2022
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2023
+ const providerHelper = getProviderHelper(provider, "conversational");
1974
2024
  yield* innerStreamingRequest(args, providerHelper, {
1975
2025
  ...options,
1976
2026
  task: "conversational"
@@ -1979,7 +2029,8 @@ async function* chatCompletionStream(args, options) {
1979
2029
 
1980
2030
  // src/tasks/nlp/featureExtraction.ts
1981
2031
  async function featureExtraction(args, options) {
1982
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction");
2032
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2033
+ const providerHelper = getProviderHelper(provider, "feature-extraction");
1983
2034
  const { data: res } = await innerRequest(args, providerHelper, {
1984
2035
  ...options,
1985
2036
  task: "feature-extraction"
@@ -1989,7 +2040,8 @@ async function featureExtraction(args, options) {
1989
2040
 
1990
2041
  // src/tasks/nlp/fillMask.ts
1991
2042
  async function fillMask(args, options) {
1992
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask");
2043
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2044
+ const providerHelper = getProviderHelper(provider, "fill-mask");
1993
2045
  const { data: res } = await innerRequest(args, providerHelper, {
1994
2046
  ...options,
1995
2047
  task: "fill-mask"
@@ -1999,7 +2051,8 @@ async function fillMask(args, options) {
1999
2051
 
2000
2052
  // src/tasks/nlp/questionAnswering.ts
2001
2053
  async function questionAnswering(args, options) {
2002
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering");
2054
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2055
+ const providerHelper = getProviderHelper(provider, "question-answering");
2003
2056
  const { data: res } = await innerRequest(
2004
2057
  args,
2005
2058
  providerHelper,
@@ -2013,7 +2066,8 @@ async function questionAnswering(args, options) {
2013
2066
 
2014
2067
  // src/tasks/nlp/sentenceSimilarity.ts
2015
2068
  async function sentenceSimilarity(args, options) {
2016
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity");
2069
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2070
+ const providerHelper = getProviderHelper(provider, "sentence-similarity");
2017
2071
  const { data: res } = await innerRequest(args, providerHelper, {
2018
2072
  ...options,
2019
2073
  task: "sentence-similarity"
@@ -2023,7 +2077,8 @@ async function sentenceSimilarity(args, options) {
2023
2077
 
2024
2078
  // src/tasks/nlp/summarization.ts
2025
2079
  async function summarization(args, options) {
2026
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization");
2080
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2081
+ const providerHelper = getProviderHelper(provider, "summarization");
2027
2082
  const { data: res } = await innerRequest(args, providerHelper, {
2028
2083
  ...options,
2029
2084
  task: "summarization"
@@ -2033,7 +2088,8 @@ async function summarization(args, options) {
2033
2088
 
2034
2089
  // src/tasks/nlp/tableQuestionAnswering.ts
2035
2090
  async function tableQuestionAnswering(args, options) {
2036
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
2091
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2092
+ const providerHelper = getProviderHelper(provider, "table-question-answering");
2037
2093
  const { data: res } = await innerRequest(
2038
2094
  args,
2039
2095
  providerHelper,
@@ -2047,7 +2103,8 @@ async function tableQuestionAnswering(args, options) {
2047
2103
 
2048
2104
  // src/tasks/nlp/textClassification.ts
2049
2105
  async function textClassification(args, options) {
2050
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
2106
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2107
+ const providerHelper = getProviderHelper(provider, "text-classification");
2051
2108
  const { data: res } = await innerRequest(args, providerHelper, {
2052
2109
  ...options,
2053
2110
  task: "text-classification"
@@ -2057,7 +2114,8 @@ async function textClassification(args, options) {
2057
2114
 
2058
2115
  // src/tasks/nlp/textGeneration.ts
2059
2116
  async function textGeneration(args, options) {
2060
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
2117
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2118
+ const providerHelper = getProviderHelper(provider, "text-generation");
2061
2119
  const { data: response } = await innerRequest(args, providerHelper, {
2062
2120
  ...options,
2063
2121
  task: "text-generation"
@@ -2067,7 +2125,8 @@ async function textGeneration(args, options) {
2067
2125
 
2068
2126
  // src/tasks/nlp/textGenerationStream.ts
2069
2127
  async function* textGenerationStream(args, options) {
2070
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
2128
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2129
+ const providerHelper = getProviderHelper(provider, "text-generation");
2071
2130
  yield* innerStreamingRequest(args, providerHelper, {
2072
2131
  ...options,
2073
2132
  task: "text-generation"
@@ -2076,7 +2135,8 @@ async function* textGenerationStream(args, options) {
2076
2135
 
2077
2136
  // src/tasks/nlp/tokenClassification.ts
2078
2137
  async function tokenClassification(args, options) {
2079
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification");
2138
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2139
+ const providerHelper = getProviderHelper(provider, "token-classification");
2080
2140
  const { data: res } = await innerRequest(
2081
2141
  args,
2082
2142
  providerHelper,
@@ -2090,7 +2150,8 @@ async function tokenClassification(args, options) {
2090
2150
 
2091
2151
  // src/tasks/nlp/translation.ts
2092
2152
  async function translation(args, options) {
2093
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation");
2153
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2154
+ const providerHelper = getProviderHelper(provider, "translation");
2094
2155
  const { data: res } = await innerRequest(args, providerHelper, {
2095
2156
  ...options,
2096
2157
  task: "translation"
@@ -2100,7 +2161,8 @@ async function translation(args, options) {
2100
2161
 
2101
2162
  // src/tasks/nlp/zeroShotClassification.ts
2102
2163
  async function zeroShotClassification(args, options) {
2103
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification");
2164
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2165
+ const providerHelper = getProviderHelper(provider, "zero-shot-classification");
2104
2166
  const { data: res } = await innerRequest(
2105
2167
  args,
2106
2168
  providerHelper,
@@ -2114,7 +2176,8 @@ async function zeroShotClassification(args, options) {
2114
2176
 
2115
2177
  // src/tasks/multimodal/documentQuestionAnswering.ts
2116
2178
  async function documentQuestionAnswering(args, options) {
2117
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering");
2179
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2180
+ const providerHelper = getProviderHelper(provider, "document-question-answering");
2118
2181
  const reqArgs = {
2119
2182
  ...args,
2120
2183
  inputs: {
@@ -2136,7 +2199,8 @@ async function documentQuestionAnswering(args, options) {
2136
2199
 
2137
2200
  // src/tasks/multimodal/visualQuestionAnswering.ts
2138
2201
  async function visualQuestionAnswering(args, options) {
2139
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "visual-question-answering");
2202
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2203
+ const providerHelper = getProviderHelper(provider, "visual-question-answering");
2140
2204
  const reqArgs = {
2141
2205
  ...args,
2142
2206
  inputs: {
@@ -2154,7 +2218,8 @@ async function visualQuestionAnswering(args, options) {
2154
2218
 
2155
2219
  // src/tasks/tabular/tabularClassification.ts
2156
2220
  async function tabularClassification(args, options) {
2157
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification");
2221
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2222
+ const providerHelper = getProviderHelper(provider, "tabular-classification");
2158
2223
  const { data: res } = await innerRequest(args, providerHelper, {
2159
2224
  ...options,
2160
2225
  task: "tabular-classification"
@@ -2164,7 +2229,8 @@ async function tabularClassification(args, options) {
2164
2229
 
2165
2230
  // src/tasks/tabular/tabularRegression.ts
2166
2231
  async function tabularRegression(args, options) {
2167
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression");
2232
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2233
+ const providerHelper = getProviderHelper(provider, "tabular-regression");
2168
2234
  const { data: res } = await innerRequest(args, providerHelper, {
2169
2235
  ...options,
2170
2236
  task: "tabular-regression"
@@ -2172,6 +2238,11 @@ async function tabularRegression(args, options) {
2172
2238
  return providerHelper.getResponse(res);
2173
2239
  }
2174
2240
 
2241
+ // src/utils/typedEntries.ts
2242
+ function typedEntries(obj) {
2243
+ return Object.entries(obj);
2244
+ }
2245
+
2175
2246
  // src/InferenceClient.ts
2176
2247
  var InferenceClient = class {
2177
2248
  accessToken;
@@ -2179,40 +2250,36 @@ var InferenceClient = class {
2179
2250
  constructor(accessToken = "", defaultOptions = {}) {
2180
2251
  this.accessToken = accessToken;
2181
2252
  this.defaultOptions = defaultOptions;
2182
- for (const [name2, fn] of Object.entries(tasks_exports)) {
2253
+ for (const [name2, fn] of typedEntries(tasks_exports)) {
2183
2254
  Object.defineProperty(this, name2, {
2184
2255
  enumerable: false,
2185
2256
  value: (params, options) => (
2186
2257
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
2187
- fn({ ...params, accessToken }, { ...defaultOptions, ...options })
2258
+ fn(
2259
+ /// ^ The cast of fn to any is necessary, otherwise TS can't compile because the generated union type is too complex
2260
+ { endpointUrl: defaultOptions.endpointUrl, accessToken, ...params },
2261
+ {
2262
+ ...omit(defaultOptions, ["endpointUrl"]),
2263
+ ...options
2264
+ }
2265
+ )
2188
2266
  )
2189
2267
  });
2190
2268
  }
2191
2269
  }
2192
2270
  /**
2193
- * Returns copy of InferenceClient tied to a specified endpoint.
2271
+ * Returns a new instance of InferenceClient tied to a specified endpoint.
2272
+ *
2273
+ * For backward compatibility mostly.
2194
2274
  */
2195
2275
  endpoint(endpointUrl) {
2196
- return new InferenceClientEndpoint(endpointUrl, this.accessToken, this.defaultOptions);
2197
- }
2198
- };
2199
- var InferenceClientEndpoint = class {
2200
- constructor(endpointUrl, accessToken = "", defaultOptions = {}) {
2201
- accessToken;
2202
- defaultOptions;
2203
- for (const [name2, fn] of Object.entries(tasks_exports)) {
2204
- Object.defineProperty(this, name2, {
2205
- enumerable: false,
2206
- value: (params, options) => (
2207
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
2208
- fn({ ...params, accessToken, endpointUrl }, { ...defaultOptions, ...options })
2209
- )
2210
- });
2211
- }
2276
+ return new InferenceClient(this.accessToken, { ...this.defaultOptions, endpointUrl });
2212
2277
  }
2213
2278
  };
2214
2279
  var HfInference = class extends InferenceClient {
2215
2280
  };
2281
+ var InferenceClientEndpoint = class extends InferenceClient {
2282
+ };
2216
2283
 
2217
2284
  // src/types.ts
2218
2285
  var INFERENCE_PROVIDERS = [
@@ -2234,6 +2301,7 @@ var INFERENCE_PROVIDERS = [
2234
2301
  "sambanova",
2235
2302
  "together"
2236
2303
  ];
2304
+ var PROVIDERS_OR_POLICIES = [...INFERENCE_PROVIDERS, "auto"];
2237
2305
 
2238
2306
  // src/snippets/index.ts
2239
2307
  var snippets_exports = {};
@@ -2565,7 +2633,7 @@ var prepareConversationalInput = (model, opts) => {
2565
2633
  return {
2566
2634
  messages: opts?.messages ?? getModelInputSnippet(model),
2567
2635
  ...opts?.temperature ? { temperature: opts?.temperature } : void 0,
2568
- max_tokens: opts?.max_tokens ?? 512,
2636
+ ...opts?.max_tokens ? { max_tokens: opts?.max_tokens } : void 0,
2569
2637
  ...opts?.top_p ? { top_p: opts?.top_p } : void 0
2570
2638
  };
2571
2639
  };
@@ -2658,6 +2726,7 @@ export {
2658
2726
  InferenceClient,
2659
2727
  InferenceClientEndpoint,
2660
2728
  InferenceOutputError,
2729
+ PROVIDERS_OR_POLICIES,
2661
2730
  audioClassification,
2662
2731
  audioToAudio,
2663
2732
  automaticSpeechRecognition,