@huggingface/inference 3.10.0 → 3.12.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. package/dist/index.cjs +713 -643
  2. package/dist/index.js +712 -643
  3. package/dist/src/InferenceClient.d.ts +16 -17
  4. package/dist/src/InferenceClient.d.ts.map +1 -1
  5. package/dist/src/lib/getInferenceProviderMapping.d.ts +5 -1
  6. package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -1
  7. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  8. package/dist/src/providers/providerHelper.d.ts +1 -1
  9. package/dist/src/providers/providerHelper.d.ts.map +1 -1
  10. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  11. package/dist/src/tasks/audio/audioToAudio.d.ts.map +1 -1
  12. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  13. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  14. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  15. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  16. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  17. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  18. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  19. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  20. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  21. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  22. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
  23. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  24. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  25. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  26. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  27. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
  28. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  29. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  30. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  31. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  32. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  33. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  34. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  35. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  36. package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
  37. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  38. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  39. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  40. package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
  41. package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
  42. package/dist/src/types.d.ts +6 -4
  43. package/dist/src/types.d.ts.map +1 -1
  44. package/dist/src/utils/typedEntries.d.ts +4 -0
  45. package/dist/src/utils/typedEntries.d.ts.map +1 -0
  46. package/package.json +3 -3
  47. package/src/InferenceClient.ts +32 -43
  48. package/src/lib/getInferenceProviderMapping.ts +68 -19
  49. package/src/lib/makeRequestOptions.ts +4 -3
  50. package/src/providers/hf-inference.ts +1 -1
  51. package/src/providers/providerHelper.ts +1 -1
  52. package/src/snippets/getInferenceSnippets.ts +1 -1
  53. package/src/tasks/audio/audioClassification.ts +3 -1
  54. package/src/tasks/audio/audioToAudio.ts +4 -1
  55. package/src/tasks/audio/automaticSpeechRecognition.ts +3 -1
  56. package/src/tasks/audio/textToSpeech.ts +2 -1
  57. package/src/tasks/custom/request.ts +3 -1
  58. package/src/tasks/custom/streamingRequest.ts +3 -1
  59. package/src/tasks/cv/imageClassification.ts +3 -1
  60. package/src/tasks/cv/imageSegmentation.ts +3 -1
  61. package/src/tasks/cv/imageToImage.ts +3 -1
  62. package/src/tasks/cv/imageToText.ts +3 -1
  63. package/src/tasks/cv/objectDetection.ts +3 -1
  64. package/src/tasks/cv/textToImage.ts +2 -1
  65. package/src/tasks/cv/textToVideo.ts +2 -1
  66. package/src/tasks/cv/zeroShotImageClassification.ts +3 -1
  67. package/src/tasks/multimodal/documentQuestionAnswering.ts +3 -1
  68. package/src/tasks/multimodal/visualQuestionAnswering.ts +3 -1
  69. package/src/tasks/nlp/chatCompletion.ts +3 -1
  70. package/src/tasks/nlp/chatCompletionStream.ts +3 -1
  71. package/src/tasks/nlp/featureExtraction.ts +3 -1
  72. package/src/tasks/nlp/fillMask.ts +3 -1
  73. package/src/tasks/nlp/questionAnswering.ts +4 -1
  74. package/src/tasks/nlp/sentenceSimilarity.ts +3 -1
  75. package/src/tasks/nlp/summarization.ts +3 -1
  76. package/src/tasks/nlp/tableQuestionAnswering.ts +3 -1
  77. package/src/tasks/nlp/textClassification.ts +3 -1
  78. package/src/tasks/nlp/textGeneration.ts +3 -1
  79. package/src/tasks/nlp/textGenerationStream.ts +3 -1
  80. package/src/tasks/nlp/tokenClassification.ts +3 -1
  81. package/src/tasks/nlp/translation.ts +3 -1
  82. package/src/tasks/nlp/zeroShotClassification.ts +3 -1
  83. package/src/tasks/tabular/tabularClassification.ts +3 -1
  84. package/src/tasks/tabular/tabularRegression.ts +3 -1
  85. package/src/types.ts +8 -4
  86. package/src/utils/typedEntries.ts +5 -0
