@huggingface/inference 3.9.2 → 3.11.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. package/README.md +9 -7
  2. package/dist/index.cjs +771 -646
  3. package/dist/index.js +770 -646
  4. package/dist/src/InferenceClient.d.ts +16 -17
  5. package/dist/src/InferenceClient.d.ts.map +1 -1
  6. package/dist/src/lib/getInferenceProviderMapping.d.ts +6 -2
  7. package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -1
  8. package/dist/src/lib/getProviderHelper.d.ts.map +1 -1
  9. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  10. package/dist/src/providers/consts.d.ts.map +1 -1
  11. package/dist/src/providers/ovhcloud.d.ts +38 -0
  12. package/dist/src/providers/ovhcloud.d.ts.map +1 -0
  13. package/dist/src/providers/providerHelper.d.ts +1 -1
  14. package/dist/src/providers/providerHelper.d.ts.map +1 -1
  15. package/dist/src/snippets/getInferenceSnippets.d.ts +1 -1
  16. package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
  17. package/dist/src/snippets/templates.exported.d.ts.map +1 -1
  18. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  19. package/dist/src/tasks/audio/audioToAudio.d.ts.map +1 -1
  20. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  21. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  22. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  23. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  24. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  25. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  26. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  27. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  28. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  29. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  30. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
  31. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  32. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  33. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  34. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  35. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
  36. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  37. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  38. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  39. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  40. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  41. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  42. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  43. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  44. package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
  45. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  46. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  47. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  48. package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
  49. package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
  50. package/dist/src/types.d.ts +7 -5
  51. package/dist/src/types.d.ts.map +1 -1
  52. package/dist/src/utils/typedEntries.d.ts +4 -0
  53. package/dist/src/utils/typedEntries.d.ts.map +1 -0
  54. package/package.json +3 -3
  55. package/src/InferenceClient.ts +32 -43
  56. package/src/lib/getInferenceProviderMapping.ts +68 -19
  57. package/src/lib/getProviderHelper.ts +5 -0
  58. package/src/lib/makeRequestOptions.ts +4 -3
  59. package/src/providers/consts.ts +1 -0
  60. package/src/providers/ovhcloud.ts +75 -0
  61. package/src/providers/providerHelper.ts +1 -1
  62. package/src/snippets/getInferenceSnippets.ts +5 -4
  63. package/src/snippets/templates.exported.ts +7 -3
  64. package/src/tasks/audio/audioClassification.ts +3 -1
  65. package/src/tasks/audio/audioToAudio.ts +4 -1
  66. package/src/tasks/audio/automaticSpeechRecognition.ts +3 -1
  67. package/src/tasks/audio/textToSpeech.ts +2 -1
  68. package/src/tasks/custom/request.ts +3 -1
  69. package/src/tasks/custom/streamingRequest.ts +3 -1
  70. package/src/tasks/cv/imageClassification.ts +3 -1
  71. package/src/tasks/cv/imageSegmentation.ts +3 -1
  72. package/src/tasks/cv/imageToImage.ts +3 -1
  73. package/src/tasks/cv/imageToText.ts +3 -1
  74. package/src/tasks/cv/objectDetection.ts +3 -1
  75. package/src/tasks/cv/textToImage.ts +2 -1
  76. package/src/tasks/cv/textToVideo.ts +2 -1
  77. package/src/tasks/cv/zeroShotImageClassification.ts +3 -1
  78. package/src/tasks/multimodal/documentQuestionAnswering.ts +3 -1
  79. package/src/tasks/multimodal/visualQuestionAnswering.ts +3 -1
  80. package/src/tasks/nlp/chatCompletion.ts +3 -1
  81. package/src/tasks/nlp/chatCompletionStream.ts +3 -1
  82. package/src/tasks/nlp/featureExtraction.ts +3 -1
  83. package/src/tasks/nlp/fillMask.ts +3 -1
  84. package/src/tasks/nlp/questionAnswering.ts +4 -1
  85. package/src/tasks/nlp/sentenceSimilarity.ts +3 -1
  86. package/src/tasks/nlp/summarization.ts +3 -1
  87. package/src/tasks/nlp/tableQuestionAnswering.ts +3 -1
  88. package/src/tasks/nlp/textClassification.ts +3 -1
  89. package/src/tasks/nlp/textGeneration.ts +3 -1
  90. package/src/tasks/nlp/textGenerationStream.ts +3 -1
  91. package/src/tasks/nlp/tokenClassification.ts +3 -1
  92. package/src/tasks/nlp/translation.ts +3 -1
  93. package/src/tasks/nlp/zeroShotClassification.ts +3 -1
  94. package/src/tasks/tabular/tabularClassification.ts +3 -1
  95. package/src/tasks/tabular/tabularRegression.ts +3 -1
  96. package/src/types.ts +9 -4
  97. package/src/utils/typedEntries.ts +5 -0
package/dist/index.js CHANGED
@@ -41,6 +41,38 @@ __export(tasks_exports, {
41
41
  zeroShotImageClassification: () => zeroShotImageClassification
42
42
  });
43
43
 
44
+ // src/config.ts
45
+ var HF_HUB_URL = "https://huggingface.co";
46
+ var HF_ROUTER_URL = "https://router.huggingface.co";
47
+ var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
48
+
49
+ // src/providers/consts.ts
50
+ var HARDCODED_MODEL_INFERENCE_MAPPING = {
51
+ /**
52
+ * "HF model ID" => "Model ID on Inference Provider's side"
53
+ *
54
+ * Example:
55
+ * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
56
+ */
57
+ "black-forest-labs": {},
58
+ cerebras: {},
59
+ cohere: {},
60
+ "fal-ai": {},
61
+ "featherless-ai": {},
62
+ "fireworks-ai": {},
63
+ groq: {},
64
+ "hf-inference": {},
65
+ hyperbolic: {},
66
+ nebius: {},
67
+ novita: {},
68
+ nscale: {},
69
+ openai: {},
70
+ ovhcloud: {},
71
+ replicate: {},
72
+ sambanova: {},
73
+ together: {}
74
+ };
75
+
44
76
  // src/lib/InferenceOutputError.ts
