@huggingface/inference 3.3.6 → 3.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. package/README.md +2 -0
  2. package/dist/index.cjs +339 -174
  3. package/dist/index.js +339 -174
  4. package/dist/src/lib/getProviderModelId.d.ts +1 -1
  5. package/dist/src/lib/getProviderModelId.d.ts.map +1 -1
  6. package/dist/src/lib/makeRequestOptions.d.ts +2 -2
  7. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  8. package/dist/src/providers/black-forest-labs.d.ts +2 -1
  9. package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
  10. package/dist/src/providers/cohere.d.ts +19 -0
  11. package/dist/src/providers/cohere.d.ts.map +1 -0
  12. package/dist/src/providers/consts.d.ts.map +1 -1
  13. package/dist/src/providers/fal-ai.d.ts +2 -1
  14. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  15. package/dist/src/providers/fireworks-ai.d.ts +2 -1
  16. package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
  17. package/dist/src/providers/hf-inference.d.ts +3 -0
  18. package/dist/src/providers/hf-inference.d.ts.map +1 -0
  19. package/dist/src/providers/hyperbolic.d.ts +2 -1
  20. package/dist/src/providers/hyperbolic.d.ts.map +1 -1
  21. package/dist/src/providers/nebius.d.ts +2 -1
  22. package/dist/src/providers/nebius.d.ts.map +1 -1
  23. package/dist/src/providers/novita.d.ts +2 -1
  24. package/dist/src/providers/novita.d.ts.map +1 -1
  25. package/dist/src/providers/replicate.d.ts +3 -1
  26. package/dist/src/providers/replicate.d.ts.map +1 -1
  27. package/dist/src/providers/sambanova.d.ts +2 -1
  28. package/dist/src/providers/sambanova.d.ts.map +1 -1
  29. package/dist/src/providers/together.d.ts +2 -1
  30. package/dist/src/providers/together.d.ts.map +1 -1
  31. package/dist/src/tasks/custom/request.d.ts +2 -4
  32. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  33. package/dist/src/tasks/custom/streamingRequest.d.ts +2 -4
  34. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  35. package/dist/src/tasks/nlp/featureExtraction.d.ts +2 -9
  36. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  37. package/dist/src/types.d.ts +25 -4
  38. package/dist/src/types.d.ts.map +1 -1
  39. package/package.json +2 -2
  40. package/src/lib/getProviderModelId.ts +4 -4
  41. package/src/lib/makeRequestOptions.ts +74 -186
  42. package/src/providers/black-forest-labs.ts +26 -2
  43. package/src/providers/cohere.ts +42 -0
  44. package/src/providers/consts.ts +2 -1
  45. package/src/providers/fal-ai.ts +24 -2
  46. package/src/providers/fireworks-ai.ts +28 -2
  47. package/src/providers/hf-inference.ts +43 -0
  48. package/src/providers/hyperbolic.ts +28 -2
  49. package/src/providers/nebius.ts +34 -2
  50. package/src/providers/novita.ts +31 -2
  51. package/src/providers/replicate.ts +30 -2
  52. package/src/providers/sambanova.ts +28 -2
  53. package/src/providers/together.ts +34 -2
  54. package/src/tasks/audio/audioClassification.ts +1 -1
  55. package/src/tasks/audio/audioToAudio.ts +1 -1
  56. package/src/tasks/audio/automaticSpeechRecognition.ts +1 -1
  57. package/src/tasks/audio/textToSpeech.ts +1 -1
  58. package/src/tasks/custom/request.ts +2 -4
  59. package/src/tasks/custom/streamingRequest.ts +2 -4
  60. package/src/tasks/cv/imageClassification.ts +1 -1
  61. package/src/tasks/cv/imageSegmentation.ts +1 -1
  62. package/src/tasks/cv/imageToImage.ts +1 -1
  63. package/src/tasks/cv/imageToText.ts +1 -1
  64. package/src/tasks/cv/objectDetection.ts +1 -1
  65. package/src/tasks/cv/textToImage.ts +1 -1
  66. package/src/tasks/cv/textToVideo.ts +1 -1
  67. package/src/tasks/cv/zeroShotImageClassification.ts +1 -1
  68. package/src/tasks/multimodal/documentQuestionAnswering.ts +1 -1
  69. package/src/tasks/multimodal/visualQuestionAnswering.ts +1 -1
  70. package/src/tasks/nlp/chatCompletion.ts +1 -1
  71. package/src/tasks/nlp/chatCompletionStream.ts +1 -1
  72. package/src/tasks/nlp/featureExtraction.ts +3 -10
  73. package/src/tasks/nlp/fillMask.ts +1 -1
  74. package/src/tasks/nlp/questionAnswering.ts +1 -1
  75. package/src/tasks/nlp/sentenceSimilarity.ts +1 -1
  76. package/src/tasks/nlp/summarization.ts +1 -1
  77. package/src/tasks/nlp/tableQuestionAnswering.ts +1 -1
  78. package/src/tasks/nlp/textClassification.ts +1 -1
  79. package/src/tasks/nlp/textGeneration.ts +3 -3
  80. package/src/tasks/nlp/textGenerationStream.ts +1 -1
  81. package/src/tasks/nlp/tokenClassification.ts +1 -1
  82. package/src/tasks/nlp/translation.ts +1 -1
  83. package/src/tasks/nlp/zeroShotClassification.ts +1 -1
  84. package/src/tasks/tabular/tabularClassification.ts +1 -1
  85. package/src/tasks/tabular/tabularRegression.ts +1 -1
  86. package/src/types.ts +29 -2