package/dist/index.cjs CHANGED
@@ -25,6 +25,7 @@ __export(src_exports, {
25
25
  InferenceClient: () => InferenceClient,
26
26
  InferenceClientEndpoint: () => InferenceClientEndpoint,
27
27
  InferenceOutputError: () => InferenceOutputError,
28
+ PROVIDERS_OR_POLICIES: () => PROVIDERS_OR_POLICIES,
28
29
  audioClassification: () => audioClassification,
29
30
  audioToAudio: () => audioToAudio,
30
31
  automaticSpeechRecognition: () => automaticSpeechRecognition,
@@ -98,6 +99,38 @@ __export(tasks_exports, {
98
99
  zeroShotImageClassification: () => zeroShotImageClassification
99
100
  });
100
101
 
102
+ // src/config.ts
103
+ var HF_HUB_URL = "https://huggingface.co";
104
+ var HF_ROUTER_URL = "https://router.huggingface.co";
105
+ var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
106
+
107
+ // src/providers/consts.ts
108
+ var HARDCODED_MODEL_INFERENCE_MAPPING = {
109
+ /**
110
+ * "HF model ID" => "Model ID on Inference Provider's side"
111
+ *
112
+ * Example:
113
+ * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
114
+ */
115
+ "black-forest-labs": {},
116
+ cerebras: {},
117
+ cohere: {},
118
+ "fal-ai": {},
119
+ "featherless-ai": {},
120
+ "fireworks-ai": {},
121
+ groq: {},
122
+ "hf-inference": {},
123
+ hyperbolic: {},
124
+ nebius: {},
125
+ novita: {},
126
+ nscale: {},
127
+ openai: {},
128
+ ovhcloud: {},
129
+ replicate: {},
130
+ sambanova: {},
131
+ together: {}
132
+ };
133
+
101
134
  // src/lib/InferenceOutputError.ts
102
135
  var InferenceOutputError = class extends TypeError {
103
136
  constructor(message) {
@@ -108,42 +141,6 @@ var InferenceOutputError = class extends TypeError {
108
141
  }
109
142
  };
110
143
 
111
- // src/utils/delay.ts
112
- function delay(ms) {
113
- return new Promise((resolve) => {
114
- setTimeout(() => resolve(), ms);
115
- });
116
- }
117
-
118
- // src/utils/pick.ts
119
- function pick(o, props) {
120
- return Object.assign(
121
- {},
122
- ...props.map((prop) => {
123
- if (o[prop] !== void 0) {
124
- return { [prop]: o[prop] };
125
- }
126
- })
127
- );
128
- }
129
-
130
- // src/utils/typedInclude.ts
131
- function typedInclude(arr, v) {
132
- return arr.includes(v);
133
- }
134
-
135
- // src/utils/omit.ts
136
- function omit(o, props) {
137
- const propsArr = Array.isArray(props) ? props : [props];
138
- const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
139
- return pick(o, letsKeep);
140
- }
141
-
142
- // src/config.ts
143
- var HF_HUB_URL = "https://huggingface.co";
144
- var HF_ROUTER_URL = "https://router.huggingface.co";
145
- var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
146
-
147
144
  // src/utils/toArray.ts
148
145
  function toArray(obj) {
149
146
  if (Array.isArray(obj)) {
@@ -238,627 +235,736 @@ var BaseTextGenerationTask = class extends TaskProviderHelper {
238
235
  }
239
236
  };
240
237
 
241
- // src/providers/black-forest-labs.ts
242
- var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
243
- var BlackForestLabsTextToImageTask = class extends TaskProviderHelper {
238
+ // src/providers/hf-inference.ts
239
+ var EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = ["feature-extraction", "sentence-similarity"];
240
+ var HFInferenceTask = class extends TaskProviderHelper {
244
241
  constructor() {
245
- super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
242
+ super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
246
243
  }
247
244
  preparePayload(params) {
248
- return {
249
- ...omit(params.args, ["inputs", "parameters"]),
250
- ...params.args.parameters,
251
- prompt: params.args.inputs
252
- };
245
+ return params.args;
253
246
  }
254
- prepareHeaders(params, binary) {
255
- const headers = {
256
- Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`
257
- };
258
- if (!binary) {
259
- headers["Content-Type"] = "application/json";
247
+ makeUrl(params) {
248
+ if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
249
+ return params.model;
260
250
  }
261
- return headers;
251
+ return super.makeUrl(params);
262
252
  }
263
253
  makeRoute(params) {
264
- if (!params) {
265
- throw new Error("Params are required");
254
+ if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
255
+ return `models/${params.model}/pipeline/${params.task}`;
266
256
  }
267
- return `/v1/${params.model}`;
257
+ return `models/${params.model}`;
258
+ }
259
+ async getResponse(response) {
260
+ return response;
268
261
  }
262
+ };
263
+ var HFInferenceTextToImageTask = class extends HFInferenceTask {
269
264
  async getResponse(response, url, headers, outputType) {
270
- const urlObj = new URL(response.polling_url);
271
- for (let step = 0; step < 5; step++) {
272
- await delay(1e3);
273
- console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
274
- urlObj.searchParams.set("attempt", step.toString(10));
275
- const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
276
- if (!resp.ok) {
277
- throw new InferenceOutputError("Failed to fetch result from black forest labs API");
265
+ if (!response) {
266
+ throw new InferenceOutputError("response is undefined");
267
+ }
268
+ if (typeof response == "object") {
269
+ if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
270
+ const base64Data = response.data[0].b64_json;
271
+ if (outputType === "url") {
272
+ return `data:image/jpeg;base64,${base64Data}`;
273
+ }
274
+ const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
275
+ return await base64Response.blob();
278
276
  }
279
- const payload = await resp.json();
280
- if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
277
+ if ("output" in response && Array.isArray(response.output)) {
281
278
  if (outputType === "url") {
282
- return payload.result.sample;
279
+ return response.output[0];
283
280
  }
284
- const image = await fetch(payload.result.sample);
285
- return await image.blob();
281
+ const urlResponse = await fetch(response.output[0]);
282
+ const blob = await urlResponse.blob();
283
+ return blob;
286
284
  }
287
285
  }
288
- throw new InferenceOutputError("Failed to fetch result from black forest labs API");
286
+ if (response instanceof Blob) {
287
+ if (outputType === "url") {
288
+ const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
289
+ return `data:image/jpeg;base64,${b64}`;
290
+ }
291
+ return response;
292
+ }
293
+ throw new InferenceOutputError("Expected a Blob ");
289
294
  }
290
295
  };
291
-
292
- // src/providers/cerebras.ts
293
- var CerebrasConversationalTask = class extends BaseConversationalTask {
294
- constructor() {
295
- super("cerebras", "https://api.cerebras.ai");
296
+ var HFInferenceConversationalTask = class extends HFInferenceTask {
297
+ makeUrl(params) {
298
+ let url;
299
+ if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
300
+ url = params.model.trim();
301
+ } else {
302
+ url = `${this.makeBaseUrl(params)}/models/${params.model}`;
303
+ }
304
+ url = url.replace(/\/+$/, "");
305
+ if (url.endsWith("/v1")) {
306
+ url += "/chat/completions";
307
+ } else if (!url.endsWith("/chat/completions")) {
308
+ url += "/v1/chat/completions";
309
+ }
310
+ return url;
311
+ }
312
+ preparePayload(params) {
313
+ return {
314
+ ...params.args,
315
+ model: params.model
316
+ };
317
+ }
318
+ async getResponse(response) {
319
+ return response;
296
320
  }
297
321
  };
298
-
299
- // src/providers/cohere.ts
300
- var CohereConversationalTask = class extends BaseConversationalTask {
301
- constructor() {
302
- super("cohere", "https://api.cohere.com");
322
+ var HFInferenceTextGenerationTask = class extends HFInferenceTask {
323
+ async getResponse(response) {
324
+ const res = toArray(response);
325
+ if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
326
+ return res?.[0];
327
+ }
328
+ throw new InferenceOutputError("Expected Array<{generated_text: string}>");
303
329
  }
304
- makeRoute() {
305
- return "/compatibility/v1/chat/completions";
330
+ };
331
+ var HFInferenceAudioClassificationTask = class extends HFInferenceTask {
332
+ async getResponse(response) {
333
+ if (Array.isArray(response) && response.every(
334
+ (x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
335
+ )) {
336
+ return response;
337
+ }
338
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
306
339
  }
307
340
  };
308
-
309
- // src/lib/isUrl.ts
310
- function isUrl(modelOrUrl) {
311
- return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
312
- }
313
-
314
- // src/providers/fal-ai.ts
315
- var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
316
- var FalAITask = class extends TaskProviderHelper {
317
- constructor(url) {
318
- super("fal-ai", url || "https://fal.run");
341
+ var HFInferenceAutomaticSpeechRecognitionTask = class extends HFInferenceTask {
342
+ async getResponse(response) {
343
+ return response;
319
344
  }
320
- preparePayload(params) {
321
- return params.args;
345
+ };
346
+ var HFInferenceAudioToAudioTask = class extends HFInferenceTask {
347
+ async getResponse(response) {
348
+ if (!Array.isArray(response)) {
349
+ throw new InferenceOutputError("Expected Array");
350
+ }
351
+ if (!response.every((elem) => {
352
+ return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
353
+ })) {
354
+ throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
355
+ }
356
+ return response;
322
357
  }
323
- makeRoute(params) {
324
- return `/${params.model}`;
358
+ };
359
+ var HFInferenceDocumentQuestionAnsweringTask = class extends HFInferenceTask {
360
+ async getResponse(response) {
361
+ if (Array.isArray(response) && response.every(
362
+ (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
363
+ )) {
364
+ return response[0];
365
+ }
366
+ throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
325
367
  }
326
- prepareHeaders(params, binary) {
327
- const headers = {
328
- Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`
368
+ };
369
+ var HFInferenceFeatureExtractionTask = class extends HFInferenceTask {
370
+ async getResponse(response) {
371
+ const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
372
+ if (curDepth > maxDepth)
373
+ return false;
374
+ if (arr.every((x) => Array.isArray(x))) {
375
+ return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
376
+ } else {
377
+ return arr.every((x) => typeof x === "number");
378
+ }
329
379
  };
330
- if (!binary) {
331
- headers["Content-Type"] = "application/json";
380
+ if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
381
+ return response;
332
382
  }
333
- return headers;
383
+ throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
334
384
  }
335
385
  };
336
- function buildLoraPath(modelId, adapterWeightsPath) {
337
- return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
338
- }
339
- var FalAITextToImageTask = class extends FalAITask {
340
- preparePayload(params) {
341
- const payload = {
342
- ...omit(params.args, ["inputs", "parameters"]),
343
- ...params.args.parameters,
344
- sync_mode: true,
345
- prompt: params.args.inputs
346
- };
347
- if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
348
- payload.loras = [
349
- {
350
- path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
351
- scale: 1
352
- }
353
- ];
354
- if (params.mapping.providerId === "fal-ai/lora") {
355
- payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
356
- }
386
+ var HFInferenceImageClassificationTask = class extends HFInferenceTask {
387
+ async getResponse(response) {
388
+ if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
389
+ return response;
357
390
  }
358
- return payload;
391
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
359
392
  }
360
- async getResponse(response, outputType) {
361
- if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") {
362
- if (outputType === "url") {
363
- return response.images[0].url;
364
- }
365
- const urlResponse = await fetch(response.images[0].url);
366
- return await urlResponse.blob();
393
+ };
394
+ var HFInferenceImageSegmentationTask = class extends HFInferenceTask {
395
+ async getResponse(response) {
396
+ if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) {
397
+ return response;
367
398
  }
368
- throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
399
+ throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
369
400
  }
370
401
  };
371
- var FalAITextToVideoTask = class extends FalAITask {
372
- constructor() {
373
- super("https://queue.fal.run");
374
- }
375
- makeRoute(params) {
376
- if (params.authMethod !== "provider-key") {
377
- return `/${params.model}?_subdomain=queue`;
402
+ var HFInferenceImageToTextTask = class extends HFInferenceTask {
403
+ async getResponse(response) {
404
+ if (typeof response?.generated_text !== "string") {
405
+ throw new InferenceOutputError("Expected {generated_text: string}");
378
406
  }
379
- return `/${params.model}`;
407
+ return response;
380
408
  }
381
- preparePayload(params) {
382
- return {
383
- ...omit(params.args, ["inputs", "parameters"]),
384
- ...params.args.parameters,
385
- prompt: params.args.inputs
386
- };
409
+ };
410
+ var HFInferenceImageToImageTask = class extends HFInferenceTask {
411
+ async getResponse(response) {
412
+ if (response instanceof Blob) {
413
+ return response;
414
+ }
415
+ throw new InferenceOutputError("Expected Blob");
387
416
  }
388
- async getResponse(response, url, headers) {
389
- if (!url || !headers) {
390
- throw new InferenceOutputError("URL and headers are required for text-to-video task");
417
+ };
418
+ var HFInferenceObjectDetectionTask = class extends HFInferenceTask {
419
+ async getResponse(response) {
420
+ if (Array.isArray(response) && response.every(
421
+ (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
422
+ )) {
423
+ return response;
391
424
  }
392
- const requestId = response.request_id;
393
- if (!requestId) {
394
- throw new InferenceOutputError("No request ID found in the response");
425
+ throw new InferenceOutputError(
426
+ "Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
427
+ );
428
+ }
429
+ };
430
+ var HFInferenceZeroShotImageClassificationTask = class extends HFInferenceTask {
431
+ async getResponse(response) {
432
+ if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
433
+ return response;
395
434
  }
396
- let status = response.status;
397
- const parsedUrl = new URL(url);
398
- const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
399
- const modelId = new URL(response.response_url).pathname;
400
- const queryParams = parsedUrl.search;
401
- const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
402
- const resultUrl = `${baseUrl}${modelId}${queryParams}`;
403
- while (status !== "COMPLETED") {
404
- await delay(500);
405
- const statusResponse = await fetch(statusUrl, { headers });
406
- if (!statusResponse.ok) {
407
- throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
408
- }
409
- try {
410
- status = (await statusResponse.json()).status;
411
- } catch (error) {
412
- throw new InferenceOutputError("Failed to parse status response from fal-ai API");
413
- }
435
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
436
+ }
437
+ };
438
+ var HFInferenceTextClassificationTask = class extends HFInferenceTask {
439
+ async getResponse(response) {
440
+ const output = response?.[0];
441
+ if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
442
+ return output;
414
443
  }
415
- const resultResponse = await fetch(resultUrl, { headers });
416
- let result;
417
- try {
418
- result = await resultResponse.json();
419
- } catch (error) {
420
- throw new InferenceOutputError("Failed to parse result response from fal-ai API");
444
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
445
+ }
446
+ };
447
+ var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask {
448
+ async getResponse(response) {
449
+ if (Array.isArray(response) ? response.every(
450
+ (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
451
+ ) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number") {
452
+ return Array.isArray(response) ? response[0] : response;
421
453
  }
422
- if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
423
- const urlResponse = await fetch(result.video.url);
424
- return await urlResponse.blob();
425
- } else {
426
- throw new InferenceOutputError(
427
- "Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
428
- );
454
+ throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
455
+ }
456
+ };
457
+ var HFInferenceFillMaskTask = class extends HFInferenceTask {
458
+ async getResponse(response) {
459
+ if (Array.isArray(response) && response.every(
460
+ (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
461
+ )) {
462
+ return response;
429
463
  }
464
+ throw new InferenceOutputError(
465
+ "Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
466
+ );
430
467
  }
431
468
  };
432
- var FalAIAutomaticSpeechRecognitionTask = class extends FalAITask {
433
- prepareHeaders(params, binary) {
434
- const headers = super.prepareHeaders(params, binary);
435
- headers["Content-Type"] = "application/json";
436
- return headers;
469
+ var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask {
470
+ async getResponse(response) {
471
+ if (Array.isArray(response) && response.every(
472
+ (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
473
+ )) {
474
+ return response;
475
+ }
476
+ throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
437
477
  }
478
+ };
479
+ var HFInferenceSentenceSimilarityTask = class extends HFInferenceTask {
438
480
  async getResponse(response) {
439
- const res = response;
440
- if (typeof res?.text !== "string") {
441
- throw new InferenceOutputError(
442
- `Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
443
- );
481
+ if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
482
+ return response;
444
483
  }
445
- return { text: res.text };
484
+ throw new InferenceOutputError("Expected Array<number>");
446
485
  }
447
486
  };
448
- var FalAITextToSpeechTask = class extends FalAITask {
449
- preparePayload(params) {
450
- return {
451
- ...omit(params.args, ["inputs", "parameters"]),
452
- ...params.args.parameters,
453
- text: params.args.inputs
454
- };
487
+ var HFInferenceTableQuestionAnsweringTask = class extends HFInferenceTask {
488
+ static validate(elem) {
489
+ return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
490
+ (coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
491
+ );
455
492
  }
456
493
  async getResponse(response) {
457
- const res = response;
458
- if (typeof res?.audio?.url !== "string") {
459
- throw new InferenceOutputError(
460
- `Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
461
- );
494
+ if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) {
495
+ return Array.isArray(response) ? response[0] : response;
462
496
  }
463
- try {
464
- const urlResponse = await fetch(res.audio.url);
465
- if (!urlResponse.ok) {
466
- throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
467
- }
468
- return await urlResponse.blob();
469
- } catch (error) {
470
- throw new InferenceOutputError(
471
- `Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)}`
472
- );
497
+ throw new InferenceOutputError(
498
+ "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
499
+ );
500
+ }
501
+ };
502
+ var HFInferenceTokenClassificationTask = class extends HFInferenceTask {
503
+ async getResponse(response) {
504
+ if (Array.isArray(response) && response.every(
505
+ (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
506
+ )) {
507
+ return response;
473
508
  }
509
+ throw new InferenceOutputError(
510
+ "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
511
+ );
474
512
  }
475
513
  };
476
-
477
- // src/providers/featherless-ai.ts
478
- var FEATHERLESS_API_BASE_URL = "https://api.featherless.ai";
479
- var FeatherlessAIConversationalTask = class extends BaseConversationalTask {
480
- constructor() {
481
- super("featherless-ai", FEATHERLESS_API_BASE_URL);
514
+ var HFInferenceTranslationTask = class extends HFInferenceTask {
515
+ async getResponse(response) {
516
+ if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
517
+ return response?.length === 1 ? response?.[0] : response;
518
+ }
519
+ throw new InferenceOutputError("Expected Array<{translation_text: string}>");
482
520
  }
483
521
  };
484
- var FeatherlessAITextGenerationTask = class extends BaseTextGenerationTask {
485
- constructor() {
486
- super("featherless-ai", FEATHERLESS_API_BASE_URL);
522
+ var HFInferenceSummarizationTask = class extends HFInferenceTask {
523
+ async getResponse(response) {
524
+ if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
525
+ return response?.[0];
526
+ }
527
+ throw new InferenceOutputError("Expected Array<{summary_text: string}>");
487
528
  }
488
- preparePayload(params) {
489
- return {
490
- ...params.args,
491
- ...params.args.parameters,
492
- model: params.model,
493
- prompt: params.args.inputs
494
- };
529
+ };
530
+ var HFInferenceTextToSpeechTask = class extends HFInferenceTask {
531
+ async getResponse(response) {
532
+ return response;
495
533
  }
534
+ };
535
+ var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
496
536
  async getResponse(response) {
497
- if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
498
- const completion = response.choices[0];
499
- return {
500
- generated_text: completion.text
501
- };
537
+ if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
538
+ return response;
502
539
  }
503
- throw new InferenceOutputError("Expected Featherless AI text generation response format");
540
+ throw new InferenceOutputError("Expected Array<number>");
504
541
  }
505
542
  };
506
-
507
- // src/providers/fireworks-ai.ts
508
- var FireworksConversationalTask = class extends BaseConversationalTask {
509
- constructor() {
510
- super("fireworks-ai", "https://api.fireworks.ai");
543
+ var HFInferenceVisualQuestionAnsweringTask = class extends HFInferenceTask {
544
+ async getResponse(response) {
545
+ if (Array.isArray(response) && response.every(
546
+ (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
547
+ )) {
548
+ return response[0];
549
+ }
550
+ throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
511
551
  }
512
- makeRoute() {
513
- return "/inference/v1/chat/completions";
552
+ };
553
+ var HFInferenceTabularRegressionTask = class extends HFInferenceTask {
554
+ async getResponse(response) {
555
+ if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
556
+ return response;
557
+ }
558
+ throw new InferenceOutputError("Expected Array<number>");
559
+ }
560
+ };
561
+ var HFInferenceTextToAudioTask = class extends HFInferenceTask {
562
+ async getResponse(response) {
563
+ return response;
514
564
  }
515
565
  };
516
566
 
517
- // src/providers/groq.ts
518
- var GROQ_API_BASE_URL = "https://api.groq.com";
519
- var GroqTextGenerationTask = class extends BaseTextGenerationTask {
520
- constructor() {
521
- super("groq", GROQ_API_BASE_URL);
567
+ // src/utils/typedInclude.ts
568
+ function typedInclude(arr, v) {
569
+ return arr.includes(v);
570
+ }
571
+
572
+ // src/lib/getInferenceProviderMapping.ts
573
+ var inferenceProviderMappingCache = /* @__PURE__ */ new Map();
574
+ async function fetchInferenceProviderMappingForModel(modelId, accessToken, options) {
575
+ let inferenceProviderMapping;
576
+ if (inferenceProviderMappingCache.has(modelId)) {
577
+ inferenceProviderMapping = inferenceProviderMappingCache.get(modelId);
578
+ } else {
579
+ const resp = await (options?.fetch ?? fetch)(
580
+ `${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
581
+ {
582
+ headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {}
583
+ }
584
+ );
585
+ if (resp.status === 404) {
586
+ throw new Error(`Model ${modelId} does not exist`);
587
+ }
588
+ inferenceProviderMapping = await resp.json().then((json) => json.inferenceProviderMapping).catch(() => null);
589
+ if (inferenceProviderMapping) {
590
+ inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
591
+ }
522
592
  }
523
- makeRoute() {
524
- return "/openai/v1/chat/completions";
593
+ if (!inferenceProviderMapping) {
594
+ throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
525
595
  }
526
- };
527
- var GroqConversationalTask = class extends BaseConversationalTask {
528
- constructor() {
529
- super("groq", GROQ_API_BASE_URL);
596
+ return inferenceProviderMapping;
597
+ }
598
+ async function getInferenceProviderMapping(params, options) {
599
+ if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
600
+ return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
530
601
  }
531
- makeRoute() {
532
- return "/openai/v1/chat/completions";
602
+ const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
603
+ params.modelId,
604
+ params.accessToken,
605
+ options
606
+ );
607
+ const providerMapping = inferenceProviderMapping[params.provider];
608
+ if (providerMapping) {
609
+ const equivalentTasks = params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task];
610
+ if (!typedInclude(equivalentTasks, providerMapping.task)) {
611
+ throw new Error(
612
+ `Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
613
+ );
614
+ }
615
+ if (providerMapping.status === "staging") {
616
+ console.warn(
617
+ `Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
618
+ );
619
+ }
620
+ return { ...providerMapping, hfModelId: params.modelId };
533
621
  }
534
- };
622
+ return null;
623
+ }
624
+ async function resolveProvider(provider, modelId, endpointUrl) {
625
+ if (endpointUrl) {
626
+ if (provider) {
627
+ throw new Error("Specifying both endpointUrl and provider is not supported.");
628
+ }
629
+ return "hf-inference";
630
+ }
631
+ if (!provider) {
632
+ console.log(
633
+ "Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
634
+ );
635
+ provider = "auto";
636
+ }
637
+ if (provider === "auto") {
638
+ if (!modelId) {
639
+ throw new Error("Specifying a model is required when provider is 'auto'");
640
+ }
641
+ const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
642
+ provider = Object.keys(inferenceProviderMapping)[0];
643
+ }
644
+ if (!provider) {
645
+ throw new Error(`No Inference Provider available for model ${modelId}.`);
646
+ }
647
+ return provider;
648
+ }
535
649
 
536
- // src/providers/hf-inference.ts
537
- var EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = ["feature-extraction", "sentence-similarity"];
538
- var HFInferenceTask = class extends TaskProviderHelper {
650
+ // src/utils/delay.ts
651
+ function delay(ms) {
652
+ return new Promise((resolve) => {
653
+ setTimeout(() => resolve(), ms);
654
+ });
655
+ }
656
+
657
+ // src/utils/pick.ts
658
+ function pick(o, props) {
659
+ return Object.assign(
660
+ {},
661
+ ...props.map((prop) => {
662
+ if (o[prop] !== void 0) {
663
+ return { [prop]: o[prop] };
664
+ }
665
+ })
666
+ );
667
+ }
668
+
669
+ // src/utils/omit.ts
670
+ function omit(o, props) {
671
+ const propsArr = Array.isArray(props) ? props : [props];
672
+ const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
673
+ return pick(o, letsKeep);
674
+ }
675
+
676
+ // src/providers/black-forest-labs.ts
677
+ var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
678
+ var BlackForestLabsTextToImageTask = class extends TaskProviderHelper {
539
679
  constructor() {
540
- super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
680
+ super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
541
681
  }
542
682
  preparePayload(params) {
543
- return params.args;
683
+ return {
684
+ ...omit(params.args, ["inputs", "parameters"]),
685
+ ...params.args.parameters,
686
+ prompt: params.args.inputs
687
+ };
544
688
  }
545
- makeUrl(params) {
546
- if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
547
- return params.model;
689
+ prepareHeaders(params, binary) {
690
+ const headers = {
691
+ Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`
692
+ };
693
+ if (!binary) {
694
+ headers["Content-Type"] = "application/json";
548
695
  }
549
- return super.makeUrl(params);
696
+ return headers;
550
697
  }
551
698
  makeRoute(params) {
552
- if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
553
- return `pipeline/${params.task}/${params.model}`;
699
+ if (!params) {
700
+ throw new Error("Params are required");
554
701
  }
555
- return `models/${params.model}`;
556
- }
557
- async getResponse(response) {
558
- return response;
702
+ return `/v1/${params.model}`;
559
703
  }
560
- };
561
- var HFInferenceTextToImageTask = class extends HFInferenceTask {
562
704
  async getResponse(response, url, headers, outputType) {
563
- if (!response) {
564
- throw new InferenceOutputError("response is undefined");
565
- }
566
- if (typeof response == "object") {
567
- if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
568
- const base64Data = response.data[0].b64_json;
569
- if (outputType === "url") {
570
- return `data:image/jpeg;base64,${base64Data}`;
571
- }
572
- const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
573
- return await base64Response.blob();
705
+ const urlObj = new URL(response.polling_url);
706
+ for (let step = 0; step < 5; step++) {
707
+ await delay(1e3);
708
+ console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
709
+ urlObj.searchParams.set("attempt", step.toString(10));
710
+ const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
711
+ if (!resp.ok) {
712
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
574
713
  }
575
- if ("output" in response && Array.isArray(response.output)) {
714
+ const payload = await resp.json();
715
+ if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
576
716
  if (outputType === "url") {
577
- return response.output[0];
717
+ return payload.result.sample;
578
718
  }
579
- const urlResponse = await fetch(response.output[0]);
580
- const blob = await urlResponse.blob();
581
- return blob;
582
- }
583
- }
584
- if (response instanceof Blob) {
585
- if (outputType === "url") {
586
- const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
587
- return `data:image/jpeg;base64,${b64}`;
719
+ const image = await fetch(payload.result.sample);
720
+ return await image.blob();
588
721
  }
589
- return response;
590
722
  }
591
- throw new InferenceOutputError("Expected a Blob ");
723
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
592
724
  }
593
725
  };
594
- var HFInferenceConversationalTask = class extends HFInferenceTask {
595
- makeUrl(params) {
596
- let url;
597
- if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
598
- url = params.model.trim();
599
- } else {
600
- url = `${this.makeBaseUrl(params)}/models/${params.model}`;
601
- }
602
- url = url.replace(/\/+$/, "");
603
- if (url.endsWith("/v1")) {
604
- url += "/chat/completions";
605
- } else if (!url.endsWith("/chat/completions")) {
606
- url += "/v1/chat/completions";
607
- }
608
- return url;
726
+
727
+ // src/providers/cerebras.ts
728
+ var CerebrasConversationalTask = class extends BaseConversationalTask {
729
+ constructor() {
730
+ super("cerebras", "https://api.cerebras.ai");
609
731
  }
610
- preparePayload(params) {
611
- return {
612
- ...params.args,
613
- model: params.model
614
- };
732
+ };
733
+
734
+ // src/providers/cohere.ts
735
+ var CohereConversationalTask = class extends BaseConversationalTask {
736
+ constructor() {
737
+ super("cohere", "https://api.cohere.com");
615
738
  }
616
- async getResponse(response) {
617
- return response;
739
+ makeRoute() {
740
+ return "/compatibility/v1/chat/completions";
618
741
  }
619
742
  };
620
- var HFInferenceTextGenerationTask = class extends HFInferenceTask {
621
- async getResponse(response) {
622
- const res = toArray(response);
623
- if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
624
- return res?.[0];
625
- }
626
- throw new InferenceOutputError("Expected Array<{generated_text: string}>");
743
+
744
+ // src/lib/isUrl.ts
745
+ function isUrl(modelOrUrl) {
746
+ return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
747
+ }
748
+
749
+ // src/providers/fal-ai.ts
750
+ var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
751
+ var FalAITask = class extends TaskProviderHelper {
752
+ constructor(url) {
753
+ super("fal-ai", url || "https://fal.run");
627
754
  }
628
- };
629
- var HFInferenceAudioClassificationTask = class extends HFInferenceTask {
630
- async getResponse(response) {
631
- if (Array.isArray(response) && response.every(
632
- (x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
633
- )) {
634
- return response;
635
- }
636
- throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
755
+ preparePayload(params) {
756
+ return params.args;
637
757
  }
638
- };
639
- var HFInferenceAutomaticSpeechRecognitionTask = class extends HFInferenceTask {
640
- async getResponse(response) {
641
- return response;
758
+ makeRoute(params) {
759
+ return `/${params.model}`;
642
760
  }
643
- };
644
- var HFInferenceAudioToAudioTask = class extends HFInferenceTask {
645
- async getResponse(response) {
646
- if (!Array.isArray(response)) {
647
- throw new InferenceOutputError("Expected Array");
648
- }
649
- if (!response.every((elem) => {
650
- return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
651
- })) {
652
- throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
761
+ prepareHeaders(params, binary) {
762
+ const headers = {
763
+ Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`
764
+ };
765
+ if (!binary) {
766
+ headers["Content-Type"] = "application/json";
653
767
  }
654
- return response;
768
+ return headers;
655
769
  }
656
770
  };
657
- var HFInferenceDocumentQuestionAnsweringTask = class extends HFInferenceTask {
658
- async getResponse(response) {
659
- if (Array.isArray(response) && response.every(
660
- (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
661
- )) {
662
- return response[0];
771
+ function buildLoraPath(modelId, adapterWeightsPath) {
772
+ return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
773
+ }
774
+ var FalAITextToImageTask = class extends FalAITask {
775
+ preparePayload(params) {
776
+ const payload = {
777
+ ...omit(params.args, ["inputs", "parameters"]),
778
+ ...params.args.parameters,
779
+ sync_mode: true,
780
+ prompt: params.args.inputs
781
+ };
782
+ if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
783
+ payload.loras = [
784
+ {
785
+ path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
786
+ scale: 1
787
+ }
788
+ ];
789
+ if (params.mapping.providerId === "fal-ai/lora") {
790
+ payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
791
+ }
663
792
  }
664
- throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
793
+ return payload;
665
794
  }
666
- };
667
- var HFInferenceFeatureExtractionTask = class extends HFInferenceTask {
668
- async getResponse(response) {
669
- const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
670
- if (curDepth > maxDepth)
671
- return false;
672
- if (arr.every((x) => Array.isArray(x))) {
673
- return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
674
- } else {
675
- return arr.every((x) => typeof x === "number");
795
+ async getResponse(response, outputType) {
796
+ if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") {
797
+ if (outputType === "url") {
798
+ return response.images[0].url;
676
799
  }
677
- };
678
- if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
679
- return response;
800
+ const urlResponse = await fetch(response.images[0].url);
801
+ return await urlResponse.blob();
680
802
  }
681
- throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
803
+ throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
682
804
  }
683
805
  };
684
- var HFInferenceImageClassificationTask = class extends HFInferenceTask {
685
- async getResponse(response) {
686
- if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
687
- return response;
688
- }
689
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
806
+ var FalAITextToVideoTask = class extends FalAITask {
807
+ constructor() {
808
+ super("https://queue.fal.run");
690
809
  }
691
- };
692
- var HFInferenceImageSegmentationTask = class extends HFInferenceTask {
693
- async getResponse(response) {
694
- if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) {
695
- return response;
810
+ makeRoute(params) {
811
+ if (params.authMethod !== "provider-key") {
812
+ return `/${params.model}?_subdomain=queue`;
696
813
  }
697
- throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
814
+ return `/${params.model}`;
698
815
  }
699
- };
700
- var HFInferenceImageToTextTask = class extends HFInferenceTask {
701
- async getResponse(response) {
702
- if (typeof response?.generated_text !== "string") {
703
- throw new InferenceOutputError("Expected {generated_text: string}");
704
- }
705
- return response;
816
+ preparePayload(params) {
817
+ return {
818
+ ...omit(params.args, ["inputs", "parameters"]),
819
+ ...params.args.parameters,
820
+ prompt: params.args.inputs
821
+ };
706
822
  }
707
- };
708
- var HFInferenceImageToImageTask = class extends HFInferenceTask {
709
- async getResponse(response) {
710
- if (response instanceof Blob) {
711
- return response;
823
+ async getResponse(response, url, headers) {
824
+ if (!url || !headers) {
825
+ throw new InferenceOutputError("URL and headers are required for text-to-video task");
712
826
  }
713
- throw new InferenceOutputError("Expected Blob");
714
- }
715
- };
716
- var HFInferenceObjectDetectionTask = class extends HFInferenceTask {
717
- async getResponse(response) {
718
- if (Array.isArray(response) && response.every(
719
- (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
720
- )) {
721
- return response;
827
+ const requestId = response.request_id;
828
+ if (!requestId) {
829
+ throw new InferenceOutputError("No request ID found in the response");
722
830
  }
723
- throw new InferenceOutputError(
724
- "Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
725
- );
726
- }
727
- };
728
- var HFInferenceZeroShotImageClassificationTask = class extends HFInferenceTask {
729
- async getResponse(response) {
730
- if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
731
- return response;
831
+ let status = response.status;
832
+ const parsedUrl = new URL(url);
833
+ const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
834
+ const modelId = new URL(response.response_url).pathname;
835
+ const queryParams = parsedUrl.search;
836
+ const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
837
+ const resultUrl = `${baseUrl}${modelId}${queryParams}`;
838
+ while (status !== "COMPLETED") {
839
+ await delay(500);
840
+ const statusResponse = await fetch(statusUrl, { headers });
841
+ if (!statusResponse.ok) {
842
+ throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
843
+ }
844
+ try {
845
+ status = (await statusResponse.json()).status;
846
+ } catch (error) {
847
+ throw new InferenceOutputError("Failed to parse status response from fal-ai API");
848
+ }
849
+ }
850
+ const resultResponse = await fetch(resultUrl, { headers });
851
+ let result;
852
+ try {
853
+ result = await resultResponse.json();
854
+ } catch (error) {
855
+ throw new InferenceOutputError("Failed to parse result response from fal-ai API");
856
+ }
857
+ if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
858
+ const urlResponse = await fetch(result.video.url);
859
+ return await urlResponse.blob();
860
+ } else {
861
+ throw new InferenceOutputError(
862
+ "Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
863
+ );
732
864
  }
733
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
734
865
  }
735
866
  };
736
- var HFInferenceTextClassificationTask = class extends HFInferenceTask {
737
- async getResponse(response) {
738
- const output = response?.[0];
739
- if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
740
- return output;
741
- }
742
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
867
+ var FalAIAutomaticSpeechRecognitionTask = class extends FalAITask {
868
+ prepareHeaders(params, binary) {
869
+ const headers = super.prepareHeaders(params, binary);
870
+ headers["Content-Type"] = "application/json";
871
+ return headers;
743
872
  }
744
- };
745
- var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask {
746
873
  async getResponse(response) {
747
- if (Array.isArray(response) ? response.every(
748
- (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
749
- ) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number") {
750
- return Array.isArray(response) ? response[0] : response;
874
+ const res = response;
875
+ if (typeof res?.text !== "string") {
876
+ throw new InferenceOutputError(
877
+ `Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
878
+ );
751
879
  }
752
- throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
880
+ return { text: res.text };
753
881
  }
754
882
  };
755
- var HFInferenceFillMaskTask = class extends HFInferenceTask {
756
- async getResponse(response) {
757
- if (Array.isArray(response) && response.every(
758
- (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
759
- )) {
760
- return response;
761
- }
762
- throw new InferenceOutputError(
763
- "Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
764
- );
883
+ var FalAITextToSpeechTask = class extends FalAITask {
884
+ preparePayload(params) {
885
+ return {
886
+ ...omit(params.args, ["inputs", "parameters"]),
887
+ ...params.args.parameters,
888
+ text: params.args.inputs
889
+ };
765
890
  }
766
- };
767
- var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask {
768
891
  async getResponse(response) {
769
- if (Array.isArray(response) && response.every(
770
- (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
771
- )) {
772
- return response;
892
+ const res = response;
893
+ if (typeof res?.audio?.url !== "string") {
894
+ throw new InferenceOutputError(
895
+ `Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
896
+ );
773
897
  }
774
- throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
775
- }
776
- };
777
- var HFInferenceSentenceSimilarityTask = class extends HFInferenceTask {
778
- async getResponse(response) {
779
- if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
780
- return response;
898
+ try {
899
+ const urlResponse = await fetch(res.audio.url);
900
+ if (!urlResponse.ok) {
901
+ throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
902
+ }
903
+ return await urlResponse.blob();
904
+ } catch (error) {
905
+ throw new InferenceOutputError(
906
+ `Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)}`
907
+ );
781
908
  }
782
- throw new InferenceOutputError("Expected Array<number>");
783
909
  }
784
910
  };
785
- var HFInferenceTableQuestionAnsweringTask = class extends HFInferenceTask {
786
- static validate(elem) {
787
- return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
788
- (coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
789
- );
790
- }
791
- async getResponse(response) {
792
- if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) {
793
- return Array.isArray(response) ? response[0] : response;
794
- }
795
- throw new InferenceOutputError(
796
- "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
797
- );
911
+
912
+ // src/providers/featherless-ai.ts
913
+ var FEATHERLESS_API_BASE_URL = "https://api.featherless.ai";
914
+ var FeatherlessAIConversationalTask = class extends BaseConversationalTask {
915
+ constructor() {
916
+ super("featherless-ai", FEATHERLESS_API_BASE_URL);
798
917
  }
799
918
  };
800
- var HFInferenceTokenClassificationTask = class extends HFInferenceTask {
801
- async getResponse(response) {
802
- if (Array.isArray(response) && response.every(
803
- (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
804
- )) {
805
- return response;
806
- }
807
- throw new InferenceOutputError(
808
- "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
809
- );
919
+ var FeatherlessAITextGenerationTask = class extends BaseTextGenerationTask {
920
+ constructor() {
921
+ super("featherless-ai", FEATHERLESS_API_BASE_URL);
810
922
  }
811
- };
812
- var HFInferenceTranslationTask = class extends HFInferenceTask {
813
- async getResponse(response) {
814
- if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
815
- return response?.length === 1 ? response?.[0] : response;
816
- }
817
- throw new InferenceOutputError("Expected Array<{translation_text: string}>");
923
+ preparePayload(params) {
924
+ return {
925
+ ...params.args,
926
+ ...params.args.parameters,
927
+ model: params.model,
928
+ prompt: params.args.inputs
929
+ };
818
930
  }
819
- };
820
- var HFInferenceSummarizationTask = class extends HFInferenceTask {
821
931
  async getResponse(response) {
822
- if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
823
- return response?.[0];
932
+ if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
933
+ const completion = response.choices[0];
934
+ return {
935
+ generated_text: completion.text
936
+ };
824
937
  }
825
- throw new InferenceOutputError("Expected Array<{summary_text: string}>");
938
+ throw new InferenceOutputError("Expected Featherless AI text generation response format");
826
939
  }
827
940
  };
828
- var HFInferenceTextToSpeechTask = class extends HFInferenceTask {
829
- async getResponse(response) {
830
- return response;
941
+
942
+ // src/providers/fireworks-ai.ts
943
+ var FireworksConversationalTask = class extends BaseConversationalTask {
944
+ constructor() {
945
+ super("fireworks-ai", "https://api.fireworks.ai");
831
946
  }
832
- };
833
- var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
834
- async getResponse(response) {
835
- if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
836
- return response;
837
- }
838
- throw new InferenceOutputError("Expected Array<number>");
947
+ makeRoute() {
948
+ return "/inference/v1/chat/completions";
839
949
  }
840
950
  };
841
- var HFInferenceVisualQuestionAnsweringTask = class extends HFInferenceTask {
842
- async getResponse(response) {
843
- if (Array.isArray(response) && response.every(
844
- (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
845
- )) {
846
- return response[0];
847
- }
848
- throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
951
+
952
+ // src/providers/groq.ts
953
+ var GROQ_API_BASE_URL = "https://api.groq.com";
954
+ var GroqTextGenerationTask = class extends BaseTextGenerationTask {
955
+ constructor() {
956
+ super("groq", GROQ_API_BASE_URL);
849
957
  }
850
- };
851
- var HFInferenceTabularRegressionTask = class extends HFInferenceTask {
852
- async getResponse(response) {
853
- if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
854
- return response;
855
- }
856
- throw new InferenceOutputError("Expected Array<number>");
958
+ makeRoute() {
959
+ return "/openai/v1/chat/completions";
857
960
  }
858
961
  };
859
- var HFInferenceTextToAudioTask = class extends HFInferenceTask {
860
- async getResponse(response) {
861
- return response;
962
+ var GroqConversationalTask = class extends BaseConversationalTask {
963
+ constructor() {
964
+ super("groq", GROQ_API_BASE_URL);
965
+ }
966
+ makeRoute() {
967
+ return "/openai/v1/chat/completions";
862
968
  }
863
969
  };
864
970
 
@@ -1352,82 +1458,13 @@ function getProviderHelper(provider, task) {
1352
1458
 
1353
1459
  // package.json
1354
1460
  var name = "@huggingface/inference";
1355
- var version = "3.10.0";
1356
-
1357
- // src/providers/consts.ts
1358
- var HARDCODED_MODEL_INFERENCE_MAPPING = {
1359
- /**
1360
- * "HF model ID" => "Model ID on Inference Provider's side"
1361
- *
1362
- * Example:
1363
- * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
1364
- */
1365
- "black-forest-labs": {},
1366
- cerebras: {},
1367
- cohere: {},
1368
- "fal-ai": {},
1369
- "featherless-ai": {},
1370
- "fireworks-ai": {},
1371
- groq: {},
1372
- "hf-inference": {},
1373
- hyperbolic: {},
1374
- nebius: {},
1375
- novita: {},
1376
- nscale: {},
1377
- openai: {},
1378
- ovhcloud: {},
1379
- replicate: {},
1380
- sambanova: {},
1381
- together: {}
1382
- };
1383
-
1384
- // src/lib/getInferenceProviderMapping.ts
1385
- var inferenceProviderMappingCache = /* @__PURE__ */ new Map();
1386
- async function getInferenceProviderMapping(params, options) {
1387
- if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
1388
- return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
1389
- }
1390
- let inferenceProviderMapping;
1391
- if (inferenceProviderMappingCache.has(params.modelId)) {
1392
- inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId);
1393
- } else {
1394
- const resp = await (options?.fetch ?? fetch)(
1395
- `${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`,
1396
- {
1397
- headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {}
1398
- }
1399
- );
1400
- if (resp.status === 404) {
1401
- throw new Error(`Model ${params.modelId} does not exist`);
1402
- }
1403
- inferenceProviderMapping = await resp.json().then((json) => json.inferenceProviderMapping).catch(() => null);
1404
- }
1405
- if (!inferenceProviderMapping) {
1406
- throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`);
1407
- }
1408
- const providerMapping = inferenceProviderMapping[params.provider];
1409
- if (providerMapping) {
1410
- const equivalentTasks = params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task];
1411
- if (!typedInclude(equivalentTasks, providerMapping.task)) {
1412
- throw new Error(
1413
- `Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
1414
- );
1415
- }
1416
- if (providerMapping.status === "staging") {
1417
- console.warn(
1418
- `Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
1419
- );
1420
- }
1421
- return { ...providerMapping, hfModelId: params.modelId };
1422
- }
1423
- return null;
1424
- }
1461
+ var version = "3.12.0";
1425
1462
 
1426
1463
  // src/lib/makeRequestOptions.ts
1427
1464
  var tasks = null;
1428
1465
  async function makeRequestOptions(args, providerHelper, options) {
1429
- const { provider: maybeProvider, model: maybeModel } = args;
1430
- const provider = maybeProvider ?? "hf-inference";
1466
+ const { model: maybeModel } = args;
1467
+ const provider = providerHelper.provider;
1431
1468
  const { task } = options ?? {};
1432
1469
  if (args.endpointUrl && provider !== "hf-inference") {
1433
1470
  throw new Error(`Cannot use endpointUrl with a third-party provider.`);
@@ -1482,7 +1519,7 @@ async function makeRequestOptions(args, providerHelper, options) {
1482
1519
  }
1483
1520
  function makeRequestOptionsFromResolvedModel(resolvedModel, providerHelper, args, mapping, options) {
1484
1521
  const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
1485
- const provider = maybeProvider ?? "hf-inference";
1522
+ const provider = providerHelper.provider;
1486
1523
  const { includeCredentials, task, signal, billTo } = options ?? {};
1487
1524
  const authMethod = (() => {
1488
1525
  if (providerHelper.clientSideRoutingOnly) {
@@ -1773,7 +1810,8 @@ async function request(args, options) {
1773
1810
  console.warn(
1774
1811
  "The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1775
1812
  );
1776
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
1813
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1814
+ const providerHelper = getProviderHelper(provider, options?.task);
1777
1815
  const result = await innerRequest(args, providerHelper, options);
1778
1816
  return result.data;
1779
1817
  }
@@ -1783,7 +1821,8 @@ async function* streamingRequest(args, options) {
1783
1821
  console.warn(
1784
1822
  "The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1785
1823
  );
1786
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
1824
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1825
+ const providerHelper = getProviderHelper(provider, options?.task);
1787
1826
  yield* innerStreamingRequest(args, providerHelper, options);
1788
1827
  }
1789
1828
 
@@ -1797,7 +1836,8 @@ function preparePayload(args) {
1797
1836
 
1798
1837
  // src/tasks/audio/audioClassification.ts
1799
1838
  async function audioClassification(args, options) {
1800
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
1839
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1840
+ const providerHelper = getProviderHelper(provider, "audio-classification");
1801
1841
  const payload = preparePayload(args);
1802
1842
  const { data: res } = await innerRequest(payload, providerHelper, {
1803
1843
  ...options,
@@ -1808,7 +1848,9 @@ async function audioClassification(args, options) {
1808
1848
 
1809
1849
  // src/tasks/audio/audioToAudio.ts
1810
1850
  async function audioToAudio(args, options) {
1811
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
1851
+ const model = "inputs" in args ? args.model : void 0;
1852
+ const provider = await resolveProvider(args.provider, model);
1853
+ const providerHelper = getProviderHelper(provider, "audio-to-audio");
1812
1854
  const payload = preparePayload(args);
1813
1855
  const { data: res } = await innerRequest(payload, providerHelper, {
1814
1856
  ...options,
@@ -1832,7 +1874,8 @@ function base64FromBytes(arr) {
1832
1874
 
1833
1875
  // src/tasks/audio/automaticSpeechRecognition.ts
1834
1876
  async function automaticSpeechRecognition(args, options) {
1835
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
1877
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1878
+ const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
1836
1879
  const payload = await buildPayload(args);
1837
1880
  const { data: res } = await innerRequest(payload, providerHelper, {
1838
1881
  ...options,
@@ -1872,7 +1915,7 @@ async function buildPayload(args) {
1872
1915
 
1873
1916
  // src/tasks/audio/textToSpeech.ts
1874
1917
  async function textToSpeech(args, options) {
1875
- const provider = args.provider ?? "hf-inference";
1918
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1876
1919
  const providerHelper = getProviderHelper(provider, "text-to-speech");
1877
1920
  const { data: res } = await innerRequest(args, providerHelper, {
1878
1921
  ...options,
@@ -1888,7 +1931,8 @@ function preparePayload2(args) {
1888
1931
 
1889
1932
  // src/tasks/cv/imageClassification.ts
1890
1933
  async function imageClassification(args, options) {
1891
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
1934
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1935
+ const providerHelper = getProviderHelper(provider, "image-classification");
1892
1936
  const payload = preparePayload2(args);
1893
1937
  const { data: res } = await innerRequest(payload, providerHelper, {
1894
1938
  ...options,
@@ -1899,7 +1943,8 @@ async function imageClassification(args, options) {
1899
1943
 
1900
1944
  // src/tasks/cv/imageSegmentation.ts
1901
1945
  async function imageSegmentation(args, options) {
1902
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
1946
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1947
+ const providerHelper = getProviderHelper(provider, "image-segmentation");
1903
1948
  const payload = preparePayload2(args);
1904
1949
  const { data: res } = await innerRequest(payload, providerHelper, {
1905
1950
  ...options,
@@ -1910,7 +1955,8 @@ async function imageSegmentation(args, options) {
1910
1955
 
1911
1956
  // src/tasks/cv/imageToImage.ts
1912
1957
  async function imageToImage(args, options) {
1913
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
1958
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1959
+ const providerHelper = getProviderHelper(provider, "image-to-image");
1914
1960
  let reqArgs;
1915
1961
  if (!args.parameters) {
1916
1962
  reqArgs = {
@@ -1935,7 +1981,8 @@ async function imageToImage(args, options) {
1935
1981
 
1936
1982
  // src/tasks/cv/imageToText.ts
1937
1983
  async function imageToText(args, options) {
1938
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
1984
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1985
+ const providerHelper = getProviderHelper(provider, "image-to-text");
1939
1986
  const payload = preparePayload2(args);
1940
1987
  const { data: res } = await innerRequest(payload, providerHelper, {
1941
1988
  ...options,
@@ -1946,7 +1993,8 @@ async function imageToText(args, options) {
1946
1993
 
1947
1994
  // src/tasks/cv/objectDetection.ts
1948
1995
  async function objectDetection(args, options) {
1949
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
1996
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1997
+ const providerHelper = getProviderHelper(provider, "object-detection");
1950
1998
  const payload = preparePayload2(args);
1951
1999
  const { data: res } = await innerRequest(payload, providerHelper, {
1952
2000
  ...options,
@@ -1957,7 +2005,7 @@ async function objectDetection(args, options) {
1957
2005
 
1958
2006
  // src/tasks/cv/textToImage.ts
1959
2007
  async function textToImage(args, options) {
1960
- const provider = args.provider ?? "hf-inference";
2008
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1961
2009
  const providerHelper = getProviderHelper(provider, "text-to-image");
1962
2010
  const { data: res } = await innerRequest(args, providerHelper, {
1963
2011
  ...options,
@@ -1969,7 +2017,7 @@ async function textToImage(args, options) {
1969
2017
 
1970
2018
  // src/tasks/cv/textToVideo.ts
1971
2019
  async function textToVideo(args, options) {
1972
- const provider = args.provider ?? "hf-inference";
2020
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1973
2021
  const providerHelper = getProviderHelper(provider, "text-to-video");
1974
2022
  const { data: response } = await innerRequest(
1975
2023
  args,
@@ -2006,7 +2054,8 @@ async function preparePayload3(args) {
2006
2054
  }
2007
2055
  }
2008
2056
  async function zeroShotImageClassification(args, options) {
2009
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
2057
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2058
+ const providerHelper = getProviderHelper(provider, "zero-shot-image-classification");
2010
2059
  const payload = await preparePayload3(args);
2011
2060
  const { data: res } = await innerRequest(payload, providerHelper, {
2012
2061
  ...options,
@@ -2017,7 +2066,8 @@ async function zeroShotImageClassification(args, options) {
2017
2066
 
2018
2067
  // src/tasks/nlp/chatCompletion.ts
2019
2068
  async function chatCompletion(args, options) {
2020
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
2069
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2070
+ const providerHelper = getProviderHelper(provider, "conversational");
2021
2071
  const { data: response } = await innerRequest(args, providerHelper, {
2022
2072
  ...options,
2023
2073
  task: "conversational"
@@ -2027,7 +2077,8 @@ async function chatCompletion(args, options) {
2027
2077
 
2028
2078
  // src/tasks/nlp/chatCompletionStream.ts
2029
2079
  async function* chatCompletionStream(args, options) {
2030
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
2080
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2081
+ const providerHelper = getProviderHelper(provider, "conversational");
2031
2082
  yield* innerStreamingRequest(args, providerHelper, {
2032
2083
  ...options,
2033
2084
  task: "conversational"
@@ -2036,7 +2087,8 @@ async function* chatCompletionStream(args, options) {
2036
2087
 
2037
2088
  // src/tasks/nlp/featureExtraction.ts
2038
2089
  async function featureExtraction(args, options) {
2039
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction");
2090
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2091
+ const providerHelper = getProviderHelper(provider, "feature-extraction");
2040
2092
  const { data: res } = await innerRequest(args, providerHelper, {
2041
2093
  ...options,
2042
2094
  task: "feature-extraction"
@@ -2046,7 +2098,8 @@ async function featureExtraction(args, options) {
2046
2098
 
2047
2099
  // src/tasks/nlp/fillMask.ts
2048
2100
  async function fillMask(args, options) {
2049
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask");
2101
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2102
+ const providerHelper = getProviderHelper(provider, "fill-mask");
2050
2103
  const { data: res } = await innerRequest(args, providerHelper, {
2051
2104
  ...options,
2052
2105
  task: "fill-mask"
@@ -2056,7 +2109,8 @@ async function fillMask(args, options) {
2056
2109
 
2057
2110
  // src/tasks/nlp/questionAnswering.ts
2058
2111
  async function questionAnswering(args, options) {
2059
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering");
2112
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2113
+ const providerHelper = getProviderHelper(provider, "question-answering");
2060
2114
  const { data: res } = await innerRequest(
2061
2115
  args,
2062
2116
  providerHelper,
@@ -2070,7 +2124,8 @@ async function questionAnswering(args, options) {
2070
2124
 
2071
2125
  // src/tasks/nlp/sentenceSimilarity.ts
2072
2126
  async function sentenceSimilarity(args, options) {
2073
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity");
2127
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2128
+ const providerHelper = getProviderHelper(provider, "sentence-similarity");
2074
2129
  const { data: res } = await innerRequest(args, providerHelper, {
2075
2130
  ...options,
2076
2131
  task: "sentence-similarity"
@@ -2080,7 +2135,8 @@ async function sentenceSimilarity(args, options) {
2080
2135
 
2081
2136
  // src/tasks/nlp/summarization.ts
2082
2137
  async function summarization(args, options) {
2083
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization");
2138
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2139
+ const providerHelper = getProviderHelper(provider, "summarization");
2084
2140
  const { data: res } = await innerRequest(args, providerHelper, {
2085
2141
  ...options,
2086
2142
  task: "summarization"
@@ -2090,7 +2146,8 @@ async function summarization(args, options) {
2090
2146
 
2091
2147
  // src/tasks/nlp/tableQuestionAnswering.ts
2092
2148
  async function tableQuestionAnswering(args, options) {
2093
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
2149
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2150
+ const providerHelper = getProviderHelper(provider, "table-question-answering");
2094
2151
  const { data: res } = await innerRequest(
2095
2152
  args,
2096
2153
  providerHelper,
@@ -2104,7 +2161,8 @@ async function tableQuestionAnswering(args, options) {
2104
2161
 
2105
2162
  // src/tasks/nlp/textClassification.ts
2106
2163
  async function textClassification(args, options) {
2107
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
2164
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2165
+ const providerHelper = getProviderHelper(provider, "text-classification");
2108
2166
  const { data: res } = await innerRequest(args, providerHelper, {
2109
2167
  ...options,
2110
2168
  task: "text-classification"
@@ -2114,7 +2172,8 @@ async function textClassification(args, options) {
2114
2172
 
2115
2173
  // src/tasks/nlp/textGeneration.ts
2116
2174
  async function textGeneration(args, options) {
2117
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
2175
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2176
+ const providerHelper = getProviderHelper(provider, "text-generation");
2118
2177
  const { data: response } = await innerRequest(args, providerHelper, {
2119
2178
  ...options,
2120
2179
  task: "text-generation"
@@ -2124,7 +2183,8 @@ async function textGeneration(args, options) {
2124
2183
 
2125
2184
  // src/tasks/nlp/textGenerationStream.ts
2126
2185
  async function* textGenerationStream(args, options) {
2127
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
2186
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2187
+ const providerHelper = getProviderHelper(provider, "text-generation");
2128
2188
  yield* innerStreamingRequest(args, providerHelper, {
2129
2189
  ...options,
2130
2190
  task: "text-generation"
@@ -2133,7 +2193,8 @@ async function* textGenerationStream(args, options) {
2133
2193
 
2134
2194
  // src/tasks/nlp/tokenClassification.ts
2135
2195
  async function tokenClassification(args, options) {
2136
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification");
2196
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2197
+ const providerHelper = getProviderHelper(provider, "token-classification");
2137
2198
  const { data: res } = await innerRequest(
2138
2199
  args,
2139
2200
  providerHelper,
@@ -2147,7 +2208,8 @@ async function tokenClassification(args, options) {
2147
2208
 
2148
2209
  // src/tasks/nlp/translation.ts
2149
2210
  async function translation(args, options) {
2150
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation");
2211
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2212
+ const providerHelper = getProviderHelper(provider, "translation");
2151
2213
  const { data: res } = await innerRequest(args, providerHelper, {
2152
2214
  ...options,
2153
2215
  task: "translation"
@@ -2157,7 +2219,8 @@ async function translation(args, options) {
2157
2219
 
2158
2220
  // src/tasks/nlp/zeroShotClassification.ts
2159
2221
  async function zeroShotClassification(args, options) {
2160
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification");
2222
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2223
+ const providerHelper = getProviderHelper(provider, "zero-shot-classification");
2161
2224
  const { data: res } = await innerRequest(
2162
2225
  args,
2163
2226
  providerHelper,
@@ -2171,7 +2234,8 @@ async function zeroShotClassification(args, options) {
2171
2234
 
2172
2235
  // src/tasks/multimodal/documentQuestionAnswering.ts
2173
2236
  async function documentQuestionAnswering(args, options) {
2174
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering");
2237
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2238
+ const providerHelper = getProviderHelper(provider, "document-question-answering");
2175
2239
  const reqArgs = {
2176
2240
  ...args,
2177
2241
  inputs: {
@@ -2193,7 +2257,8 @@ async function documentQuestionAnswering(args, options) {
2193
2257
 
2194
2258
  // src/tasks/multimodal/visualQuestionAnswering.ts
2195
2259
  async function visualQuestionAnswering(args, options) {
2196
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "visual-question-answering");
2260
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2261
+ const providerHelper = getProviderHelper(provider, "visual-question-answering");
2197
2262
  const reqArgs = {
2198
2263
  ...args,
2199
2264
  inputs: {
@@ -2211,7 +2276,8 @@ async function visualQuestionAnswering(args, options) {
2211
2276
 
2212
2277
  // src/tasks/tabular/tabularClassification.ts
2213
2278
  async function tabularClassification(args, options) {
2214
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification");
2279
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2280
+ const providerHelper = getProviderHelper(provider, "tabular-classification");
2215
2281
  const { data: res } = await innerRequest(args, providerHelper, {
2216
2282
  ...options,
2217
2283
  task: "tabular-classification"
@@ -2221,7 +2287,8 @@ async function tabularClassification(args, options) {
2221
2287
 
2222
2288
  // src/tasks/tabular/tabularRegression.ts
2223
2289
  async function tabularRegression(args, options) {
2224
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression");
2290
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2291
+ const providerHelper = getProviderHelper(provider, "tabular-regression");
2225
2292
  const { data: res } = await innerRequest(args, providerHelper, {
2226
2293
  ...options,
2227
2294
  task: "tabular-regression"
@@ -2229,6 +2296,11 @@ async function tabularRegression(args, options) {
2229
2296
  return providerHelper.getResponse(res);
2230
2297
  }
2231
2298
 
2299
+ // src/utils/typedEntries.ts
2300
+ function typedEntries(obj) {
2301
+ return Object.entries(obj);
2302
+ }
2303
+
2232
2304
  // src/InferenceClient.ts
2233
2305
  var InferenceClient = class {
2234
2306
  accessToken;
@@ -2236,40 +2308,36 @@ var InferenceClient = class {
2236
2308
  constructor(accessToken = "", defaultOptions = {}) {
2237
2309
  this.accessToken = accessToken;
2238
2310
  this.defaultOptions = defaultOptions;
2239
- for (const [name2, fn] of Object.entries(tasks_exports)) {
2311
+ for (const [name2, fn] of typedEntries(tasks_exports)) {
2240
2312
  Object.defineProperty(this, name2, {
2241
2313
  enumerable: false,
2242
2314
  value: (params, options) => (
2243
2315
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
2244
- fn({ ...params, accessToken }, { ...defaultOptions, ...options })
2316
+ fn(
2317
+ /// ^ The cast of fn to any is necessary, otherwise TS can't compile because the generated union type is too complex
2318
+ { endpointUrl: defaultOptions.endpointUrl, accessToken, ...params },
2319
+ {
2320
+ ...omit(defaultOptions, ["endpointUrl"]),
2321
+ ...options
2322
+ }
2323
+ )
2245
2324
  )
2246
2325
  });
2247
2326
  }
2248
2327
  }
2249
2328
  /**
2250
- * Returns copy of InferenceClient tied to a specified endpoint.
2329
+ * Returns a new instance of InferenceClient tied to a specified endpoint.
2330
+ *
2331
+ * For backward compatibility mostly.
2251
2332
  */
2252
2333
  endpoint(endpointUrl) {
2253
- return new InferenceClientEndpoint(endpointUrl, this.accessToken, this.defaultOptions);
2254
- }
2255
- };
2256
- var InferenceClientEndpoint = class {
2257
- constructor(endpointUrl, accessToken = "", defaultOptions = {}) {
2258
- accessToken;
2259
- defaultOptions;
2260
- for (const [name2, fn] of Object.entries(tasks_exports)) {
2261
- Object.defineProperty(this, name2, {
2262
- enumerable: false,
2263
- value: (params, options) => (
2264
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
2265
- fn({ ...params, accessToken, endpointUrl }, { ...defaultOptions, ...options })
2266
- )
2267
- });
2268
- }
2334
+ return new InferenceClient(this.accessToken, { ...this.defaultOptions, endpointUrl });
2269
2335
  }
2270
2336
  };
2271
2337
  var HfInference = class extends InferenceClient {
2272
2338
  };
2339
+ var InferenceClientEndpoint = class extends InferenceClient {
2340
+ };
2273
2341
 
2274
2342
  // src/types.ts
2275
2343
  var INFERENCE_PROVIDERS = [
@@ -2291,6 +2359,7 @@ var INFERENCE_PROVIDERS = [
2291
2359
  "sambanova",
2292
2360
  "together"
2293
2361
  ];
2362
+ var PROVIDERS_OR_POLICIES = [...INFERENCE_PROVIDERS, "auto"];
2294
2363
 
2295
2364
  // src/snippets/index.ts
2296
2365
  var snippets_exports = {};
@@ -2619,7 +2688,7 @@ var prepareConversationalInput = (model, opts) => {
2619
2688
  return {
2620
2689
  messages: opts?.messages ?? (0, import_tasks.getModelInputSnippet)(model),
2621
2690
  ...opts?.temperature ? { temperature: opts?.temperature } : void 0,
2622
- max_tokens: opts?.max_tokens ?? 512,
2691
+ ...opts?.max_tokens ? { max_tokens: opts?.max_tokens } : void 0,
2623
2692
  ...opts?.top_p ? { top_p: opts?.top_p } : void 0
2624
2693
  };
2625
2694
  };
@@ -2713,6 +2782,7 @@ function removeSuffix(str, suffix) {
2713
2782
  InferenceClient,
2714
2783
  InferenceClientEndpoint,
2715
2784
  InferenceOutputError,
2785
+ PROVIDERS_OR_POLICIES,
2716
2786
  audioClassification,
2717
2787
  audioToAudio,
2718
2788
  automaticSpeechRecognition,