45
77
  var InferenceOutputError = class extends TypeError {
46
78
  constructor(message) {
@@ -51,42 +83,6 @@ var InferenceOutputError = class extends TypeError {
51
83
  }
52
84
  };
53
85
 
54
- // src/utils/delay.ts
55
- function delay(ms) {
56
- return new Promise((resolve) => {
57
- setTimeout(() => resolve(), ms);
58
- });
59
- }
60
-
61
- // src/utils/pick.ts
62
- function pick(o, props) {
63
- return Object.assign(
64
- {},
65
- ...props.map((prop) => {
66
- if (o[prop] !== void 0) {
67
- return { [prop]: o[prop] };
68
- }
69
- })
70
- );
71
- }
72
-
73
- // src/utils/typedInclude.ts
74
- function typedInclude(arr, v) {
75
- return arr.includes(v);
76
- }
77
-
78
- // src/utils/omit.ts
79
- function omit(o, props) {
80
- const propsArr = Array.isArray(props) ? props : [props];
81
- const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
82
- return pick(o, letsKeep);
83
- }
84
-
85
- // src/config.ts
86
- var HF_HUB_URL = "https://huggingface.co";
87
- var HF_ROUTER_URL = "https://router.huggingface.co";
88
- var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
89
-
90
86
  // src/utils/toArray.ts
91
87
  function toArray(obj) {
92
88
  if (Array.isArray(obj)) {
@@ -181,627 +177,736 @@ var BaseTextGenerationTask = class extends TaskProviderHelper {
181
177
  }
182
178
  };
183
179
 
184
- // src/providers/black-forest-labs.ts
185
- var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
186
- var BlackForestLabsTextToImageTask = class extends TaskProviderHelper {
180
+ // src/providers/hf-inference.ts
181
+ var EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = ["feature-extraction", "sentence-similarity"];
182
+ var HFInferenceTask = class extends TaskProviderHelper {
187
183
  constructor() {
188
- super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
184
+ super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
189
185
  }
190
186
  preparePayload(params) {
191
- return {
192
- ...omit(params.args, ["inputs", "parameters"]),
193
- ...params.args.parameters,
194
- prompt: params.args.inputs
195
- };
187
+ return params.args;
196
188
  }
197
- prepareHeaders(params, binary) {
198
- const headers = {
199
- Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`
200
- };
201
- if (!binary) {
202
- headers["Content-Type"] = "application/json";
189
+ makeUrl(params) {
190
+ if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
191
+ return params.model;
203
192
  }
204
- return headers;
193
+ return super.makeUrl(params);
205
194
  }
206
195
  makeRoute(params) {
207
- if (!params) {
208
- throw new Error("Params are required");
196
+ if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
197
+ return `pipeline/${params.task}/${params.model}`;
209
198
  }
210
- return `/v1/${params.model}`;
199
+ return `models/${params.model}`;
200
+ }
201
+ async getResponse(response) {
202
+ return response;
211
203
  }
204
+ };
205
+ var HFInferenceTextToImageTask = class extends HFInferenceTask {
212
206
  async getResponse(response, url, headers, outputType) {
213
- const urlObj = new URL(response.polling_url);
214
- for (let step = 0; step < 5; step++) {
215
- await delay(1e3);
216
- console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
217
- urlObj.searchParams.set("attempt", step.toString(10));
218
- const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
219
- if (!resp.ok) {
220
- throw new InferenceOutputError("Failed to fetch result from black forest labs API");
207
+ if (!response) {
208
+ throw new InferenceOutputError("response is undefined");
209
+ }
210
+ if (typeof response == "object") {
211
+ if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
212
+ const base64Data = response.data[0].b64_json;
213
+ if (outputType === "url") {
214
+ return `data:image/jpeg;base64,${base64Data}`;
215
+ }
216
+ const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
217
+ return await base64Response.blob();
221
218
  }
222
- const payload = await resp.json();
223
- if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
219
+ if ("output" in response && Array.isArray(response.output)) {
224
220
  if (outputType === "url") {
225
- return payload.result.sample;
221
+ return response.output[0];
226
222
  }
227
- const image = await fetch(payload.result.sample);
228
- return await image.blob();
223
+ const urlResponse = await fetch(response.output[0]);
224
+ const blob = await urlResponse.blob();
225
+ return blob;
229
226
  }
230
227
  }
231
- throw new InferenceOutputError("Failed to fetch result from black forest labs API");
228
+ if (response instanceof Blob) {
229
+ if (outputType === "url") {
230
+ const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
231
+ return `data:image/jpeg;base64,${b64}`;
232
+ }
233
+ return response;
234
+ }
235
+ throw new InferenceOutputError("Expected a Blob ");
232
236
  }
233
237
  };
234
-
235
- // src/providers/cerebras.ts
236
- var CerebrasConversationalTask = class extends BaseConversationalTask {
237
- constructor() {
238
- super("cerebras", "https://api.cerebras.ai");
238
+ var HFInferenceConversationalTask = class extends HFInferenceTask {
239
+ makeUrl(params) {
240
+ let url;
241
+ if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
242
+ url = params.model.trim();
243
+ } else {
244
+ url = `${this.makeBaseUrl(params)}/models/${params.model}`;
245
+ }
246
+ url = url.replace(/\/+$/, "");
247
+ if (url.endsWith("/v1")) {
248
+ url += "/chat/completions";
249
+ } else if (!url.endsWith("/chat/completions")) {
250
+ url += "/v1/chat/completions";
251
+ }
252
+ return url;
253
+ }
254
+ preparePayload(params) {
255
+ return {
256
+ ...params.args,
257
+ model: params.model
258
+ };
259
+ }
260
+ async getResponse(response) {
261
+ return response;
239
262
  }
240
263
  };
241
-
242
- // src/providers/cohere.ts
243
- var CohereConversationalTask = class extends BaseConversationalTask {
244
- constructor() {
245
- super("cohere", "https://api.cohere.com");
264
+ var HFInferenceTextGenerationTask = class extends HFInferenceTask {
265
+ async getResponse(response) {
266
+ const res = toArray(response);
267
+ if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
268
+ return res?.[0];
269
+ }
270
+ throw new InferenceOutputError("Expected Array<{generated_text: string}>");
246
271
  }
247
- makeRoute() {
248
- return "/compatibility/v1/chat/completions";
272
+ };
273
+ var HFInferenceAudioClassificationTask = class extends HFInferenceTask {
274
+ async getResponse(response) {
275
+ if (Array.isArray(response) && response.every(
276
+ (x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
277
+ )) {
278
+ return response;
279
+ }
280
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
249
281
  }
250
282
  };
251
-
252
- // src/lib/isUrl.ts
253
- function isUrl(modelOrUrl) {
254
- return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
255
- }
256
-
257
- // src/providers/fal-ai.ts
258
- var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
259
- var FalAITask = class extends TaskProviderHelper {
260
- constructor(url) {
261
- super("fal-ai", url || "https://fal.run");
283
+ var HFInferenceAutomaticSpeechRecognitionTask = class extends HFInferenceTask {
284
+ async getResponse(response) {
285
+ return response;
262
286
  }
263
- preparePayload(params) {
264
- return params.args;
287
+ };
288
+ var HFInferenceAudioToAudioTask = class extends HFInferenceTask {
289
+ async getResponse(response) {
290
+ if (!Array.isArray(response)) {
291
+ throw new InferenceOutputError("Expected Array");
292
+ }
293
+ if (!response.every((elem) => {
294
+ return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
295
+ })) {
296
+ throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
297
+ }
298
+ return response;
265
299
  }
266
- makeRoute(params) {
267
- return `/${params.model}`;
300
+ };
301
+ var HFInferenceDocumentQuestionAnsweringTask = class extends HFInferenceTask {
302
+ async getResponse(response) {
303
+ if (Array.isArray(response) && response.every(
304
+ (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
305
+ )) {
306
+ return response[0];
307
+ }
308
+ throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
268
309
  }
269
- prepareHeaders(params, binary) {
270
- const headers = {
271
- Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`
310
+ };
311
+ var HFInferenceFeatureExtractionTask = class extends HFInferenceTask {
312
+ async getResponse(response) {
313
+ const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
314
+ if (curDepth > maxDepth)
315
+ return false;
316
+ if (arr.every((x) => Array.isArray(x))) {
317
+ return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
318
+ } else {
319
+ return arr.every((x) => typeof x === "number");
320
+ }
272
321
  };
273
- if (!binary) {
274
- headers["Content-Type"] = "application/json";
322
+ if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
323
+ return response;
275
324
  }
276
- return headers;
325
+ throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
277
326
  }
278
327
  };
279
- function buildLoraPath(modelId, adapterWeightsPath) {
280
- return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
281
- }
282
- var FalAITextToImageTask = class extends FalAITask {
283
- preparePayload(params) {
284
- const payload = {
285
- ...omit(params.args, ["inputs", "parameters"]),
286
- ...params.args.parameters,
287
- sync_mode: true,
288
- prompt: params.args.inputs
289
- };
290
- if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
291
- payload.loras = [
292
- {
293
- path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
294
- scale: 1
295
- }
296
- ];
297
- if (params.mapping.providerId === "fal-ai/lora") {
298
- payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
299
- }
328
+ var HFInferenceImageClassificationTask = class extends HFInferenceTask {
329
+ async getResponse(response) {
330
+ if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
331
+ return response;
300
332
  }
301
- return payload;
333
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
302
334
  }
303
- async getResponse(response, outputType) {
304
- if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") {
305
- if (outputType === "url") {
306
- return response.images[0].url;
307
- }
308
- const urlResponse = await fetch(response.images[0].url);
309
- return await urlResponse.blob();
335
+ };
336
+ var HFInferenceImageSegmentationTask = class extends HFInferenceTask {
337
+ async getResponse(response) {
338
+ if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) {
339
+ return response;
310
340
  }
311
- throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
341
+ throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
312
342
  }
313
343
  };
314
- var FalAITextToVideoTask = class extends FalAITask {
315
- constructor() {
316
- super("https://queue.fal.run");
317
- }
318
- makeRoute(params) {
319
- if (params.authMethod !== "provider-key") {
320
- return `/${params.model}?_subdomain=queue`;
344
+ var HFInferenceImageToTextTask = class extends HFInferenceTask {
345
+ async getResponse(response) {
346
+ if (typeof response?.generated_text !== "string") {
347
+ throw new InferenceOutputError("Expected {generated_text: string}");
321
348
  }
322
- return `/${params.model}`;
349
+ return response;
323
350
  }
324
- preparePayload(params) {
325
- return {
326
- ...omit(params.args, ["inputs", "parameters"]),
327
- ...params.args.parameters,
328
- prompt: params.args.inputs
329
- };
351
+ };
352
+ var HFInferenceImageToImageTask = class extends HFInferenceTask {
353
+ async getResponse(response) {
354
+ if (response instanceof Blob) {
355
+ return response;
356
+ }
357
+ throw new InferenceOutputError("Expected Blob");
330
358
  }
331
- async getResponse(response, url, headers) {
332
- if (!url || !headers) {
333
- throw new InferenceOutputError("URL and headers are required for text-to-video task");
359
+ };
360
+ var HFInferenceObjectDetectionTask = class extends HFInferenceTask {
361
+ async getResponse(response) {
362
+ if (Array.isArray(response) && response.every(
363
+ (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
364
+ )) {
365
+ return response;
334
366
  }
335
- const requestId = response.request_id;
336
- if (!requestId) {
337
- throw new InferenceOutputError("No request ID found in the response");
367
+ throw new InferenceOutputError(
368
+ "Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
369
+ );
370
+ }
371
+ };
372
+ var HFInferenceZeroShotImageClassificationTask = class extends HFInferenceTask {
373
+ async getResponse(response) {
374
+ if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
375
+ return response;
338
376
  }
339
- let status = response.status;
340
- const parsedUrl = new URL(url);
341
- const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
342
- const modelId = new URL(response.response_url).pathname;
343
- const queryParams = parsedUrl.search;
344
- const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
345
- const resultUrl = `${baseUrl}${modelId}${queryParams}`;
346
- while (status !== "COMPLETED") {
347
- await delay(500);
348
- const statusResponse = await fetch(statusUrl, { headers });
349
- if (!statusResponse.ok) {
350
- throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
351
- }
352
- try {
353
- status = (await statusResponse.json()).status;
354
- } catch (error) {
355
- throw new InferenceOutputError("Failed to parse status response from fal-ai API");
356
- }
377
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
378
+ }
379
+ };
380
+ var HFInferenceTextClassificationTask = class extends HFInferenceTask {
381
+ async getResponse(response) {
382
+ const output = response?.[0];
383
+ if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
384
+ return output;
357
385
  }
358
- const resultResponse = await fetch(resultUrl, { headers });
359
- let result;
360
- try {
361
- result = await resultResponse.json();
362
- } catch (error) {
363
- throw new InferenceOutputError("Failed to parse result response from fal-ai API");
386
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
387
+ }
388
+ };
389
+ var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask {
390
+ async getResponse(response) {
391
+ if (Array.isArray(response) ? response.every(
392
+ (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
393
+ ) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number") {
394
+ return Array.isArray(response) ? response[0] : response;
364
395
  }
365
- if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
366
- const urlResponse = await fetch(result.video.url);
367
- return await urlResponse.blob();
368
- } else {
369
- throw new InferenceOutputError(
370
- "Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
371
- );
396
+ throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
397
+ }
398
+ };
399
+ var HFInferenceFillMaskTask = class extends HFInferenceTask {
400
+ async getResponse(response) {
401
+ if (Array.isArray(response) && response.every(
402
+ (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
403
+ )) {
404
+ return response;
372
405
  }
406
+ throw new InferenceOutputError(
407
+ "Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
408
+ );
373
409
  }
374
410
  };
375
- var FalAIAutomaticSpeechRecognitionTask = class extends FalAITask {
376
- prepareHeaders(params, binary) {
377
- const headers = super.prepareHeaders(params, binary);
378
- headers["Content-Type"] = "application/json";
379
- return headers;
411
+ var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask {
412
+ async getResponse(response) {
413
+ if (Array.isArray(response) && response.every(
414
+ (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
415
+ )) {
416
+ return response;
417
+ }
418
+ throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
380
419
  }
420
+ };
421
+ var HFInferenceSentenceSimilarityTask = class extends HFInferenceTask {
381
422
  async getResponse(response) {
382
- const res = response;
383
- if (typeof res?.text !== "string") {
384
- throw new InferenceOutputError(
385
- `Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
386
- );
423
+ if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
424
+ return response;
387
425
  }
388
- return { text: res.text };
426
+ throw new InferenceOutputError("Expected Array<number>");
389
427
  }
390
428
  };
391
- var FalAITextToSpeechTask = class extends FalAITask {
392
- preparePayload(params) {
393
- return {
394
- ...omit(params.args, ["inputs", "parameters"]),
395
- ...params.args.parameters,
396
- text: params.args.inputs
397
- };
429
+ var HFInferenceTableQuestionAnsweringTask = class extends HFInferenceTask {
430
+ static validate(elem) {
431
+ return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
432
+ (coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
433
+ );
398
434
  }
399
435
  async getResponse(response) {
400
- const res = response;
401
- if (typeof res?.audio?.url !== "string") {
402
- throw new InferenceOutputError(
403
- `Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
404
- );
436
+ if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) {
437
+ return Array.isArray(response) ? response[0] : response;
405
438
  }
406
- try {
407
- const urlResponse = await fetch(res.audio.url);
408
- if (!urlResponse.ok) {
409
- throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
410
- }
411
- return await urlResponse.blob();
412
- } catch (error) {
413
- throw new InferenceOutputError(
414
- `Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)}`
415
- );
439
+ throw new InferenceOutputError(
440
+ "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
441
+ );
442
+ }
443
+ };
444
+ var HFInferenceTokenClassificationTask = class extends HFInferenceTask {
445
+ async getResponse(response) {
446
+ if (Array.isArray(response) && response.every(
447
+ (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
448
+ )) {
449
+ return response;
416
450
  }
451
+ throw new InferenceOutputError(
452
+ "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
453
+ );
417
454
  }
418
455
  };
419
-
420
- // src/providers/featherless-ai.ts
421
- var FEATHERLESS_API_BASE_URL = "https://api.featherless.ai";
422
- var FeatherlessAIConversationalTask = class extends BaseConversationalTask {
423
- constructor() {
424
- super("featherless-ai", FEATHERLESS_API_BASE_URL);
456
+ var HFInferenceTranslationTask = class extends HFInferenceTask {
457
+ async getResponse(response) {
458
+ if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
459
+ return response?.length === 1 ? response?.[0] : response;
460
+ }
461
+ throw new InferenceOutputError("Expected Array<{translation_text: string}>");
425
462
  }
426
463
  };
427
- var FeatherlessAITextGenerationTask = class extends BaseTextGenerationTask {
428
- constructor() {
429
- super("featherless-ai", FEATHERLESS_API_BASE_URL);
464
+ var HFInferenceSummarizationTask = class extends HFInferenceTask {
465
+ async getResponse(response) {
466
+ if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
467
+ return response?.[0];
468
+ }
469
+ throw new InferenceOutputError("Expected Array<{summary_text: string}>");
430
470
  }
431
- preparePayload(params) {
432
- return {
433
- ...params.args,
434
- ...params.args.parameters,
435
- model: params.model,
436
- prompt: params.args.inputs
437
- };
471
+ };
472
+ var HFInferenceTextToSpeechTask = class extends HFInferenceTask {
473
+ async getResponse(response) {
474
+ return response;
438
475
  }
476
+ };
477
+ var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
439
478
  async getResponse(response) {
440
- if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
441
- const completion = response.choices[0];
442
- return {
443
- generated_text: completion.text
444
- };
479
+ if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
480
+ return response;
445
481
  }
446
- throw new InferenceOutputError("Expected Featherless AI text generation response format");
482
+ throw new InferenceOutputError("Expected Array<number>");
447
483
  }
448
484
  };
449
-
450
- // src/providers/fireworks-ai.ts
451
- var FireworksConversationalTask = class extends BaseConversationalTask {
452
- constructor() {
453
- super("fireworks-ai", "https://api.fireworks.ai");
485
+ var HFInferenceVisualQuestionAnsweringTask = class extends HFInferenceTask {
486
+ async getResponse(response) {
487
+ if (Array.isArray(response) && response.every(
488
+ (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
489
+ )) {
490
+ return response[0];
491
+ }
492
+ throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
454
493
  }
455
- makeRoute() {
456
- return "/inference/v1/chat/completions";
494
+ };
495
+ var HFInferenceTabularRegressionTask = class extends HFInferenceTask {
496
+ async getResponse(response) {
497
+ if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
498
+ return response;
499
+ }
500
+ throw new InferenceOutputError("Expected Array<number>");
501
+ }
502
+ };
503
+ var HFInferenceTextToAudioTask = class extends HFInferenceTask {
504
+ async getResponse(response) {
505
+ return response;
457
506
  }
458
507
  };
459
508
 
460
- // src/providers/groq.ts
461
- var GROQ_API_BASE_URL = "https://api.groq.com";
462
- var GroqTextGenerationTask = class extends BaseTextGenerationTask {
463
- constructor() {
464
- super("groq", GROQ_API_BASE_URL);
509
+ // src/utils/typedInclude.ts
510
+ function typedInclude(arr, v) {
511
+ return arr.includes(v);
512
+ }
513
+
514
+ // src/lib/getInferenceProviderMapping.ts
515
+ var inferenceProviderMappingCache = /* @__PURE__ */ new Map();
516
+ async function fetchInferenceProviderMappingForModel(modelId, accessToken, options) {
517
+ let inferenceProviderMapping;
518
+ if (inferenceProviderMappingCache.has(modelId)) {
519
+ inferenceProviderMapping = inferenceProviderMappingCache.get(modelId);
520
+ } else {
521
+ const resp = await (options?.fetch ?? fetch)(
522
+ `${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
523
+ {
524
+ headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {}
525
+ }
526
+ );
527
+ if (resp.status === 404) {
528
+ throw new Error(`Model ${modelId} does not exist`);
529
+ }
530
+ inferenceProviderMapping = await resp.json().then((json) => json.inferenceProviderMapping).catch(() => null);
531
+ if (inferenceProviderMapping) {
532
+ inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
533
+ }
465
534
  }
466
- makeRoute() {
467
- return "/openai/v1/chat/completions";
535
+ if (!inferenceProviderMapping) {
536
+ throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
468
537
  }
469
- };
470
- var GroqConversationalTask = class extends BaseConversationalTask {
471
- constructor() {
472
- super("groq", GROQ_API_BASE_URL);
538
+ return inferenceProviderMapping;
539
+ }
540
+ async function getInferenceProviderMapping(params, options) {
541
+ if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
542
+ return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
473
543
  }
474
- makeRoute() {
475
- return "/openai/v1/chat/completions";
544
+ const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
545
+ params.modelId,
546
+ params.accessToken,
547
+ options
548
+ );
549
+ const providerMapping = inferenceProviderMapping[params.provider];
550
+ if (providerMapping) {
551
+ const equivalentTasks = params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task];
552
+ if (!typedInclude(equivalentTasks, providerMapping.task)) {
553
+ throw new Error(
554
+ `Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
555
+ );
556
+ }
557
+ if (providerMapping.status === "staging") {
558
+ console.warn(
559
+ `Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
560
+ );
561
+ }
562
+ return { ...providerMapping, hfModelId: params.modelId };
476
563
  }
477
- };
564
+ return null;
565
+ }
566
+ async function resolveProvider(provider, modelId, endpointUrl) {
567
+ if (endpointUrl) {
568
+ if (provider) {
569
+ throw new Error("Specifying both endpointUrl and provider is not supported.");
570
+ }
571
+ return "hf-inference";
572
+ }
573
+ if (!provider) {
574
+ console.log(
575
+ "Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
576
+ );
577
+ provider = "auto";
578
+ }
579
+ if (provider === "auto") {
580
+ if (!modelId) {
581
+ throw new Error("Specifying a model is required when provider is 'auto'");
582
+ }
583
+ const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
584
+ provider = Object.keys(inferenceProviderMapping)[0];
585
+ }
586
+ if (!provider) {
587
+ throw new Error(`No Inference Provider available for model ${modelId}.`);
588
+ }
589
+ return provider;
590
+ }
478
591
 
479
- // src/providers/hf-inference.ts
480
- var EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = ["feature-extraction", "sentence-similarity"];
481
- var HFInferenceTask = class extends TaskProviderHelper {
592
+ // src/utils/delay.ts
593
+ function delay(ms) {
594
+ return new Promise((resolve) => {
595
+ setTimeout(() => resolve(), ms);
596
+ });
597
+ }
598
+
599
+ // src/utils/pick.ts
600
+ function pick(o, props) {
601
+ return Object.assign(
602
+ {},
603
+ ...props.map((prop) => {
604
+ if (o[prop] !== void 0) {
605
+ return { [prop]: o[prop] };
606
+ }
607
+ })
608
+ );
609
+ }
610
+
611
+ // src/utils/omit.ts
612
+ function omit(o, props) {
613
+ const propsArr = Array.isArray(props) ? props : [props];
614
+ const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
615
+ return pick(o, letsKeep);
616
+ }
617
+
618
+ // src/providers/black-forest-labs.ts
619
+ var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
620
+ var BlackForestLabsTextToImageTask = class extends TaskProviderHelper {
482
621
  constructor() {
483
- super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
622
+ super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
484
623
  }
485
624
  preparePayload(params) {
486
- return params.args;
625
+ return {
626
+ ...omit(params.args, ["inputs", "parameters"]),
627
+ ...params.args.parameters,
628
+ prompt: params.args.inputs
629
+ };
487
630
  }
488
- makeUrl(params) {
489
- if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
490
- return params.model;
631
+ prepareHeaders(params, binary) {
632
+ const headers = {
633
+ Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`
634
+ };
635
+ if (!binary) {
636
+ headers["Content-Type"] = "application/json";
491
637
  }
492
- return super.makeUrl(params);
638
+ return headers;
493
639
  }
494
640
  makeRoute(params) {
495
- if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
496
- return `pipeline/${params.task}/${params.model}`;
641
+ if (!params) {
642
+ throw new Error("Params are required");
497
643
  }
498
- return `models/${params.model}`;
499
- }
500
- async getResponse(response) {
501
- return response;
644
+ return `/v1/${params.model}`;
502
645
  }
503
- };
504
- var HFInferenceTextToImageTask = class extends HFInferenceTask {
505
646
  async getResponse(response, url, headers, outputType) {
506
- if (!response) {
507
- throw new InferenceOutputError("response is undefined");
508
- }
509
- if (typeof response == "object") {
510
- if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
511
- const base64Data = response.data[0].b64_json;
512
- if (outputType === "url") {
513
- return `data:image/jpeg;base64,${base64Data}`;
514
- }
515
- const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
516
- return await base64Response.blob();
647
+ const urlObj = new URL(response.polling_url);
648
+ for (let step = 0; step < 5; step++) {
649
+ await delay(1e3);
650
+ console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
651
+ urlObj.searchParams.set("attempt", step.toString(10));
652
+ const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
653
+ if (!resp.ok) {
654
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
517
655
  }
518
- if ("output" in response && Array.isArray(response.output)) {
656
+ const payload = await resp.json();
657
+ if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
519
658
  if (outputType === "url") {
520
- return response.output[0];
659
+ return payload.result.sample;
521
660
  }
522
- const urlResponse = await fetch(response.output[0]);
523
- const blob = await urlResponse.blob();
524
- return blob;
525
- }
526
- }
527
- if (response instanceof Blob) {
528
- if (outputType === "url") {
529
- const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
530
- return `data:image/jpeg;base64,${b64}`;
661
+ const image = await fetch(payload.result.sample);
662
+ return await image.blob();
531
663
  }
532
- return response;
533
664
  }
534
- throw new InferenceOutputError("Expected a Blob ");
665
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
535
666
  }
536
667
  };
537
- var HFInferenceConversationalTask = class extends HFInferenceTask {
538
- makeUrl(params) {
539
- let url;
540
- if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
541
- url = params.model.trim();
542
- } else {
543
- url = `${this.makeBaseUrl(params)}/models/${params.model}`;
544
- }
545
- url = url.replace(/\/+$/, "");
546
- if (url.endsWith("/v1")) {
547
- url += "/chat/completions";
548
- } else if (!url.endsWith("/chat/completions")) {
549
- url += "/v1/chat/completions";
550
- }
551
- return url;
552
- }
553
- preparePayload(params) {
554
- return {
555
- ...params.args,
556
- model: params.model
557
- };
558
- }
559
- async getResponse(response) {
560
- return response;
668
+
669
+ // src/providers/cerebras.ts
670
+ var CerebrasConversationalTask = class extends BaseConversationalTask {
671
+ constructor() {
672
+ super("cerebras", "https://api.cerebras.ai");
561
673
  }
562
674
  };
563
- var HFInferenceTextGenerationTask = class extends HFInferenceTask {
564
- async getResponse(response) {
565
- const res = toArray(response);
566
- if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
567
- return res?.[0];
568
- }
569
- throw new InferenceOutputError("Expected Array<{generated_text: string}>");
675
+
676
+ // src/providers/cohere.ts
677
+ var CohereConversationalTask = class extends BaseConversationalTask {
678
+ constructor() {
679
+ super("cohere", "https://api.cohere.com");
570
680
  }
571
- };
572
- var HFInferenceAudioClassificationTask = class extends HFInferenceTask {
573
- async getResponse(response) {
574
- if (Array.isArray(response) && response.every(
575
- (x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
576
- )) {
577
- return response;
578
- }
579
- throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
681
+ makeRoute() {
682
+ return "/compatibility/v1/chat/completions";
580
683
  }
581
684
  };
582
- var HFInferenceAutomaticSpeechRecognitionTask = class extends HFInferenceTask {
583
- async getResponse(response) {
584
- return response;
685
+
686
+ // src/lib/isUrl.ts
687
+ function isUrl(modelOrUrl) {
688
+ return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
689
+ }
690
+
691
+ // src/providers/fal-ai.ts
692
+ var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
693
+ var FalAITask = class extends TaskProviderHelper {
694
+ constructor(url) {
695
+ super("fal-ai", url || "https://fal.run");
585
696
  }
586
- };
587
- var HFInferenceAudioToAudioTask = class extends HFInferenceTask {
588
- async getResponse(response) {
589
- if (!Array.isArray(response)) {
590
- throw new InferenceOutputError("Expected Array");
591
- }
592
- if (!response.every((elem) => {
593
- return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
594
- })) {
595
- throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
596
- }
597
- return response;
697
+ preparePayload(params) {
698
+ return params.args;
598
699
  }
599
- };
600
- var HFInferenceDocumentQuestionAnsweringTask = class extends HFInferenceTask {
601
- async getResponse(response) {
602
- if (Array.isArray(response) && response.every(
603
- (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
604
- )) {
605
- return response[0];
606
- }
607
- throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
700
+ makeRoute(params) {
701
+ return `/${params.model}`;
608
702
  }
609
- };
610
- var HFInferenceFeatureExtractionTask = class extends HFInferenceTask {
611
- async getResponse(response) {
612
- const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
613
- if (curDepth > maxDepth)
614
- return false;
615
- if (arr.every((x) => Array.isArray(x))) {
616
- return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
617
- } else {
618
- return arr.every((x) => typeof x === "number");
619
- }
703
+ prepareHeaders(params, binary) {
704
+ const headers = {
705
+ Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`
620
706
  };
621
- if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
622
- return response;
707
+ if (!binary) {
708
+ headers["Content-Type"] = "application/json";
623
709
  }
624
- throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
710
+ return headers;
625
711
  }
626
712
  };
627
- var HFInferenceImageClassificationTask = class extends HFInferenceTask {
628
- async getResponse(response) {
629
- if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
630
- return response;
713
+ function buildLoraPath(modelId, adapterWeightsPath) {
714
+ return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
715
+ }
716
+ var FalAITextToImageTask = class extends FalAITask {
717
+ preparePayload(params) {
718
+ const payload = {
719
+ ...omit(params.args, ["inputs", "parameters"]),
720
+ ...params.args.parameters,
721
+ sync_mode: true,
722
+ prompt: params.args.inputs
723
+ };
724
+ if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
725
+ payload.loras = [
726
+ {
727
+ path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
728
+ scale: 1
729
+ }
730
+ ];
731
+ if (params.mapping.providerId === "fal-ai/lora") {
732
+ payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
733
+ }
631
734
  }
632
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
735
+ return payload;
633
736
  }
634
- };
635
- var HFInferenceImageSegmentationTask = class extends HFInferenceTask {
636
- async getResponse(response) {
637
- if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) {
638
- return response;
737
+ async getResponse(response, outputType) {
738
+ if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") {
739
+ if (outputType === "url") {
740
+ return response.images[0].url;
741
+ }
742
+ const urlResponse = await fetch(response.images[0].url);
743
+ return await urlResponse.blob();
639
744
  }
640
- throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
745
+ throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
641
746
  }
642
747
  };
643
- var HFInferenceImageToTextTask = class extends HFInferenceTask {
644
- async getResponse(response) {
645
- if (typeof response?.generated_text !== "string") {
646
- throw new InferenceOutputError("Expected {generated_text: string}");
647
- }
648
- return response;
748
+ var FalAITextToVideoTask = class extends FalAITask {
749
+ constructor() {
750
+ super("https://queue.fal.run");
649
751
  }
650
- };
651
- var HFInferenceImageToImageTask = class extends HFInferenceTask {
652
- async getResponse(response) {
653
- if (response instanceof Blob) {
654
- return response;
752
+ makeRoute(params) {
753
+ if (params.authMethod !== "provider-key") {
754
+ return `/${params.model}?_subdomain=queue`;
655
755
  }
656
- throw new InferenceOutputError("Expected Blob");
756
+ return `/${params.model}`;
657
757
  }
658
- };
659
- var HFInferenceObjectDetectionTask = class extends HFInferenceTask {
660
- async getResponse(response) {
661
- if (Array.isArray(response) && response.every(
662
- (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
663
- )) {
664
- return response;
665
- }
666
- throw new InferenceOutputError(
667
- "Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
668
- );
758
+ preparePayload(params) {
759
+ return {
760
+ ...omit(params.args, ["inputs", "parameters"]),
761
+ ...params.args.parameters,
762
+ prompt: params.args.inputs
763
+ };
669
764
  }
670
- };
671
- var HFInferenceZeroShotImageClassificationTask = class extends HFInferenceTask {
672
- async getResponse(response) {
673
- if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
674
- return response;
765
+ async getResponse(response, url, headers) {
766
+ if (!url || !headers) {
767
+ throw new InferenceOutputError("URL and headers are required for text-to-video task");
675
768
  }
676
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
677
- }
678
- };
679
- var HFInferenceTextClassificationTask = class extends HFInferenceTask {
680
- async getResponse(response) {
681
- const output = response?.[0];
682
- if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
683
- return output;
769
+ const requestId = response.request_id;
770
+ if (!requestId) {
771
+ throw new InferenceOutputError("No request ID found in the response");
684
772
  }
685
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
686
- }
687
- };
688
- var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask {
689
- async getResponse(response) {
690
- if (Array.isArray(response) ? response.every(
691
- (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
692
- ) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number") {
693
- return Array.isArray(response) ? response[0] : response;
773
+ let status = response.status;
774
+ const parsedUrl = new URL(url);
775
+ const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
776
+ const modelId = new URL(response.response_url).pathname;
777
+ const queryParams = parsedUrl.search;
778
+ const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
779
+ const resultUrl = `${baseUrl}${modelId}${queryParams}`;
780
+ while (status !== "COMPLETED") {
781
+ await delay(500);
782
+ const statusResponse = await fetch(statusUrl, { headers });
783
+ if (!statusResponse.ok) {
784
+ throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
785
+ }
786
+ try {
787
+ status = (await statusResponse.json()).status;
788
+ } catch (error) {
789
+ throw new InferenceOutputError("Failed to parse status response from fal-ai API");
790
+ }
694
791
  }
695
- throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
696
- }
697
- };
698
- var HFInferenceFillMaskTask = class extends HFInferenceTask {
699
- async getResponse(response) {
700
- if (Array.isArray(response) && response.every(
701
- (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
702
- )) {
703
- return response;
792
+ const resultResponse = await fetch(resultUrl, { headers });
793
+ let result;
794
+ try {
795
+ result = await resultResponse.json();
796
+ } catch (error) {
797
+ throw new InferenceOutputError("Failed to parse result response from fal-ai API");
704
798
  }
705
- throw new InferenceOutputError(
706
- "Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
707
- );
708
- }
709
- };
710
- var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask {
711
- async getResponse(response) {
712
- if (Array.isArray(response) && response.every(
713
- (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
714
- )) {
715
- return response;
799
+ if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
800
+ const urlResponse = await fetch(result.video.url);
801
+ return await urlResponse.blob();
802
+ } else {
803
+ throw new InferenceOutputError(
804
+ "Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
805
+ );
716
806
  }
717
- throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
718
807
  }
719
808
  };
720
- var HFInferenceSentenceSimilarityTask = class extends HFInferenceTask {
809
+ var FalAIAutomaticSpeechRecognitionTask = class extends FalAITask {
810
+ prepareHeaders(params, binary) {
811
+ const headers = super.prepareHeaders(params, binary);
812
+ headers["Content-Type"] = "application/json";
813
+ return headers;
814
+ }
721
815
  async getResponse(response) {
722
- if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
723
- return response;
816
+ const res = response;
817
+ if (typeof res?.text !== "string") {
818
+ throw new InferenceOutputError(
819
+ `Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
820
+ );
724
821
  }
725
- throw new InferenceOutputError("Expected Array<number>");
822
+ return { text: res.text };
726
823
  }
727
824
  };
728
- var HFInferenceTableQuestionAnsweringTask = class extends HFInferenceTask {
729
- static validate(elem) {
730
- return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
731
- (coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
732
- );
825
+ var FalAITextToSpeechTask = class extends FalAITask {
826
+ preparePayload(params) {
827
+ return {
828
+ ...omit(params.args, ["inputs", "parameters"]),
829
+ ...params.args.parameters,
830
+ text: params.args.inputs
831
+ };
733
832
  }
734
833
  async getResponse(response) {
735
- if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) {
736
- return Array.isArray(response) ? response[0] : response;
834
+ const res = response;
835
+ if (typeof res?.audio?.url !== "string") {
836
+ throw new InferenceOutputError(
837
+ `Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
838
+ );
737
839
  }
738
- throw new InferenceOutputError(
739
- "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
740
- );
741
- }
742
- };
743
- var HFInferenceTokenClassificationTask = class extends HFInferenceTask {
744
- async getResponse(response) {
745
- if (Array.isArray(response) && response.every(
746
- (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
747
- )) {
748
- return response;
840
+ try {
841
+ const urlResponse = await fetch(res.audio.url);
842
+ if (!urlResponse.ok) {
843
+ throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
844
+ }
845
+ return await urlResponse.blob();
846
+ } catch (error) {
847
+ throw new InferenceOutputError(
848
+ `Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)}`
849
+ );
749
850
  }
750
- throw new InferenceOutputError(
751
- "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
752
- );
753
851
  }
754
852
  };
755
- var HFInferenceTranslationTask = class extends HFInferenceTask {
756
- async getResponse(response) {
757
- if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
758
- return response?.length === 1 ? response?.[0] : response;
759
- }
760
- throw new InferenceOutputError("Expected Array<{translation_text: string}>");
853
+
854
+ // src/providers/featherless-ai.ts
855
+ var FEATHERLESS_API_BASE_URL = "https://api.featherless.ai";
856
+ var FeatherlessAIConversationalTask = class extends BaseConversationalTask {
857
+ constructor() {
858
+ super("featherless-ai", FEATHERLESS_API_BASE_URL);
761
859
  }
762
860
  };
763
- var HFInferenceSummarizationTask = class extends HFInferenceTask {
764
- async getResponse(response) {
765
- if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
766
- return response?.[0];
767
- }
768
- throw new InferenceOutputError("Expected Array<{summary_text: string}>");
861
+ var FeatherlessAITextGenerationTask = class extends BaseTextGenerationTask {
862
+ constructor() {
863
+ super("featherless-ai", FEATHERLESS_API_BASE_URL);
769
864
  }
770
- };
771
- var HFInferenceTextToSpeechTask = class extends HFInferenceTask {
772
- async getResponse(response) {
773
- return response;
865
+ preparePayload(params) {
866
+ return {
867
+ ...params.args,
868
+ ...params.args.parameters,
869
+ model: params.model,
870
+ prompt: params.args.inputs
871
+ };
774
872
  }
775
- };
776
- var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
777
873
  async getResponse(response) {
778
- if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
779
- return response;
874
+ if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
875
+ const completion = response.choices[0];
876
+ return {
877
+ generated_text: completion.text
878
+ };
780
879
  }
781
- throw new InferenceOutputError("Expected Array<number>");
880
+ throw new InferenceOutputError("Expected Featherless AI text generation response format");
782
881
  }
783
882
  };
784
- var HFInferenceVisualQuestionAnsweringTask = class extends HFInferenceTask {
785
- async getResponse(response) {
786
- if (Array.isArray(response) && response.every(
787
- (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
788
- )) {
789
- return response[0];
790
- }
791
- throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
883
+
884
+ // src/providers/fireworks-ai.ts
885
+ var FireworksConversationalTask = class extends BaseConversationalTask {
886
+ constructor() {
887
+ super("fireworks-ai", "https://api.fireworks.ai");
888
+ }
889
+ makeRoute() {
890
+ return "/inference/v1/chat/completions";
792
891
  }
793
892
  };
794
- var HFInferenceTabularRegressionTask = class extends HFInferenceTask {
795
- async getResponse(response) {
796
- if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
797
- return response;
798
- }
799
- throw new InferenceOutputError("Expected Array<number>");
893
+
894
+ // src/providers/groq.ts
895
+ var GROQ_API_BASE_URL = "https://api.groq.com";
896
+ var GroqTextGenerationTask = class extends BaseTextGenerationTask {
897
+ constructor() {
898
+ super("groq", GROQ_API_BASE_URL);
899
+ }
900
+ makeRoute() {
901
+ return "/openai/v1/chat/completions";
800
902
  }
801
903
  };
802
- var HFInferenceTextToAudioTask = class extends HFInferenceTask {
803
- async getResponse(response) {
804
- return response;
904
+ var GroqConversationalTask = class extends BaseConversationalTask {
905
+ constructor() {
906
+ super("groq", GROQ_API_BASE_URL);
907
+ }
908
+ makeRoute() {
909
+ return "/openai/v1/chat/completions";
805
910
  }
806
911
  };
807
912
 
@@ -968,6 +1073,39 @@ var OpenAIConversationalTask = class extends BaseConversationalTask {
968
1073
  }
969
1074
  };
970
1075
 
1076
+ // src/providers/ovhcloud.ts
1077
+ var OVHCLOUD_API_BASE_URL = "https://oai.endpoints.kepler.ai.cloud.ovh.net";
1078
+ var OvhCloudConversationalTask = class extends BaseConversationalTask {
1079
+ constructor() {
1080
+ super("ovhcloud", OVHCLOUD_API_BASE_URL);
1081
+ }
1082
+ };
1083
+ var OvhCloudTextGenerationTask = class extends BaseTextGenerationTask {
1084
+ constructor() {
1085
+ super("ovhcloud", OVHCLOUD_API_BASE_URL);
1086
+ }
1087
+ preparePayload(params) {
1088
+ return {
1089
+ model: params.model,
1090
+ ...omit(params.args, ["inputs", "parameters"]),
1091
+ ...params.args.parameters ? {
1092
+ max_tokens: params.args.parameters.max_new_tokens,
1093
+ ...omit(params.args.parameters, "max_new_tokens")
1094
+ } : void 0,
1095
+ prompt: params.args.inputs
1096
+ };
1097
+ }
1098
+ async getResponse(response) {
1099
+ if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
1100
+ const completion = response.choices[0];
1101
+ return {
1102
+ generated_text: completion.text
1103
+ };
1104
+ }
1105
+ throw new InferenceOutputError("Expected OVHcloud text generation response format");
1106
+ }
1107
+ };
1108
+
971
1109
  // src/providers/replicate.ts
972
1110
  var ReplicateTask = class extends TaskProviderHelper {
973
1111
  constructor(url) {
@@ -1220,6 +1358,10 @@ var PROVIDERS = {
1220
1358
  openai: {
1221
1359
  conversational: new OpenAIConversationalTask()
1222
1360
  },
1361
+ ovhcloud: {
1362
+ conversational: new OvhCloudConversationalTask(),
1363
+ "text-generation": new OvhCloudTextGenerationTask()
1364
+ },
1223
1365
  replicate: {
1224
1366
  "text-to-image": new ReplicateTextToImageTask(),
1225
1367
  "text-to-speech": new ReplicateTextToSpeechTask(),
@@ -1258,81 +1400,13 @@ function getProviderHelper(provider, task) {
1258
1400
 
1259
1401
  // package.json
1260
1402
  var name = "@huggingface/inference";
1261
- var version = "3.9.2";
1262
-
1263
- // src/providers/consts.ts
1264
- var HARDCODED_MODEL_INFERENCE_MAPPING = {
1265
- /**
1266
- * "HF model ID" => "Model ID on Inference Provider's side"
1267
- *
1268
- * Example:
1269
- * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
1270
- */
1271
- "black-forest-labs": {},
1272
- cerebras: {},
1273
- cohere: {},
1274
- "fal-ai": {},
1275
- "featherless-ai": {},
1276
- "fireworks-ai": {},
1277
- groq: {},
1278
- "hf-inference": {},
1279
- hyperbolic: {},
1280
- nebius: {},
1281
- novita: {},
1282
- nscale: {},
1283
- openai: {},
1284
- replicate: {},
1285
- sambanova: {},
1286
- together: {}
1287
- };
1288
-
1289
- // src/lib/getInferenceProviderMapping.ts
1290
- var inferenceProviderMappingCache = /* @__PURE__ */ new Map();
1291
- async function getInferenceProviderMapping(params, options) {
1292
- if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
1293
- return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
1294
- }
1295
- let inferenceProviderMapping;
1296
- if (inferenceProviderMappingCache.has(params.modelId)) {
1297
- inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId);
1298
- } else {
1299
- const resp = await (options?.fetch ?? fetch)(
1300
- `${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`,
1301
- {
1302
- headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {}
1303
- }
1304
- );
1305
- if (resp.status === 404) {
1306
- throw new Error(`Model ${params.modelId} does not exist`);
1307
- }
1308
- inferenceProviderMapping = await resp.json().then((json) => json.inferenceProviderMapping).catch(() => null);
1309
- }
1310
- if (!inferenceProviderMapping) {
1311
- throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`);
1312
- }
1313
- const providerMapping = inferenceProviderMapping[params.provider];
1314
- if (providerMapping) {
1315
- const equivalentTasks = params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task];
1316
- if (!typedInclude(equivalentTasks, providerMapping.task)) {
1317
- throw new Error(
1318
- `Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
1319
- );
1320
- }
1321
- if (providerMapping.status === "staging") {
1322
- console.warn(
1323
- `Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
1324
- );
1325
- }
1326
- return { ...providerMapping, hfModelId: params.modelId };
1327
- }
1328
- return null;
1329
- }
1403
+ var version = "3.11.0";
1330
1404
 
1331
1405
  // src/lib/makeRequestOptions.ts
1332
1406
  var tasks = null;
1333
1407
  async function makeRequestOptions(args, providerHelper, options) {
1334
- const { provider: maybeProvider, model: maybeModel } = args;
1335
- const provider = maybeProvider ?? "hf-inference";
1408
+ const { model: maybeModel } = args;
1409
+ const provider = providerHelper.provider;
1336
1410
  const { task } = options ?? {};
1337
1411
  if (args.endpointUrl && provider !== "hf-inference") {
1338
1412
  throw new Error(`Cannot use endpointUrl with a third-party provider.`);
@@ -1387,7 +1461,7 @@ async function makeRequestOptions(args, providerHelper, options) {
1387
1461
  }
1388
1462
  function makeRequestOptionsFromResolvedModel(resolvedModel, providerHelper, args, mapping, options) {
1389
1463
  const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
1390
- const provider = maybeProvider ?? "hf-inference";
1464
+ const provider = providerHelper.provider;
1391
1465
  const { includeCredentials, task, signal, billTo } = options ?? {};
1392
1466
  const authMethod = (() => {
1393
1467
  if (providerHelper.clientSideRoutingOnly) {
@@ -1678,7 +1752,8 @@ async function request(args, options) {
1678
1752
  console.warn(
1679
1753
  "The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1680
1754
  );
1681
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
1755
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1756
+ const providerHelper = getProviderHelper(provider, options?.task);
1682
1757
  const result = await innerRequest(args, providerHelper, options);
1683
1758
  return result.data;
1684
1759
  }
@@ -1688,7 +1763,8 @@ async function* streamingRequest(args, options) {
1688
1763
  console.warn(
1689
1764
  "The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1690
1765
  );
1691
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
1766
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1767
+ const providerHelper = getProviderHelper(provider, options?.task);
1692
1768
  yield* innerStreamingRequest(args, providerHelper, options);
1693
1769
  }
1694
1770
 
@@ -1702,7 +1778,8 @@ function preparePayload(args) {
1702
1778
 
1703
1779
  // src/tasks/audio/audioClassification.ts
1704
1780
  async function audioClassification(args, options) {
1705
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
1781
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1782
+ const providerHelper = getProviderHelper(provider, "audio-classification");
1706
1783
  const payload = preparePayload(args);
1707
1784
  const { data: res } = await innerRequest(payload, providerHelper, {
1708
1785
  ...options,
@@ -1713,7 +1790,9 @@ async function audioClassification(args, options) {
1713
1790
 
1714
1791
  // src/tasks/audio/audioToAudio.ts
1715
1792
  async function audioToAudio(args, options) {
1716
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
1793
+ const model = "inputs" in args ? args.model : void 0;
1794
+ const provider = await resolveProvider(args.provider, model);
1795
+ const providerHelper = getProviderHelper(provider, "audio-to-audio");
1717
1796
  const payload = preparePayload(args);
1718
1797
  const { data: res } = await innerRequest(payload, providerHelper, {
1719
1798
  ...options,
@@ -1737,7 +1816,8 @@ function base64FromBytes(arr) {
1737
1816
 
1738
1817
  // src/tasks/audio/automaticSpeechRecognition.ts
1739
1818
  async function automaticSpeechRecognition(args, options) {
1740
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
1819
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1820
+ const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
1741
1821
  const payload = await buildPayload(args);
1742
1822
  const { data: res } = await innerRequest(payload, providerHelper, {
1743
1823
  ...options,
@@ -1777,7 +1857,7 @@ async function buildPayload(args) {
1777
1857
 
1778
1858
  // src/tasks/audio/textToSpeech.ts
1779
1859
  async function textToSpeech(args, options) {
1780
- const provider = args.provider ?? "hf-inference";
1860
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1781
1861
  const providerHelper = getProviderHelper(provider, "text-to-speech");
1782
1862
  const { data: res } = await innerRequest(args, providerHelper, {
1783
1863
  ...options,
@@ -1793,7 +1873,8 @@ function preparePayload2(args) {
1793
1873
 
1794
1874
  // src/tasks/cv/imageClassification.ts
1795
1875
  async function imageClassification(args, options) {
1796
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
1876
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1877
+ const providerHelper = getProviderHelper(provider, "image-classification");
1797
1878
  const payload = preparePayload2(args);
1798
1879
  const { data: res } = await innerRequest(payload, providerHelper, {
1799
1880
  ...options,
@@ -1804,7 +1885,8 @@ async function imageClassification(args, options) {
1804
1885
 
1805
1886
  // src/tasks/cv/imageSegmentation.ts
1806
1887
  async function imageSegmentation(args, options) {
1807
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
1888
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1889
+ const providerHelper = getProviderHelper(provider, "image-segmentation");
1808
1890
  const payload = preparePayload2(args);
1809
1891
  const { data: res } = await innerRequest(payload, providerHelper, {
1810
1892
  ...options,
@@ -1815,7 +1897,8 @@ async function imageSegmentation(args, options) {
1815
1897
 
1816
1898
  // src/tasks/cv/imageToImage.ts
1817
1899
  async function imageToImage(args, options) {
1818
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
1900
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1901
+ const providerHelper = getProviderHelper(provider, "image-to-image");
1819
1902
  let reqArgs;
1820
1903
  if (!args.parameters) {
1821
1904
  reqArgs = {
@@ -1840,7 +1923,8 @@ async function imageToImage(args, options) {
1840
1923
 
1841
1924
  // src/tasks/cv/imageToText.ts
1842
1925
  async function imageToText(args, options) {
1843
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
1926
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1927
+ const providerHelper = getProviderHelper(provider, "image-to-text");
1844
1928
  const payload = preparePayload2(args);
1845
1929
  const { data: res } = await innerRequest(payload, providerHelper, {
1846
1930
  ...options,
@@ -1851,7 +1935,8 @@ async function imageToText(args, options) {
1851
1935
 
1852
1936
  // src/tasks/cv/objectDetection.ts
1853
1937
  async function objectDetection(args, options) {
1854
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
1938
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1939
+ const providerHelper = getProviderHelper(provider, "object-detection");
1855
1940
  const payload = preparePayload2(args);
1856
1941
  const { data: res } = await innerRequest(payload, providerHelper, {
1857
1942
  ...options,
@@ -1862,7 +1947,7 @@ async function objectDetection(args, options) {
1862
1947
 
1863
1948
  // src/tasks/cv/textToImage.ts
1864
1949
  async function textToImage(args, options) {
1865
- const provider = args.provider ?? "hf-inference";
1950
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1866
1951
  const providerHelper = getProviderHelper(provider, "text-to-image");
1867
1952
  const { data: res } = await innerRequest(args, providerHelper, {
1868
1953
  ...options,
@@ -1874,7 +1959,7 @@ async function textToImage(args, options) {
1874
1959
 
1875
1960
  // src/tasks/cv/textToVideo.ts
1876
1961
  async function textToVideo(args, options) {
1877
- const provider = args.provider ?? "hf-inference";
1962
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1878
1963
  const providerHelper = getProviderHelper(provider, "text-to-video");
1879
1964
  const { data: response } = await innerRequest(
1880
1965
  args,
@@ -1911,7 +1996,8 @@ async function preparePayload3(args) {
1911
1996
  }
1912
1997
  }
1913
1998
  async function zeroShotImageClassification(args, options) {
1914
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
1999
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2000
+ const providerHelper = getProviderHelper(provider, "zero-shot-image-classification");
1915
2001
  const payload = await preparePayload3(args);
1916
2002
  const { data: res } = await innerRequest(payload, providerHelper, {
1917
2003
  ...options,
@@ -1922,7 +2008,8 @@ async function zeroShotImageClassification(args, options) {
1922
2008
 
1923
2009
  // src/tasks/nlp/chatCompletion.ts
1924
2010
  async function chatCompletion(args, options) {
1925
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
2011
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2012
+ const providerHelper = getProviderHelper(provider, "conversational");
1926
2013
  const { data: response } = await innerRequest(args, providerHelper, {
1927
2014
  ...options,
1928
2015
  task: "conversational"
@@ -1932,7 +2019,8 @@ async function chatCompletion(args, options) {
1932
2019
 
1933
2020
  // src/tasks/nlp/chatCompletionStream.ts
1934
2021
  async function* chatCompletionStream(args, options) {
1935
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
2022
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2023
+ const providerHelper = getProviderHelper(provider, "conversational");
1936
2024
  yield* innerStreamingRequest(args, providerHelper, {
1937
2025
  ...options,
1938
2026
  task: "conversational"
@@ -1941,7 +2029,8 @@ async function* chatCompletionStream(args, options) {
1941
2029
 
1942
2030
  // src/tasks/nlp/featureExtraction.ts
1943
2031
  async function featureExtraction(args, options) {
1944
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction");
2032
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2033
+ const providerHelper = getProviderHelper(provider, "feature-extraction");
1945
2034
  const { data: res } = await innerRequest(args, providerHelper, {
1946
2035
  ...options,
1947
2036
  task: "feature-extraction"
@@ -1951,7 +2040,8 @@ async function featureExtraction(args, options) {
1951
2040
 
1952
2041
  // src/tasks/nlp/fillMask.ts
1953
2042
  async function fillMask(args, options) {
1954
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask");
2043
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2044
+ const providerHelper = getProviderHelper(provider, "fill-mask");
1955
2045
  const { data: res } = await innerRequest(args, providerHelper, {
1956
2046
  ...options,
1957
2047
  task: "fill-mask"
@@ -1961,7 +2051,8 @@ async function fillMask(args, options) {
1961
2051
 
1962
2052
  // src/tasks/nlp/questionAnswering.ts
1963
2053
  async function questionAnswering(args, options) {
1964
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering");
2054
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2055
+ const providerHelper = getProviderHelper(provider, "question-answering");
1965
2056
  const { data: res } = await innerRequest(
1966
2057
  args,
1967
2058
  providerHelper,
@@ -1975,7 +2066,8 @@ async function questionAnswering(args, options) {
1975
2066
 
1976
2067
  // src/tasks/nlp/sentenceSimilarity.ts
1977
2068
  async function sentenceSimilarity(args, options) {
1978
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity");
2069
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2070
+ const providerHelper = getProviderHelper(provider, "sentence-similarity");
1979
2071
  const { data: res } = await innerRequest(args, providerHelper, {
1980
2072
  ...options,
1981
2073
  task: "sentence-similarity"
@@ -1985,7 +2077,8 @@ async function sentenceSimilarity(args, options) {
1985
2077
 
1986
2078
  // src/tasks/nlp/summarization.ts
1987
2079
  async function summarization(args, options) {
1988
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization");
2080
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2081
+ const providerHelper = getProviderHelper(provider, "summarization");
1989
2082
  const { data: res } = await innerRequest(args, providerHelper, {
1990
2083
  ...options,
1991
2084
  task: "summarization"
@@ -1995,7 +2088,8 @@ async function summarization(args, options) {
1995
2088
 
1996
2089
  // src/tasks/nlp/tableQuestionAnswering.ts
1997
2090
  async function tableQuestionAnswering(args, options) {
1998
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
2091
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2092
+ const providerHelper = getProviderHelper(provider, "table-question-answering");
1999
2093
  const { data: res } = await innerRequest(
2000
2094
  args,
2001
2095
  providerHelper,
@@ -2009,7 +2103,8 @@ async function tableQuestionAnswering(args, options) {
2009
2103
 
2010
2104
  // src/tasks/nlp/textClassification.ts
2011
2105
  async function textClassification(args, options) {
2012
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
2106
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2107
+ const providerHelper = getProviderHelper(provider, "text-classification");
2013
2108
  const { data: res } = await innerRequest(args, providerHelper, {
2014
2109
  ...options,
2015
2110
  task: "text-classification"
@@ -2019,7 +2114,8 @@ async function textClassification(args, options) {
2019
2114
 
2020
2115
  // src/tasks/nlp/textGeneration.ts
2021
2116
  async function textGeneration(args, options) {
2022
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
2117
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2118
+ const providerHelper = getProviderHelper(provider, "text-generation");
2023
2119
  const { data: response } = await innerRequest(args, providerHelper, {
2024
2120
  ...options,
2025
2121
  task: "text-generation"
@@ -2029,7 +2125,8 @@ async function textGeneration(args, options) {
2029
2125
 
2030
2126
  // src/tasks/nlp/textGenerationStream.ts
2031
2127
  async function* textGenerationStream(args, options) {
2032
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
2128
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2129
+ const providerHelper = getProviderHelper(provider, "text-generation");
2033
2130
  yield* innerStreamingRequest(args, providerHelper, {
2034
2131
  ...options,
2035
2132
  task: "text-generation"
@@ -2038,7 +2135,8 @@ async function* textGenerationStream(args, options) {
2038
2135
 
2039
2136
  // src/tasks/nlp/tokenClassification.ts
2040
2137
  async function tokenClassification(args, options) {
2041
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification");
2138
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2139
+ const providerHelper = getProviderHelper(provider, "token-classification");
2042
2140
  const { data: res } = await innerRequest(
2043
2141
  args,
2044
2142
  providerHelper,
@@ -2052,7 +2150,8 @@ async function tokenClassification(args, options) {
2052
2150
 
2053
2151
  // src/tasks/nlp/translation.ts
2054
2152
  async function translation(args, options) {
2055
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation");
2153
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2154
+ const providerHelper = getProviderHelper(provider, "translation");
2056
2155
  const { data: res } = await innerRequest(args, providerHelper, {
2057
2156
  ...options,
2058
2157
  task: "translation"
@@ -2062,7 +2161,8 @@ async function translation(args, options) {
2062
2161
 
2063
2162
  // src/tasks/nlp/zeroShotClassification.ts
2064
2163
  async function zeroShotClassification(args, options) {
2065
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification");
2164
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2165
+ const providerHelper = getProviderHelper(provider, "zero-shot-classification");
2066
2166
  const { data: res } = await innerRequest(
2067
2167
  args,
2068
2168
  providerHelper,
@@ -2076,7 +2176,8 @@ async function zeroShotClassification(args, options) {
2076
2176
 
2077
2177
  // src/tasks/multimodal/documentQuestionAnswering.ts
2078
2178
  async function documentQuestionAnswering(args, options) {
2079
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering");
2179
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2180
+ const providerHelper = getProviderHelper(provider, "document-question-answering");
2080
2181
  const reqArgs = {
2081
2182
  ...args,
2082
2183
  inputs: {
@@ -2098,7 +2199,8 @@ async function documentQuestionAnswering(args, options) {
2098
2199
 
2099
2200
  // src/tasks/multimodal/visualQuestionAnswering.ts
2100
2201
  async function visualQuestionAnswering(args, options) {
2101
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "visual-question-answering");
2202
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2203
+ const providerHelper = getProviderHelper(provider, "visual-question-answering");
2102
2204
  const reqArgs = {
2103
2205
  ...args,
2104
2206
  inputs: {
@@ -2116,7 +2218,8 @@ async function visualQuestionAnswering(args, options) {
2116
2218
 
2117
2219
  // src/tasks/tabular/tabularClassification.ts
2118
2220
  async function tabularClassification(args, options) {
2119
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification");
2221
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2222
+ const providerHelper = getProviderHelper(provider, "tabular-classification");
2120
2223
  const { data: res } = await innerRequest(args, providerHelper, {
2121
2224
  ...options,
2122
2225
  task: "tabular-classification"
@@ -2126,7 +2229,8 @@ async function tabularClassification(args, options) {
2126
2229
 
2127
2230
  // src/tasks/tabular/tabularRegression.ts
2128
2231
  async function tabularRegression(args, options) {
2129
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression");
2232
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2233
+ const providerHelper = getProviderHelper(provider, "tabular-regression");
2130
2234
  const { data: res } = await innerRequest(args, providerHelper, {
2131
2235
  ...options,
2132
2236
  task: "tabular-regression"
@@ -2134,6 +2238,11 @@ async function tabularRegression(args, options) {
2134
2238
  return providerHelper.getResponse(res);
2135
2239
  }
2136
2240
 
2241
+ // src/utils/typedEntries.ts
2242
+ function typedEntries(obj) {
2243
+ return Object.entries(obj);
2244
+ }
2245
+
2137
2246
  // src/InferenceClient.ts
2138
2247
  var InferenceClient = class {
2139
2248
  accessToken;
@@ -2141,40 +2250,36 @@ var InferenceClient = class {
2141
2250
  constructor(accessToken = "", defaultOptions = {}) {
2142
2251
  this.accessToken = accessToken;
2143
2252
  this.defaultOptions = defaultOptions;
2144
- for (const [name2, fn] of Object.entries(tasks_exports)) {
2253
+ for (const [name2, fn] of typedEntries(tasks_exports)) {
2145
2254
  Object.defineProperty(this, name2, {
2146
2255
  enumerable: false,
2147
2256
  value: (params, options) => (
2148
2257
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
2149
- fn({ ...params, accessToken }, { ...defaultOptions, ...options })
2258
+ fn(
2259
+ /// ^ The cast of fn to any is necessary, otherwise TS can't compile because the generated union type is too complex
2260
+ { endpointUrl: defaultOptions.endpointUrl, accessToken, ...params },
2261
+ {
2262
+ ...omit(defaultOptions, ["endpointUrl"]),
2263
+ ...options
2264
+ }
2265
+ )
2150
2266
  )
2151
2267
  });
2152
2268
  }
2153
2269
  }
2154
2270
  /**
2155
- * Returns copy of InferenceClient tied to a specified endpoint.
2271
+ * Returns a new instance of InferenceClient tied to a specified endpoint.
2272
+ *
2273
+ * For backward compatibility mostly.
2156
2274
  */
2157
2275
  endpoint(endpointUrl) {
2158
- return new InferenceClientEndpoint(endpointUrl, this.accessToken, this.defaultOptions);
2159
- }
2160
- };
2161
- var InferenceClientEndpoint = class {
2162
- constructor(endpointUrl, accessToken = "", defaultOptions = {}) {
2163
- accessToken;
2164
- defaultOptions;
2165
- for (const [name2, fn] of Object.entries(tasks_exports)) {
2166
- Object.defineProperty(this, name2, {
2167
- enumerable: false,
2168
- value: (params, options) => (
2169
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
2170
- fn({ ...params, accessToken, endpointUrl }, { ...defaultOptions, ...options })
2171
- )
2172
- });
2173
- }
2276
+ return new InferenceClient(this.accessToken, { ...this.defaultOptions, endpointUrl });
2174
2277
  }
2175
2278
  };
2176
2279
  var HfInference = class extends InferenceClient {
2177
2280
  };
2281
+ var InferenceClientEndpoint = class extends InferenceClient {
2282
+ };
2178
2283
 
2179
2284
  // src/types.ts
2180
2285
  var INFERENCE_PROVIDERS = [
@@ -2191,10 +2296,12 @@ var INFERENCE_PROVIDERS = [
2191
2296
  "novita",
2192
2297
  "nscale",
2193
2298
  "openai",
2299
+ "ovhcloud",
2194
2300
  "replicate",
2195
2301
  "sambanova",
2196
2302
  "together"
2197
2303
  ];
2304
+ var PROVIDERS_OR_POLICIES = [...INFERENCE_PROVIDERS, "auto"];
2198
2305
 
2199
2306
  // src/snippets/index.ts
2200
2307
  var snippets_exports = {};
@@ -2218,6 +2325,7 @@ var templates = {
2218
2325
  "basicImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "image/jpeg",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
2219
2326
  "textToAudio": '{% if model.library_name == "transformers" %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ',
2220
2327
  "textToImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\n\nquery({ {{ providerInputs.asTsString }} }).then((response) => {\n // Use image\n});',
2328
+ "textToSpeech": '{% if model.library_name == "transformers" %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\nquery({ text: {{ inputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ text: {{ inputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ',
2221
2329
  "zeroShotClassification": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({\n inputs: {{ providerInputs.asObj.inputs }},\n parameters: { candidate_labels: ["refund", "legal", "faq"] }\n}).then((response) => {\n console.log(JSON.stringify(response));\n});'
2222
2330
  },
2223
2331
  "huggingface.js": {
@@ -2239,11 +2347,23 @@ const image = await client.textToImage({
2239
2347
  billTo: "{{ billTo }}",
2240
2348
  }{% endif %});
2241
2349
  /// Use the generated image (it's a Blob)`,
2350
+ "textToSpeech": `import { InferenceClient } from "@huggingface/inference";
2351
+
2352
+ const client = new InferenceClient("{{ accessToken }}");
2353
+
2354
+ const audio = await client.textToSpeech({
2355
+ provider: "{{ provider }}",
2356
+ model: "{{ model.id }}",
2357
+ inputs: {{ inputs.asObj.inputs }},
2358
+ }{% if billTo %}, {
2359
+ billTo: "{{ billTo }}",
2360
+ }{% endif %});
2361
+ // Use the generated audio (it's a Blob)`,
2242
2362
  "textToVideo": `import { InferenceClient } from "@huggingface/inference";
2243
2363
 
2244
2364
  const client = new InferenceClient("{{ accessToken }}");
2245
2365
 
2246
- const image = await client.textToVideo({
2366
+ const video = await client.textToVideo({
2247
2367
  provider: "{{ provider }}",
2248
2368
  model: "{{ model.id }}",
2249
2369
  inputs: {{ inputs.asObj.inputs }},
@@ -2259,7 +2379,7 @@ const image = await client.textToVideo({
2259
2379
  },
2260
2380
  "python": {
2261
2381
  "fal_client": {
2262
- "textToImage": '{% if provider == "fal-ai" %}\nimport fal_client\n\nresult = fal_client.subscribe(\n "{{ providerModelId }}",\n arguments={\n "prompt": {{ inputs.asObj.inputs }},\n },\n)\nprint(result)\n{% endif %} '
2382
+ "textToImage": '{% if provider == "fal-ai" %}\nimport fal_client\n\n{% if providerInputs.asObj.loras is defined and providerInputs.asObj.loras != none %}\nresult = fal_client.subscribe(\n "{{ providerModelId }}",\n arguments={\n "prompt": {{ inputs.asObj.inputs }},\n "loras":{{ providerInputs.asObj.loras | tojson }},\n },\n)\n{% else %}\nresult = fal_client.subscribe(\n "{{ providerModelId }}",\n arguments={\n "prompt": {{ inputs.asObj.inputs }},\n },\n)\n{% endif %} \nprint(result)\n{% endif %} '
2263
2383
  },
2264
2384
  "huggingface_hub": {
2265
2385
  "basic": 'result = client.{{ methodName }}(\n inputs={{ inputs.asObj.inputs }},\n model="{{ model.id }}",\n)',
@@ -2271,6 +2391,7 @@ const image = await client.textToVideo({
2271
2391
  "imageToImage": '# output is a PIL.Image object\nimage = client.image_to_image(\n "{{ inputs.asObj.inputs }}",\n prompt="{{ inputs.asObj.parameters.prompt }}",\n model="{{ model.id }}",\n) ',
2272
2392
  "importInferenceClient": 'from huggingface_hub import InferenceClient\n\nclient = InferenceClient(\n provider="{{ provider }}",\n api_key="{{ accessToken }}",\n{% if billTo %}\n bill_to="{{ billTo }}",\n{% endif %}\n)',
2273
2393
  "textToImage": '# output is a PIL.Image object\nimage = client.text_to_image(\n {{ inputs.asObj.inputs }},\n model="{{ model.id }}",\n) ',
2394
+ "textToSpeech": '# audio is returned as bytes\naudio = client.text_to_speech(\n {{ inputs.asObj.inputs }},\n model="{{ model.id }}",\n) \n',
2274
2395
  "textToVideo": 'video = client.text_to_video(\n {{ inputs.asObj.inputs }},\n model="{{ model.id }}",\n) '
2275
2396
  },
2276
2397
  "openai": {
@@ -2287,8 +2408,9 @@ const image = await client.textToVideo({
2287
2408
  "imageToImage": 'def query(payload):\n with open(payload["inputs"], "rb") as f:\n img = f.read()\n payload["inputs"] = base64.b64encode(img).decode("utf-8")\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nimage_bytes = query({\n{{ providerInputs.asJsonString }}\n})\n\n# You can access the image with PIL.Image for example\nimport io\nfrom PIL import Image\nimage = Image.open(io.BytesIO(image_bytes)) ',
2288
2409
  "importRequests": '{% if importBase64 %}\nimport base64\n{% endif %}\n{% if importJson %}\nimport json\n{% endif %}\nimport requests\n\nAPI_URL = "{{ fullUrl }}"\nheaders = {\n "Authorization": "{{ authorizationHeader }}",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}"\n{% endif %}\n}',
2289
2410
  "tabular": 'def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nresponse = query({\n "inputs": {\n "data": {{ providerInputs.asObj.inputs }}\n },\n}) ',
2290
- "textToAudio": '{% if model.library_name == "transformers" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\naudio_bytes = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio_bytes)\n{% else %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\naudio, sampling_rate = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio, rate=sampling_rate)\n{% endif %} ',
2411
+ "textToAudio": '{% if model.library_name == "transformers" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\naudio_bytes = query({\n "inputs": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio_bytes)\n{% else %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\naudio, sampling_rate = query({\n "inputs": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio, rate=sampling_rate)\n{% endif %} ',
2291
2412
  "textToImage": '{% if provider == "hf-inference" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nimage_bytes = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n})\n\n# You can access the image with PIL.Image for example\nimport io\nfrom PIL import Image\nimage = Image.open(io.BytesIO(image_bytes))\n{% endif %}',
2413
+ "textToSpeech": '{% if model.library_name == "transformers" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\naudio_bytes = query({\n "text": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio_bytes)\n{% else %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\naudio, sampling_rate = query({\n "text": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio, rate=sampling_rate)\n{% endif %} ',
2292
2414
  "zeroShotClassification": 'def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\noutput = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n "parameters": {"candidate_labels": ["refund", "legal", "faq"]},\n}) ',
2293
2415
  "zeroShotImageClassification": 'def query(data):\n with open(data["image_path"], "rb") as f:\n img = f.read()\n payload={\n "parameters": data["parameters"],\n "inputs": base64.b64encode(img).decode("utf-8")\n }\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\noutput = query({\n "image_path": {{ providerInputs.asObj.inputs }},\n "parameters": {"candidate_labels": ["cat", "dog", "llama"]},\n}) '
2294
2416
  }
@@ -2391,6 +2513,7 @@ var HF_JS_METHODS = {
2391
2513
  "text-generation": "textGeneration",
2392
2514
  "text2text-generation": "textGeneration",
2393
2515
  "token-classification": "tokenClassification",
2516
+ "text-to-speech": "textToSpeech",
2394
2517
  translation: "translation"
2395
2518
  };
2396
2519
  var snippetGenerator = (templateName, inputPreparationFn) => {
@@ -2510,7 +2633,7 @@ var prepareConversationalInput = (model, opts) => {
2510
2633
  return {
2511
2634
  messages: opts?.messages ?? getModelInputSnippet(model),
2512
2635
  ...opts?.temperature ? { temperature: opts?.temperature } : void 0,
2513
- max_tokens: opts?.max_tokens ?? 512,
2636
+ ...opts?.max_tokens ? { max_tokens: opts?.max_tokens } : void 0,
2514
2637
  ...opts?.top_p ? { top_p: opts?.top_p } : void 0
2515
2638
  };
2516
2639
  };
@@ -2537,7 +2660,7 @@ var snippets = {
2537
2660
  "text-generation": snippetGenerator("basic"),
2538
2661
  "text-to-audio": snippetGenerator("textToAudio"),
2539
2662
  "text-to-image": snippetGenerator("textToImage"),
2540
- "text-to-speech": snippetGenerator("textToAudio"),
2663
+ "text-to-speech": snippetGenerator("textToSpeech"),
2541
2664
  "text-to-video": snippetGenerator("textToVideo"),
2542
2665
  "text2text-generation": snippetGenerator("basic"),
2543
2666
  "token-classification": snippetGenerator("basic"),
@@ -2603,6 +2726,7 @@ export {
2603
2726
  InferenceClient,
2604
2727
  InferenceClientEndpoint,
2605
2728
  InferenceOutputError,
2729
+ PROVIDERS_OR_POLICIES,
2606
2730
  audioClassification,
2607
2731
  audioToAudio,
2608
2732
  automaticSpeechRecognition,