package/dist/index.cjs CHANGED
@@ -100,32 +100,277 @@ __export(tasks_exports, {
100
100
  var HF_HUB_URL = "https://huggingface.co";
101
101
  var HF_ROUTER_URL = "https://router.huggingface.co";
102
102
 
103
+ // src/providers/black-forest-labs.ts
104
+ var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
105
+ var makeBody = (params) => {
106
+ return params.args;
107
+ };
108
+ var makeHeaders = (params) => {
109
+ if (params.authMethod === "provider-key") {
110
+ return { "X-Key": `${params.accessToken}` };
111
+ } else {
112
+ return { Authorization: `Bearer ${params.accessToken}` };
113
+ }
114
+ };
115
+ var makeUrl = (params) => {
116
+ return `${params.baseUrl}/${params.model}`;
117
+ };
118
+ var BLACK_FOREST_LABS_CONFIG = {
119
+ baseUrl: BLACK_FOREST_LABS_AI_API_BASE_URL,
120
+ makeBody,
121
+ makeHeaders,
122
+ makeUrl
123
+ };
124
+
125
+ // src/providers/cohere.ts
126
+ var COHERE_API_BASE_URL = "https://api.cohere.com";
127
+ var makeBody2 = (params) => {
128
+ return {
129
+ ...params.args,
130
+ model: params.model
131
+ };
132
+ };
133
+ var makeHeaders2 = (params) => {
134
+ return { Authorization: `Bearer ${params.accessToken}` };
135
+ };
136
+ var makeUrl2 = (params) => {
137
+ return `${params.baseUrl}/compatibility/v1/chat/completions`;
138
+ };
139
+ var COHERE_CONFIG = {
140
+ baseUrl: COHERE_API_BASE_URL,
141
+ makeBody: makeBody2,
142
+ makeHeaders: makeHeaders2,
143
+ makeUrl: makeUrl2
144
+ };
145
+
103
146
  // src/providers/fal-ai.ts
104
147
  var FAL_AI_API_BASE_URL = "https://fal.run";
148
+ var makeBody3 = (params) => {
149
+ return params.args;
150
+ };
151
+ var makeHeaders3 = (params) => {
152
+ return {
153
+ Authorization: params.authMethod === "provider-key" ? `Key ${params.accessToken}` : `Bearer ${params.accessToken}`
154
+ };
155
+ };
156
+ var makeUrl3 = (params) => {
157
+ return `${params.baseUrl}/${params.model}`;
158
+ };
159
+ var FAL_AI_CONFIG = {
160
+ baseUrl: FAL_AI_API_BASE_URL,
161
+ makeBody: makeBody3,
162
+ makeHeaders: makeHeaders3,
163
+ makeUrl: makeUrl3
164
+ };
165
+
166
+ // src/providers/fireworks-ai.ts
167
+ var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
168
+ var makeBody4 = (params) => {
169
+ return {
170
+ ...params.args,
171
+ ...params.chatCompletion ? { model: params.model } : void 0
172
+ };
173
+ };
174
+ var makeHeaders4 = (params) => {
175
+ return { Authorization: `Bearer ${params.accessToken}` };
176
+ };
177
+ var makeUrl4 = (params) => {
178
+ if (params.task === "text-generation" && params.chatCompletion) {
179
+ return `${params.baseUrl}/v1/chat/completions`;
180
+ }
181
+ return params.baseUrl;
182
+ };
183
+ var FIREWORKS_AI_CONFIG = {
184
+ baseUrl: FIREWORKS_AI_API_BASE_URL,
185
+ makeBody: makeBody4,
186
+ makeHeaders: makeHeaders4,
187
+ makeUrl: makeUrl4
188
+ };
189
+
190
+ // src/providers/hf-inference.ts
191
+ var makeBody5 = (params) => {
192
+ return {
193
+ ...params.args,
194
+ ...params.chatCompletion ? { model: params.model } : void 0
195
+ };
196
+ };
197
+ var makeHeaders5 = (params) => {
198
+ return { Authorization: `Bearer ${params.accessToken}` };
199
+ };
200
+ var makeUrl5 = (params) => {
201
+ if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
202
+ return `${params.baseUrl}/pipeline/${params.task}/${params.model}`;
203
+ }
204
+ if (params.task === "text-generation" && params.chatCompletion) {
205
+ return `${params.baseUrl}/models/${params.model}/v1/chat/completions`;
206
+ }
207
+ return `${params.baseUrl}/models/${params.model}`;
208
+ };
209
+ var HF_INFERENCE_CONFIG = {
210
+ baseUrl: `${HF_ROUTER_URL}/hf-inference`,
211
+ makeBody: makeBody5,
212
+ makeHeaders: makeHeaders5,
213
+ makeUrl: makeUrl5
214
+ };
215
+
216
+ // src/providers/hyperbolic.ts
217
+ var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
218
+ var makeBody6 = (params) => {
219
+ return {
220
+ ...params.args,
221
+ ...params.task === "text-to-image" ? { model_name: params.model } : { model: params.model }
222
+ };
223
+ };
224
+ var makeHeaders6 = (params) => {
225
+ return { Authorization: `Bearer ${params.accessToken}` };
226
+ };
227
+ var makeUrl6 = (params) => {
228
+ if (params.task === "text-to-image") {
229
+ return `${params.baseUrl}/v1/images/generations`;
230
+ }
231
+ return `${params.baseUrl}/v1/chat/completions`;
232
+ };
233
+ var HYPERBOLIC_CONFIG = {
234
+ baseUrl: HYPERBOLIC_API_BASE_URL,
235
+ makeBody: makeBody6,
236
+ makeHeaders: makeHeaders6,
237
+ makeUrl: makeUrl6
238
+ };
105
239
 
106
240
  // src/providers/nebius.ts
107
241
  var NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
242
+ var makeBody7 = (params) => {
243
+ return {
244
+ ...params.args,
245
+ model: params.model
246
+ };
247
+ };
248
+ var makeHeaders7 = (params) => {
249
+ return { Authorization: `Bearer ${params.accessToken}` };
250
+ };
251
+ var makeUrl7 = (params) => {
252
+ if (params.task === "text-to-image") {
253
+ return `${params.baseUrl}/v1/images/generations`;
254
+ }
255
+ if (params.task === "text-generation") {
256
+ if (params.chatCompletion) {
257
+ return `${params.baseUrl}/v1/chat/completions`;
258
+ }
259
+ return `${params.baseUrl}/v1/completions`;
260
+ }
261
+ return params.baseUrl;
262
+ };
263
+ var NEBIUS_CONFIG = {
264
+ baseUrl: NEBIUS_API_BASE_URL,
265
+ makeBody: makeBody7,
266
+ makeHeaders: makeHeaders7,
267
+ makeUrl: makeUrl7
268
+ };
269
+
270
+ // src/providers/novita.ts
271
+ var NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
272
+ var makeBody8 = (params) => {
273
+ return {
274
+ ...params.args,
275
+ ...params.chatCompletion ? { model: params.model } : void 0
276
+ };
277
+ };
278
+ var makeHeaders8 = (params) => {
279
+ return { Authorization: `Bearer ${params.accessToken}` };
280
+ };
281
+ var makeUrl8 = (params) => {
282
+ if (params.task === "text-generation") {
283
+ if (params.chatCompletion) {
284
+ return `${params.baseUrl}/chat/completions`;
285
+ }
286
+ return `${params.baseUrl}/completions`;
287
+ }
288
+ return params.baseUrl;
289
+ };
290
+ var NOVITA_CONFIG = {
291
+ baseUrl: NOVITA_API_BASE_URL,
292
+ makeBody: makeBody8,
293
+ makeHeaders: makeHeaders8,
294
+ makeUrl: makeUrl8
295
+ };
108
296
 
109
297
  // src/providers/replicate.ts
110
298
  var REPLICATE_API_BASE_URL = "https://api.replicate.com";
299
+ var makeBody9 = (params) => {
300
+ return {
301
+ input: params.args,
302
+ version: params.model.includes(":") ? params.model.split(":")[1] : void 0
303
+ };
304
+ };
305
+ var makeHeaders9 = (params) => {
306
+ return { Authorization: `Bearer ${params.accessToken}` };
307
+ };
308
+ var makeUrl9 = (params) => {
309
+ if (params.model.includes(":")) {
310
+ return `${params.baseUrl}/v1/predictions`;
311
+ }
312
+ return `${params.baseUrl}/v1/models/${params.model}/predictions`;
313
+ };
314
+ var REPLICATE_CONFIG = {
315
+ baseUrl: REPLICATE_API_BASE_URL,
316
+ makeBody: makeBody9,
317
+ makeHeaders: makeHeaders9,
318
+ makeUrl: makeUrl9
319
+ };
111
320
 
112
321
  // src/providers/sambanova.ts
113
322
  var SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
323
+ var makeBody10 = (params) => {
324
+ return {
325
+ ...params.args,
326
+ ...params.chatCompletion ? { model: params.model } : void 0
327
+ };
328
+ };
329
+ var makeHeaders10 = (params) => {
330
+ return { Authorization: `Bearer ${params.accessToken}` };
331
+ };
332
+ var makeUrl10 = (params) => {
333
+ if (params.task === "text-generation" && params.chatCompletion) {
334
+ return `${params.baseUrl}/v1/chat/completions`;
335
+ }
336
+ return params.baseUrl;
337
+ };
338
+ var SAMBANOVA_CONFIG = {
339
+ baseUrl: SAMBANOVA_API_BASE_URL,
340
+ makeBody: makeBody10,
341
+ makeHeaders: makeHeaders10,
342
+ makeUrl: makeUrl10
343
+ };
114
344
 
115
345
  // src/providers/together.ts
116
346
  var TOGETHER_API_BASE_URL = "https://api.together.xyz";
117
-
118
- // src/providers/novita.ts
119
- var NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
120
-
121
- // src/providers/fireworks-ai.ts
122
- var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
123
-
124
- // src/providers/hyperbolic.ts
125
- var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
126
-
127
- // src/providers/black-forest-labs.ts
128
- var BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
347
+ var makeBody11 = (params) => {
348
+ return {
349
+ ...params.args,
350
+ model: params.model
351
+ };
352
+ };
353
+ var makeHeaders11 = (params) => {
354
+ return { Authorization: `Bearer ${params.accessToken}` };
355
+ };
356
+ var makeUrl11 = (params) => {
357
+ if (params.task === "text-to-image") {
358
+ return `${params.baseUrl}/v1/images/generations`;
359
+ }
360
+ if (params.task === "text-generation") {
361
+ if (params.chatCompletion) {
362
+ return `${params.baseUrl}/v1/chat/completions`;
363
+ }
364
+ return `${params.baseUrl}/v1/completions`;
365
+ }
366
+ return params.baseUrl;
367
+ };
368
+ var TOGETHER_CONFIG = {
369
+ baseUrl: TOGETHER_API_BASE_URL,
370
+ makeBody: makeBody11,
371
+ makeHeaders: makeHeaders11,
372
+ makeUrl: makeUrl11
373
+ };
129
374
 
130
375
  // src/lib/isUrl.ts
131
376
  function isUrl(modelOrUrl) {
@@ -134,7 +379,7 @@ function isUrl(modelOrUrl) {
134
379
 
135
380
  // package.json
136
381
  var name = "@huggingface/inference";
137
- var version = "3.3.6";
382
+ var version = "3.4.0";
138
383
 
139
384
  // src/providers/consts.ts
140
385
  var HARDCODED_MODEL_ID_MAPPING = {
@@ -145,15 +390,16 @@ var HARDCODED_MODEL_ID_MAPPING = {
145
390
  * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
146
391
  */
147
392
  "black-forest-labs": {},
393
+ cohere: {},
148
394
  "fal-ai": {},
149
395
  "fireworks-ai": {},
150
396
  "hf-inference": {},
151
397
  hyperbolic: {},
152
398
  nebius: {},
399
+ novita: {},
153
400
  replicate: {},
154
401
  sambanova: {},
155
- together: {},
156
- novita: {}
402
+ together: {}
157
403
  };
158
404
 
159
405
  // src/lib/getProviderModelId.ts
@@ -162,10 +408,10 @@ async function getProviderModelId(params, args, options = {}) {
162
408
  if (params.provider === "hf-inference") {
163
409
  return params.model;
164
410
  }
165
- if (!options.taskHint) {
166
- throw new Error("taskHint must be specified when using a third-party provider");
411
+ if (!options.task) {
412
+ throw new Error("task must be specified when using a third-party provider");
167
413
  }
168
- const task = options.taskHint === "text-generation" && options.chatCompletion ? "conversational" : options.taskHint;
414
+ const task = options.task === "text-generation" && options.chatCompletion ? "conversational" : options.task;
169
415
  if (HARDCODED_MODEL_ID_MAPPING[params.provider]?.[params.model]) {
170
416
  return HARDCODED_MODEL_ID_MAPPING[params.provider][params.model];
171
417
  }
@@ -203,165 +449,83 @@ async function getProviderModelId(params, args, options = {}) {
203
449
  // src/lib/makeRequestOptions.ts
204
450
  var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
205
451
  var tasks = null;
452
+ var providerConfigs = {
453
+ "black-forest-labs": BLACK_FOREST_LABS_CONFIG,
454
+ cohere: COHERE_CONFIG,
455
+ "fal-ai": FAL_AI_CONFIG,
456
+ "fireworks-ai": FIREWORKS_AI_CONFIG,
457
+ "hf-inference": HF_INFERENCE_CONFIG,
458
+ hyperbolic: HYPERBOLIC_CONFIG,
459
+ nebius: NEBIUS_CONFIG,
460
+ novita: NOVITA_CONFIG,
461
+ replicate: REPLICATE_CONFIG,
462
+ sambanova: SAMBANOVA_CONFIG,
463
+ together: TOGETHER_CONFIG
464
+ };
206
465
  async function makeRequestOptions(args, options) {
207
466
  const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
208
- let otherArgs = remainingArgs;
209
467
  const provider = maybeProvider ?? "hf-inference";
210
- const { includeCredentials, taskHint, chatCompletion: chatCompletion2 } = options ?? {};
468
+ const providerConfig = providerConfigs[provider];
469
+ const { includeCredentials, task, chatCompletion: chatCompletion2, signal } = options ?? {};
211
470
  if (endpointUrl && provider !== "hf-inference") {
212
471
  throw new Error(`Cannot use endpointUrl with a third-party provider.`);
213
472
  }
214
473
  if (maybeModel && isUrl(maybeModel)) {
215
474
  throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
216
475
  }
217
- if (!maybeModel && !taskHint) {
476
+ if (!maybeModel && !task) {
218
477
  throw new Error("No model provided, and no task has been specified.");
219
478
  }
220
- const hfModel = maybeModel ?? await loadDefaultModel(taskHint);
479
+ if (!providerConfig) {
480
+ throw new Error(`No provider config found for provider ${provider}`);
481
+ }
482
+ const hfModel = maybeModel ?? await loadDefaultModel(task);
221
483
  const model = await getProviderModelId({ model: hfModel, provider }, args, {
222
- taskHint,
484
+ task,
223
485
  chatCompletion: chatCompletion2,
224
486
  fetch: options?.fetch
225
487
  });
226
488
  const authMethod = accessToken ? accessToken.startsWith("hf_") ? "hf-token" : "provider-key" : includeCredentials === "include" ? "credentials-include" : "none";
227
- const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : makeUrl({
228
- authMethod,
229
- chatCompletion: chatCompletion2 ?? false,
489
+ const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : providerConfig.makeUrl({
490
+ baseUrl: authMethod !== "provider-key" ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider) : providerConfig.baseUrl,
230
491
  model,
231
- provider: provider ?? "hf-inference",
232
- taskHint
492
+ chatCompletion: chatCompletion2,
493
+ task
233
494
  });
234
- const headers = {};
235
- if (accessToken) {
236
- if (provider === "fal-ai" && authMethod === "provider-key") {
237
- headers["Authorization"] = `Key ${accessToken}`;
238
- } else if (provider === "black-forest-labs" && authMethod === "provider-key") {
239
- headers["X-Key"] = accessToken;
240
- } else {
241
- headers["Authorization"] = `Bearer ${accessToken}`;
242
- }
243
- }
244
- const ownUserAgent = `${name}/${version}`;
245
- headers["User-Agent"] = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : void 0].filter((x) => x !== void 0).join(" ");
246
495
  const binary = "data" in args && !!args.data;
496
+ const headers = providerConfig.makeHeaders({
497
+ accessToken,
498
+ authMethod
499
+ });
247
500
  if (!binary) {
248
501
  headers["Content-Type"] = "application/json";
249
502
  }
250
- if (provider === "replicate") {
251
- headers["Prefer"] = "wait";
252
- }
503
+ const ownUserAgent = `${name}/${version}`;
504
+ const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : void 0].filter((x) => x !== void 0).join(" ");
505
+ headers["User-Agent"] = userAgent;
506
+ const body = binary ? args.data : JSON.stringify(
507
+ providerConfig.makeBody({
508
+ args: remainingArgs,
509
+ model,
510
+ task,
511
+ chatCompletion: chatCompletion2
512
+ })
513
+ );
253
514
  let credentials;
254
515
  if (typeof includeCredentials === "string") {
255
516
  credentials = includeCredentials;
256
517
  } else if (includeCredentials === true) {
257
518
  credentials = "include";
258
519
  }
259
- if (provider === "replicate") {
260
- const version2 = model.includes(":") ? model.split(":")[1] : void 0;
261
- otherArgs = { input: otherArgs, version: version2 };
262
- }
263
520
  const info = {
264
521
  headers,
265
522
  method: "POST",
266
- body: binary ? args.data : JSON.stringify({
267
- ...otherArgs,
268
- ...taskHint === "text-to-image" && provider === "hyperbolic" ? { model_name: model } : chatCompletion2 || provider === "together" || provider === "nebius" || provider === "hyperbolic" ? { model } : void 0
269
- }),
523
+ body,
270
524
  ...credentials ? { credentials } : void 0,
271
- signal: options?.signal
525
+ signal
272
526
  };
273
527
  return { url, info };
274
528
  }
275
- function makeUrl(params) {
276
- if (params.authMethod === "none" && params.provider !== "hf-inference") {
277
- throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
278
- }
279
- const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
280
- switch (params.provider) {
281
- case "black-forest-labs": {
282
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : BLACKFORESTLABS_AI_API_BASE_URL;
283
- return `${baseUrl}/${params.model}`;
284
- }
285
- case "fal-ai": {
286
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FAL_AI_API_BASE_URL;
287
- return `${baseUrl}/${params.model}`;
288
- }
289
- case "nebius": {
290
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : NEBIUS_API_BASE_URL;
291
- if (params.taskHint === "text-to-image") {
292
- return `${baseUrl}/v1/images/generations`;
293
- }
294
- if (params.taskHint === "text-generation") {
295
- if (params.chatCompletion) {
296
- return `${baseUrl}/v1/chat/completions`;
297
- }
298
- return `${baseUrl}/v1/completions`;
299
- }
300
- return baseUrl;
301
- }
302
- case "replicate": {
303
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : REPLICATE_API_BASE_URL;
304
- if (params.model.includes(":")) {
305
- return `${baseUrl}/v1/predictions`;
306
- }
307
- return `${baseUrl}/v1/models/${params.model}/predictions`;
308
- }
309
- case "sambanova": {
310
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : SAMBANOVA_API_BASE_URL;
311
- if (params.taskHint === "text-generation" && params.chatCompletion) {
312
- return `${baseUrl}/v1/chat/completions`;
313
- }
314
- return baseUrl;
315
- }
316
- case "together": {
317
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : TOGETHER_API_BASE_URL;
318
- if (params.taskHint === "text-to-image") {
319
- return `${baseUrl}/v1/images/generations`;
320
- }
321
- if (params.taskHint === "text-generation") {
322
- if (params.chatCompletion) {
323
- return `${baseUrl}/v1/chat/completions`;
324
- }
325
- return `${baseUrl}/v1/completions`;
326
- }
327
- return baseUrl;
328
- }
329
- case "fireworks-ai": {
330
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FIREWORKS_AI_API_BASE_URL;
331
- if (params.taskHint === "text-generation" && params.chatCompletion) {
332
- return `${baseUrl}/v1/chat/completions`;
333
- }
334
- return baseUrl;
335
- }
336
- case "hyperbolic": {
337
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : HYPERBOLIC_API_BASE_URL;
338
- if (params.taskHint === "text-to-image") {
339
- return `${baseUrl}/v1/images/generations`;
340
- }
341
- return `${baseUrl}/v1/chat/completions`;
342
- }
343
- case "novita": {
344
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : NOVITA_API_BASE_URL;
345
- if (params.taskHint === "text-generation") {
346
- if (params.chatCompletion) {
347
- return `${baseUrl}/chat/completions`;
348
- }
349
- return `${baseUrl}/completions`;
350
- }
351
- return baseUrl;
352
- }
353
- default: {
354
- const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
355
- if (params.taskHint && ["feature-extraction", "sentence-similarity"].includes(params.taskHint)) {
356
- return `${baseUrl}/pipeline/${params.taskHint}/${params.model}`;
357
- }
358
- if (params.taskHint === "text-generation" && params.chatCompletion) {
359
- return `${baseUrl}/models/${params.model}/v1/chat/completions`;
360
- }
361
- return `${baseUrl}/models/${params.model}`;
362
- }
363
- }
364
- }
365
529
  async function loadDefaultModel(task) {
366
530
  if (!tasks) {
367
531
  tasks = await loadTaskInfo();
@@ -628,7 +792,7 @@ async function audioClassification(args, options) {
628
792
  const payload = preparePayload(args);
629
793
  const res = await request(payload, {
630
794
  ...options,
631
- taskHint: "audio-classification"
795
+ task: "audio-classification"
632
796
  });
633
797
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
634
798
  if (!isValidOutput) {
@@ -655,7 +819,7 @@ async function automaticSpeechRecognition(args, options) {
655
819
  const payload = await buildPayload(args);
656
820
  const res = await request(payload, {
657
821
  ...options,
658
- taskHint: "automatic-speech-recognition"
822
+ task: "automatic-speech-recognition"
659
823
  });
660
824
  const isValidOutput = typeof res?.text === "string";
661
825
  if (!isValidOutput) {
@@ -699,7 +863,7 @@ async function textToSpeech(args, options) {
699
863
  } : args;
700
864
  const res = await request(payload, {
701
865
  ...options,
702
- taskHint: "text-to-speech"
866
+ task: "text-to-speech"
703
867
  });
704
868
  if (res instanceof Blob) {
705
869
  return res;
@@ -725,7 +889,7 @@ async function audioToAudio(args, options) {
725
889
  const payload = preparePayload(args);
726
890
  const res = await request(payload, {
727
891
  ...options,
728
- taskHint: "audio-to-audio"
892
+ task: "audio-to-audio"
729
893
  });
730
894
  return validateOutput(res);
731
895
  }
@@ -751,7 +915,7 @@ async function imageClassification(args, options) {
751
915
  const payload = preparePayload2(args);
752
916
  const res = await request(payload, {
753
917
  ...options,
754
- taskHint: "image-classification"
918
+ task: "image-classification"
755
919
  });
756
920
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
757
921
  if (!isValidOutput) {
@@ -765,7 +929,7 @@ async function imageSegmentation(args, options) {
765
929
  const payload = preparePayload2(args);
766
930
  const res = await request(payload, {
767
931
  ...options,
768
- taskHint: "image-segmentation"
932
+ task: "image-segmentation"
769
933
  });
770
934
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
771
935
  if (!isValidOutput) {
@@ -779,7 +943,7 @@ async function imageToText(args, options) {
779
943
  const payload = preparePayload2(args);
780
944
  const res = (await request(payload, {
781
945
  ...options,
782
- taskHint: "image-to-text"
946
+ task: "image-to-text"
783
947
  }))?.[0];
784
948
  if (typeof res?.generated_text !== "string") {
785
949
  throw new InferenceOutputError("Expected {generated_text: string}");
@@ -792,7 +956,7 @@ async function objectDetection(args, options) {
792
956
  const payload = preparePayload2(args);
793
957
  const res = await request(payload, {
794
958
  ...options,
795
- taskHint: "object-detection"
959
+ task: "object-detection"
796
960
  });
797
961
  const isValidOutput = Array.isArray(res) && res.every(
798
962
  (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
@@ -836,7 +1000,7 @@ async function textToImage(args, options) {
836
1000
  };
837
1001
  const res = await request(payload, {
838
1002
  ...options,
839
- taskHint: "text-to-image"
1003
+ task: "text-to-image"
840
1004
  });
841
1005
  if (res && typeof res === "object") {
842
1006
  if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
@@ -925,7 +1089,7 @@ async function imageToImage(args, options) {
925
1089
  }
926
1090
  const res = await request(reqArgs, {
927
1091
  ...options,
928
- taskHint: "image-to-image"
1092
+ task: "image-to-image"
929
1093
  });
930
1094
  const isValidOutput = res && res instanceof Blob;
931
1095
  if (!isValidOutput) {
@@ -960,7 +1124,7 @@ async function zeroShotImageClassification(args, options) {
960
1124
  const payload = await preparePayload3(args);
961
1125
  const res = await request(payload, {
962
1126
  ...options,
963
- taskHint: "zero-shot-image-classification"
1127
+ task: "zero-shot-image-classification"
964
1128
  });
965
1129
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
966
1130
  if (!isValidOutput) {
@@ -980,7 +1144,7 @@ async function textToVideo(args, options) {
980
1144
  const payload = args.provider === "fal-ai" || args.provider === "replicate" ? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs } : args;
981
1145
  const res = await request(payload, {
982
1146
  ...options,
983
- taskHint: "text-to-video"
1147
+ task: "text-to-video"
984
1148
  });
985
1149
  if (args.provider === "fal-ai") {
986
1150
  const isValidOutput = typeof res === "object" && !!res && "video" in res && typeof res.video === "object" && !!res.video && "url" in res.video && typeof res.video.url === "string" && isUrl(res.video.url);
@@ -1003,7 +1167,7 @@ async function textToVideo(args, options) {
1003
1167
  async function featureExtraction(args, options) {
1004
1168
  const res = await request(args, {
1005
1169
  ...options,
1006
- taskHint: "feature-extraction"
1170
+ task: "feature-extraction"
1007
1171
  });
1008
1172
  let isValidOutput = true;
1009
1173
  const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
@@ -1026,7 +1190,7 @@ async function featureExtraction(args, options) {
1026
1190
  async function fillMask(args, options) {
1027
1191
  const res = await request(args, {
1028
1192
  ...options,
1029
- taskHint: "fill-mask"
1193
+ task: "fill-mask"
1030
1194
  });
1031
1195
  const isValidOutput = Array.isArray(res) && res.every(
1032
1196
  (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
@@ -1043,7 +1207,7 @@ async function fillMask(args, options) {
1043
1207
  async function questionAnswering(args, options) {
1044
1208
  const res = await request(args, {
1045
1209
  ...options,
1046
- taskHint: "question-answering"
1210
+ task: "question-answering"
1047
1211
  });
1048
1212
  const isValidOutput = Array.isArray(res) ? res.every(
1049
1213
  (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
@@ -1058,7 +1222,7 @@ async function questionAnswering(args, options) {
1058
1222
  async function sentenceSimilarity(args, options) {
1059
1223
  const res = await request(prepareInput(args), {
1060
1224
  ...options,
1061
- taskHint: "sentence-similarity"
1225
+ task: "sentence-similarity"
1062
1226
  });
1063
1227
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1064
1228
  if (!isValidOutput) {
@@ -1078,7 +1242,7 @@ function prepareInput(args) {
1078
1242
  async function summarization(args, options) {
1079
1243
  const res = await request(args, {
1080
1244
  ...options,
1081
- taskHint: "summarization"
1245
+ task: "summarization"
1082
1246
  });
1083
1247
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string");
1084
1248
  if (!isValidOutput) {
@@ -1091,7 +1255,7 @@ async function summarization(args, options) {
1091
1255
  async function tableQuestionAnswering(args, options) {
1092
1256
  const res = await request(args, {
1093
1257
  ...options,
1094
- taskHint: "table-question-answering"
1258
+ task: "table-question-answering"
1095
1259
  });
1096
1260
  const isValidOutput = Array.isArray(res) ? res.every((elem) => validate(elem)) : validate(res);
1097
1261
  if (!isValidOutput) {
@@ -1111,7 +1275,7 @@ function validate(elem) {
1111
1275
  async function textClassification(args, options) {
1112
1276
  const res = (await request(args, {
1113
1277
  ...options,
1114
- taskHint: "text-classification"
1278
+ task: "text-classification"
1115
1279
  }))?.[0];
1116
1280
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number");
1117
1281
  if (!isValidOutput) {
@@ -1134,7 +1298,7 @@ async function textGeneration(args, options) {
1134
1298
  args.prompt = args.inputs;
1135
1299
  const raw = await request(args, {
1136
1300
  ...options,
1137
- taskHint: "text-generation"
1301
+ task: "text-generation"
1138
1302
  });
1139
1303
  const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
1140
1304
  if (!isValidOutput) {
@@ -1155,7 +1319,7 @@ async function textGeneration(args, options) {
1155
1319
  };
1156
1320
  const raw = await request(payload, {
1157
1321
  ...options,
1158
- taskHint: "text-generation"
1322
+ task: "text-generation"
1159
1323
  });
1160
1324
  const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
1161
1325
  if (!isValidOutput) {
@@ -1169,7 +1333,7 @@ async function textGeneration(args, options) {
1169
1333
  const res = toArray(
1170
1334
  await request(args, {
1171
1335
  ...options,
1172
- taskHint: "text-generation"
1336
+ task: "text-generation"
1173
1337
  })
1174
1338
  );
1175
1339
  const isValidOutput = Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
@@ -1184,7 +1348,7 @@ async function textGeneration(args, options) {
1184
1348
  async function* textGenerationStream(args, options) {
1185
1349
  yield* streamingRequest(args, {
1186
1350
  ...options,
1187
- taskHint: "text-generation"
1351
+ task: "text-generation"
1188
1352
  });
1189
1353
  }
1190
1354
 
@@ -1193,7 +1357,7 @@ async function tokenClassification(args, options) {
1193
1357
  const res = toArray(
1194
1358
  await request(args, {
1195
1359
  ...options,
1196
- taskHint: "token-classification"
1360
+ task: "token-classification"
1197
1361
  })
1198
1362
  );
1199
1363
  const isValidOutput = Array.isArray(res) && res.every(
@@ -1211,7 +1375,7 @@ async function tokenClassification(args, options) {
1211
1375
  async function translation(args, options) {
1212
1376
  const res = await request(args, {
1213
1377
  ...options,
1214
- taskHint: "translation"
1378
+ task: "translation"
1215
1379
  });
1216
1380
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");
1217
1381
  if (!isValidOutput) {
@@ -1225,7 +1389,7 @@ async function zeroShotClassification(args, options) {
1225
1389
  const res = toArray(
1226
1390
  await request(args, {
1227
1391
  ...options,
1228
- taskHint: "zero-shot-classification"
1392
+ task: "zero-shot-classification"
1229
1393
  })
1230
1394
  );
1231
1395
  const isValidOutput = Array.isArray(res) && res.every(
@@ -1241,7 +1405,7 @@ async function zeroShotClassification(args, options) {
1241
1405
  async function chatCompletion(args, options) {
1242
1406
  const res = await request(args, {
1243
1407
  ...options,
1244
- taskHint: "text-generation",
1408
+ task: "text-generation",
1245
1409
  chatCompletion: true
1246
1410
  });
1247
1411
  const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
@@ -1256,7 +1420,7 @@ async function chatCompletion(args, options) {
1256
1420
  async function* chatCompletionStream(args, options) {
1257
1421
  yield* streamingRequest(args, {
1258
1422
  ...options,
1259
- taskHint: "text-generation",
1423
+ task: "text-generation",
1260
1424
  chatCompletion: true
1261
1425
  });
1262
1426
  }
@@ -1274,7 +1438,7 @@ async function documentQuestionAnswering(args, options) {
1274
1438
  const res = toArray(
1275
1439
  await request(reqArgs, {
1276
1440
  ...options,
1277
- taskHint: "document-question-answering"
1441
+ task: "document-question-answering"
1278
1442
  })
1279
1443
  );
1280
1444
  const isValidOutput = Array.isArray(res) && res.every(
@@ -1298,7 +1462,7 @@ async function visualQuestionAnswering(args, options) {
1298
1462
  };
1299
1463
  const res = await request(reqArgs, {
1300
1464
  ...options,
1301
- taskHint: "visual-question-answering"
1465
+ task: "visual-question-answering"
1302
1466
  });
1303
1467
  const isValidOutput = Array.isArray(res) && res.every(
1304
1468
  (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
@@ -1313,7 +1477,7 @@ async function visualQuestionAnswering(args, options) {
1313
1477
  async function tabularRegression(args, options) {
1314
1478
  const res = await request(args, {
1315
1479
  ...options,
1316
- taskHint: "tabular-regression"
1480
+ task: "tabular-regression"
1317
1481
  });
1318
1482
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1319
1483
  if (!isValidOutput) {
@@ -1326,7 +1490,7 @@ async function tabularRegression(args, options) {
1326
1490
  async function tabularClassification(args, options) {
1327
1491
  const res = await request(args, {
1328
1492
  ...options,
1329
- taskHint: "tabular-classification"
1493
+ task: "tabular-classification"
1330
1494
  });
1331
1495
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1332
1496
  if (!isValidOutput) {
@@ -1378,6 +1542,7 @@ var HfInferenceEndpoint = class {
1378
1542
  // src/types.ts
1379
1543
  var INFERENCE_PROVIDERS = [
1380
1544
  "black-forest-labs",
1545
+ "cohere",
1381
1546
  "fal-ai",
1382
1547
  "fireworks-ai",
1383
1548
  "hf-inference",