@huggingface/inference 3.3.6 → 3.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. package/README.md +2 -0
  2. package/dist/index.cjs +339 -174
  3. package/dist/index.js +339 -174
  4. package/dist/src/lib/getProviderModelId.d.ts +1 -1
  5. package/dist/src/lib/getProviderModelId.d.ts.map +1 -1
  6. package/dist/src/lib/makeRequestOptions.d.ts +2 -2
  7. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  8. package/dist/src/providers/black-forest-labs.d.ts +2 -1
  9. package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
  10. package/dist/src/providers/cohere.d.ts +19 -0
  11. package/dist/src/providers/cohere.d.ts.map +1 -0
  12. package/dist/src/providers/consts.d.ts.map +1 -1
  13. package/dist/src/providers/fal-ai.d.ts +2 -1
  14. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  15. package/dist/src/providers/fireworks-ai.d.ts +2 -1
  16. package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
  17. package/dist/src/providers/hf-inference.d.ts +3 -0
  18. package/dist/src/providers/hf-inference.d.ts.map +1 -0
  19. package/dist/src/providers/hyperbolic.d.ts +2 -1
  20. package/dist/src/providers/hyperbolic.d.ts.map +1 -1
  21. package/dist/src/providers/nebius.d.ts +2 -1
  22. package/dist/src/providers/nebius.d.ts.map +1 -1
  23. package/dist/src/providers/novita.d.ts +2 -1
  24. package/dist/src/providers/novita.d.ts.map +1 -1
  25. package/dist/src/providers/replicate.d.ts +3 -1
  26. package/dist/src/providers/replicate.d.ts.map +1 -1
  27. package/dist/src/providers/sambanova.d.ts +2 -1
  28. package/dist/src/providers/sambanova.d.ts.map +1 -1
  29. package/dist/src/providers/together.d.ts +2 -1
  30. package/dist/src/providers/together.d.ts.map +1 -1
  31. package/dist/src/tasks/custom/request.d.ts +2 -4
  32. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  33. package/dist/src/tasks/custom/streamingRequest.d.ts +2 -4
  34. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  35. package/dist/src/tasks/nlp/featureExtraction.d.ts +2 -9
  36. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  37. package/dist/src/types.d.ts +25 -4
  38. package/dist/src/types.d.ts.map +1 -1
  39. package/package.json +2 -2
  40. package/src/lib/getProviderModelId.ts +4 -4
  41. package/src/lib/makeRequestOptions.ts +74 -186
  42. package/src/providers/black-forest-labs.ts +26 -2
  43. package/src/providers/cohere.ts +42 -0
  44. package/src/providers/consts.ts +2 -1
  45. package/src/providers/fal-ai.ts +24 -2
  46. package/src/providers/fireworks-ai.ts +28 -2
  47. package/src/providers/hf-inference.ts +43 -0
  48. package/src/providers/hyperbolic.ts +28 -2
  49. package/src/providers/nebius.ts +34 -2
  50. package/src/providers/novita.ts +31 -2
  51. package/src/providers/replicate.ts +30 -2
  52. package/src/providers/sambanova.ts +28 -2
  53. package/src/providers/together.ts +34 -2
  54. package/src/tasks/audio/audioClassification.ts +1 -1
  55. package/src/tasks/audio/audioToAudio.ts +1 -1
  56. package/src/tasks/audio/automaticSpeechRecognition.ts +1 -1
  57. package/src/tasks/audio/textToSpeech.ts +1 -1
  58. package/src/tasks/custom/request.ts +2 -4
  59. package/src/tasks/custom/streamingRequest.ts +2 -4
  60. package/src/tasks/cv/imageClassification.ts +1 -1
  61. package/src/tasks/cv/imageSegmentation.ts +1 -1
  62. package/src/tasks/cv/imageToImage.ts +1 -1
  63. package/src/tasks/cv/imageToText.ts +1 -1
  64. package/src/tasks/cv/objectDetection.ts +1 -1
  65. package/src/tasks/cv/textToImage.ts +1 -1
  66. package/src/tasks/cv/textToVideo.ts +1 -1
  67. package/src/tasks/cv/zeroShotImageClassification.ts +1 -1
  68. package/src/tasks/multimodal/documentQuestionAnswering.ts +1 -1
  69. package/src/tasks/multimodal/visualQuestionAnswering.ts +1 -1
  70. package/src/tasks/nlp/chatCompletion.ts +1 -1
  71. package/src/tasks/nlp/chatCompletionStream.ts +1 -1
  72. package/src/tasks/nlp/featureExtraction.ts +3 -10
  73. package/src/tasks/nlp/fillMask.ts +1 -1
  74. package/src/tasks/nlp/questionAnswering.ts +1 -1
  75. package/src/tasks/nlp/sentenceSimilarity.ts +1 -1
  76. package/src/tasks/nlp/summarization.ts +1 -1
  77. package/src/tasks/nlp/tableQuestionAnswering.ts +1 -1
  78. package/src/tasks/nlp/textClassification.ts +1 -1
  79. package/src/tasks/nlp/textGeneration.ts +3 -3
  80. package/src/tasks/nlp/textGenerationStream.ts +1 -1
  81. package/src/tasks/nlp/tokenClassification.ts +1 -1
  82. package/src/tasks/nlp/translation.ts +1 -1
  83. package/src/tasks/nlp/zeroShotClassification.ts +1 -1
  84. package/src/tasks/tabular/tabularClassification.ts +1 -1
  85. package/src/tasks/tabular/tabularRegression.ts +1 -1
  86. package/src/types.ts +29 -2
package/dist/index.js CHANGED
@@ -45,32 +45,277 @@ __export(tasks_exports, {
45
45
  var HF_HUB_URL = "https://huggingface.co";
46
46
  var HF_ROUTER_URL = "https://router.huggingface.co";
47
47
 
48
+ // src/providers/black-forest-labs.ts
49
+ var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
50
+ var makeBody = (params) => {
51
+ return params.args;
52
+ };
53
+ var makeHeaders = (params) => {
54
+ if (params.authMethod === "provider-key") {
55
+ return { "X-Key": `${params.accessToken}` };
56
+ } else {
57
+ return { Authorization: `Bearer ${params.accessToken}` };
58
+ }
59
+ };
60
+ var makeUrl = (params) => {
61
+ return `${params.baseUrl}/${params.model}`;
62
+ };
63
+ var BLACK_FOREST_LABS_CONFIG = {
64
+ baseUrl: BLACK_FOREST_LABS_AI_API_BASE_URL,
65
+ makeBody,
66
+ makeHeaders,
67
+ makeUrl
68
+ };
69
+
70
+ // src/providers/cohere.ts
71
+ var COHERE_API_BASE_URL = "https://api.cohere.com";
72
+ var makeBody2 = (params) => {
73
+ return {
74
+ ...params.args,
75
+ model: params.model
76
+ };
77
+ };
78
+ var makeHeaders2 = (params) => {
79
+ return { Authorization: `Bearer ${params.accessToken}` };
80
+ };
81
+ var makeUrl2 = (params) => {
82
+ return `${params.baseUrl}/compatibility/v1/chat/completions`;
83
+ };
84
+ var COHERE_CONFIG = {
85
+ baseUrl: COHERE_API_BASE_URL,
86
+ makeBody: makeBody2,
87
+ makeHeaders: makeHeaders2,
88
+ makeUrl: makeUrl2
89
+ };
90
+
48
91
  // src/providers/fal-ai.ts
49
92
  var FAL_AI_API_BASE_URL = "https://fal.run";
93
+ var makeBody3 = (params) => {
94
+ return params.args;
95
+ };
96
+ var makeHeaders3 = (params) => {
97
+ return {
98
+ Authorization: params.authMethod === "provider-key" ? `Key ${params.accessToken}` : `Bearer ${params.accessToken}`
99
+ };
100
+ };
101
+ var makeUrl3 = (params) => {
102
+ return `${params.baseUrl}/${params.model}`;
103
+ };
104
+ var FAL_AI_CONFIG = {
105
+ baseUrl: FAL_AI_API_BASE_URL,
106
+ makeBody: makeBody3,
107
+ makeHeaders: makeHeaders3,
108
+ makeUrl: makeUrl3
109
+ };
110
+
111
+ // src/providers/fireworks-ai.ts
112
+ var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
113
+ var makeBody4 = (params) => {
114
+ return {
115
+ ...params.args,
116
+ ...params.chatCompletion ? { model: params.model } : void 0
117
+ };
118
+ };
119
+ var makeHeaders4 = (params) => {
120
+ return { Authorization: `Bearer ${params.accessToken}` };
121
+ };
122
+ var makeUrl4 = (params) => {
123
+ if (params.task === "text-generation" && params.chatCompletion) {
124
+ return `${params.baseUrl}/v1/chat/completions`;
125
+ }
126
+ return params.baseUrl;
127
+ };
128
+ var FIREWORKS_AI_CONFIG = {
129
+ baseUrl: FIREWORKS_AI_API_BASE_URL,
130
+ makeBody: makeBody4,
131
+ makeHeaders: makeHeaders4,
132
+ makeUrl: makeUrl4
133
+ };
134
+
135
+ // src/providers/hf-inference.ts
136
+ var makeBody5 = (params) => {
137
+ return {
138
+ ...params.args,
139
+ ...params.chatCompletion ? { model: params.model } : void 0
140
+ };
141
+ };
142
+ var makeHeaders5 = (params) => {
143
+ return { Authorization: `Bearer ${params.accessToken}` };
144
+ };
145
+ var makeUrl5 = (params) => {
146
+ if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
147
+ return `${params.baseUrl}/pipeline/${params.task}/${params.model}`;
148
+ }
149
+ if (params.task === "text-generation" && params.chatCompletion) {
150
+ return `${params.baseUrl}/models/${params.model}/v1/chat/completions`;
151
+ }
152
+ return `${params.baseUrl}/models/${params.model}`;
153
+ };
154
+ var HF_INFERENCE_CONFIG = {
155
+ baseUrl: `${HF_ROUTER_URL}/hf-inference`,
156
+ makeBody: makeBody5,
157
+ makeHeaders: makeHeaders5,
158
+ makeUrl: makeUrl5
159
+ };
160
+
161
+ // src/providers/hyperbolic.ts
162
+ var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
163
+ var makeBody6 = (params) => {
164
+ return {
165
+ ...params.args,
166
+ ...params.task === "text-to-image" ? { model_name: params.model } : { model: params.model }
167
+ };
168
+ };
169
+ var makeHeaders6 = (params) => {
170
+ return { Authorization: `Bearer ${params.accessToken}` };
171
+ };
172
+ var makeUrl6 = (params) => {
173
+ if (params.task === "text-to-image") {
174
+ return `${params.baseUrl}/v1/images/generations`;
175
+ }
176
+ return `${params.baseUrl}/v1/chat/completions`;
177
+ };
178
+ var HYPERBOLIC_CONFIG = {
179
+ baseUrl: HYPERBOLIC_API_BASE_URL,
180
+ makeBody: makeBody6,
181
+ makeHeaders: makeHeaders6,
182
+ makeUrl: makeUrl6
183
+ };
50
184
 
51
185
  // src/providers/nebius.ts
52
186
  var NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
187
+ var makeBody7 = (params) => {
188
+ return {
189
+ ...params.args,
190
+ model: params.model
191
+ };
192
+ };
193
+ var makeHeaders7 = (params) => {
194
+ return { Authorization: `Bearer ${params.accessToken}` };
195
+ };
196
+ var makeUrl7 = (params) => {
197
+ if (params.task === "text-to-image") {
198
+ return `${params.baseUrl}/v1/images/generations`;
199
+ }
200
+ if (params.task === "text-generation") {
201
+ if (params.chatCompletion) {
202
+ return `${params.baseUrl}/v1/chat/completions`;
203
+ }
204
+ return `${params.baseUrl}/v1/completions`;
205
+ }
206
+ return params.baseUrl;
207
+ };
208
+ var NEBIUS_CONFIG = {
209
+ baseUrl: NEBIUS_API_BASE_URL,
210
+ makeBody: makeBody7,
211
+ makeHeaders: makeHeaders7,
212
+ makeUrl: makeUrl7
213
+ };
214
+
215
+ // src/providers/novita.ts
216
+ var NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
217
+ var makeBody8 = (params) => {
218
+ return {
219
+ ...params.args,
220
+ ...params.chatCompletion ? { model: params.model } : void 0
221
+ };
222
+ };
223
+ var makeHeaders8 = (params) => {
224
+ return { Authorization: `Bearer ${params.accessToken}` };
225
+ };
226
+ var makeUrl8 = (params) => {
227
+ if (params.task === "text-generation") {
228
+ if (params.chatCompletion) {
229
+ return `${params.baseUrl}/chat/completions`;
230
+ }
231
+ return `${params.baseUrl}/completions`;
232
+ }
233
+ return params.baseUrl;
234
+ };
235
+ var NOVITA_CONFIG = {
236
+ baseUrl: NOVITA_API_BASE_URL,
237
+ makeBody: makeBody8,
238
+ makeHeaders: makeHeaders8,
239
+ makeUrl: makeUrl8
240
+ };
53
241
 
54
242
  // src/providers/replicate.ts
55
243
  var REPLICATE_API_BASE_URL = "https://api.replicate.com";
244
+ var makeBody9 = (params) => {
245
+ return {
246
+ input: params.args,
247
+ version: params.model.includes(":") ? params.model.split(":")[1] : void 0
248
+ };
249
+ };
250
+ var makeHeaders9 = (params) => {
251
+ return { Authorization: `Bearer ${params.accessToken}` };
252
+ };
253
+ var makeUrl9 = (params) => {
254
+ if (params.model.includes(":")) {
255
+ return `${params.baseUrl}/v1/predictions`;
256
+ }
257
+ return `${params.baseUrl}/v1/models/${params.model}/predictions`;
258
+ };
259
+ var REPLICATE_CONFIG = {
260
+ baseUrl: REPLICATE_API_BASE_URL,
261
+ makeBody: makeBody9,
262
+ makeHeaders: makeHeaders9,
263
+ makeUrl: makeUrl9
264
+ };
56
265
 
57
266
  // src/providers/sambanova.ts
58
267
  var SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
268
+ var makeBody10 = (params) => {
269
+ return {
270
+ ...params.args,
271
+ ...params.chatCompletion ? { model: params.model } : void 0
272
+ };
273
+ };
274
+ var makeHeaders10 = (params) => {
275
+ return { Authorization: `Bearer ${params.accessToken}` };
276
+ };
277
+ var makeUrl10 = (params) => {
278
+ if (params.task === "text-generation" && params.chatCompletion) {
279
+ return `${params.baseUrl}/v1/chat/completions`;
280
+ }
281
+ return params.baseUrl;
282
+ };
283
+ var SAMBANOVA_CONFIG = {
284
+ baseUrl: SAMBANOVA_API_BASE_URL,
285
+ makeBody: makeBody10,
286
+ makeHeaders: makeHeaders10,
287
+ makeUrl: makeUrl10
288
+ };
59
289
 
60
290
  // src/providers/together.ts
61
291
  var TOGETHER_API_BASE_URL = "https://api.together.xyz";
62
-
63
- // src/providers/novita.ts
64
- var NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
65
-
66
- // src/providers/fireworks-ai.ts
67
- var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
68
-
69
- // src/providers/hyperbolic.ts
70
- var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
71
-
72
- // src/providers/black-forest-labs.ts
73
- var BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
292
+ var makeBody11 = (params) => {
293
+ return {
294
+ ...params.args,
295
+ model: params.model
296
+ };
297
+ };
298
+ var makeHeaders11 = (params) => {
299
+ return { Authorization: `Bearer ${params.accessToken}` };
300
+ };
301
+ var makeUrl11 = (params) => {
302
+ if (params.task === "text-to-image") {
303
+ return `${params.baseUrl}/v1/images/generations`;
304
+ }
305
+ if (params.task === "text-generation") {
306
+ if (params.chatCompletion) {
307
+ return `${params.baseUrl}/v1/chat/completions`;
308
+ }
309
+ return `${params.baseUrl}/v1/completions`;
310
+ }
311
+ return params.baseUrl;
312
+ };
313
+ var TOGETHER_CONFIG = {
314
+ baseUrl: TOGETHER_API_BASE_URL,
315
+ makeBody: makeBody11,
316
+ makeHeaders: makeHeaders11,
317
+ makeUrl: makeUrl11
318
+ };
74
319
 
75
320
  // src/lib/isUrl.ts
76
321
  function isUrl(modelOrUrl) {
@@ -79,7 +324,7 @@ function isUrl(modelOrUrl) {
79
324
 
80
325
  // package.json
81
326
  var name = "@huggingface/inference";
82
- var version = "3.3.6";
327
+ var version = "3.4.0";
83
328
 
84
329
  // src/providers/consts.ts
85
330
  var HARDCODED_MODEL_ID_MAPPING = {
@@ -90,15 +335,16 @@ var HARDCODED_MODEL_ID_MAPPING = {
90
335
  * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
91
336
  */
92
337
  "black-forest-labs": {},
338
+ cohere: {},
93
339
  "fal-ai": {},
94
340
  "fireworks-ai": {},
95
341
  "hf-inference": {},
96
342
  hyperbolic: {},
97
343
  nebius: {},
344
+ novita: {},
98
345
  replicate: {},
99
346
  sambanova: {},
100
- together: {},
101
- novita: {}
347
+ together: {}
102
348
  };
103
349
 
104
350
  // src/lib/getProviderModelId.ts
@@ -107,10 +353,10 @@ async function getProviderModelId(params, args, options = {}) {
107
353
  if (params.provider === "hf-inference") {
108
354
  return params.model;
109
355
  }
110
- if (!options.taskHint) {
111
- throw new Error("taskHint must be specified when using a third-party provider");
356
+ if (!options.task) {
357
+ throw new Error("task must be specified when using a third-party provider");
112
358
  }
113
- const task = options.taskHint === "text-generation" && options.chatCompletion ? "conversational" : options.taskHint;
359
+ const task = options.task === "text-generation" && options.chatCompletion ? "conversational" : options.task;
114
360
  if (HARDCODED_MODEL_ID_MAPPING[params.provider]?.[params.model]) {
115
361
  return HARDCODED_MODEL_ID_MAPPING[params.provider][params.model];
116
362
  }
@@ -148,165 +394,83 @@ async function getProviderModelId(params, args, options = {}) {
148
394
  // src/lib/makeRequestOptions.ts
149
395
  var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
150
396
  var tasks = null;
397
+ var providerConfigs = {
398
+ "black-forest-labs": BLACK_FOREST_LABS_CONFIG,
399
+ cohere: COHERE_CONFIG,
400
+ "fal-ai": FAL_AI_CONFIG,
401
+ "fireworks-ai": FIREWORKS_AI_CONFIG,
402
+ "hf-inference": HF_INFERENCE_CONFIG,
403
+ hyperbolic: HYPERBOLIC_CONFIG,
404
+ nebius: NEBIUS_CONFIG,
405
+ novita: NOVITA_CONFIG,
406
+ replicate: REPLICATE_CONFIG,
407
+ sambanova: SAMBANOVA_CONFIG,
408
+ together: TOGETHER_CONFIG
409
+ };
151
410
  async function makeRequestOptions(args, options) {
152
411
  const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
153
- let otherArgs = remainingArgs;
154
412
  const provider = maybeProvider ?? "hf-inference";
155
- const { includeCredentials, taskHint, chatCompletion: chatCompletion2 } = options ?? {};
413
+ const providerConfig = providerConfigs[provider];
414
+ const { includeCredentials, task, chatCompletion: chatCompletion2, signal } = options ?? {};
156
415
  if (endpointUrl && provider !== "hf-inference") {
157
416
  throw new Error(`Cannot use endpointUrl with a third-party provider.`);
158
417
  }
159
418
  if (maybeModel && isUrl(maybeModel)) {
160
419
  throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
161
420
  }
162
- if (!maybeModel && !taskHint) {
421
+ if (!maybeModel && !task) {
163
422
  throw new Error("No model provided, and no task has been specified.");
164
423
  }
165
- const hfModel = maybeModel ?? await loadDefaultModel(taskHint);
424
+ if (!providerConfig) {
425
+ throw new Error(`No provider config found for provider ${provider}`);
426
+ }
427
+ const hfModel = maybeModel ?? await loadDefaultModel(task);
166
428
  const model = await getProviderModelId({ model: hfModel, provider }, args, {
167
- taskHint,
429
+ task,
168
430
  chatCompletion: chatCompletion2,
169
431
  fetch: options?.fetch
170
432
  });
171
433
  const authMethod = accessToken ? accessToken.startsWith("hf_") ? "hf-token" : "provider-key" : includeCredentials === "include" ? "credentials-include" : "none";
172
- const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : makeUrl({
173
- authMethod,
174
- chatCompletion: chatCompletion2 ?? false,
434
+ const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : providerConfig.makeUrl({
435
+ baseUrl: authMethod !== "provider-key" ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider) : providerConfig.baseUrl,
175
436
  model,
176
- provider: provider ?? "hf-inference",
177
- taskHint
437
+ chatCompletion: chatCompletion2,
438
+ task
178
439
  });
179
- const headers = {};
180
- if (accessToken) {
181
- if (provider === "fal-ai" && authMethod === "provider-key") {
182
- headers["Authorization"] = `Key ${accessToken}`;
183
- } else if (provider === "black-forest-labs" && authMethod === "provider-key") {
184
- headers["X-Key"] = accessToken;
185
- } else {
186
- headers["Authorization"] = `Bearer ${accessToken}`;
187
- }
188
- }
189
- const ownUserAgent = `${name}/${version}`;
190
- headers["User-Agent"] = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : void 0].filter((x) => x !== void 0).join(" ");
191
440
  const binary = "data" in args && !!args.data;
441
+ const headers = providerConfig.makeHeaders({
442
+ accessToken,
443
+ authMethod
444
+ });
192
445
  if (!binary) {
193
446
  headers["Content-Type"] = "application/json";
194
447
  }
195
- if (provider === "replicate") {
196
- headers["Prefer"] = "wait";
197
- }
448
+ const ownUserAgent = `${name}/${version}`;
449
+ const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : void 0].filter((x) => x !== void 0).join(" ");
450
+ headers["User-Agent"] = userAgent;
451
+ const body = binary ? args.data : JSON.stringify(
452
+ providerConfig.makeBody({
453
+ args: remainingArgs,
454
+ model,
455
+ task,
456
+ chatCompletion: chatCompletion2
457
+ })
458
+ );
198
459
  let credentials;
199
460
  if (typeof includeCredentials === "string") {
200
461
  credentials = includeCredentials;
201
462
  } else if (includeCredentials === true) {
202
463
  credentials = "include";
203
464
  }
204
- if (provider === "replicate") {
205
- const version2 = model.includes(":") ? model.split(":")[1] : void 0;
206
- otherArgs = { input: otherArgs, version: version2 };
207
- }
208
465
  const info = {
209
466
  headers,
210
467
  method: "POST",
211
- body: binary ? args.data : JSON.stringify({
212
- ...otherArgs,
213
- ...taskHint === "text-to-image" && provider === "hyperbolic" ? { model_name: model } : chatCompletion2 || provider === "together" || provider === "nebius" || provider === "hyperbolic" ? { model } : void 0
214
- }),
468
+ body,
215
469
  ...credentials ? { credentials } : void 0,
216
- signal: options?.signal
470
+ signal
217
471
  };
218
472
  return { url, info };
219
473
  }
220
- function makeUrl(params) {
221
- if (params.authMethod === "none" && params.provider !== "hf-inference") {
222
- throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
223
- }
224
- const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
225
- switch (params.provider) {
226
- case "black-forest-labs": {
227
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : BLACKFORESTLABS_AI_API_BASE_URL;
228
- return `${baseUrl}/${params.model}`;
229
- }
230
- case "fal-ai": {
231
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FAL_AI_API_BASE_URL;
232
- return `${baseUrl}/${params.model}`;
233
- }
234
- case "nebius": {
235
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : NEBIUS_API_BASE_URL;
236
- if (params.taskHint === "text-to-image") {
237
- return `${baseUrl}/v1/images/generations`;
238
- }
239
- if (params.taskHint === "text-generation") {
240
- if (params.chatCompletion) {
241
- return `${baseUrl}/v1/chat/completions`;
242
- }
243
- return `${baseUrl}/v1/completions`;
244
- }
245
- return baseUrl;
246
- }
247
- case "replicate": {
248
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : REPLICATE_API_BASE_URL;
249
- if (params.model.includes(":")) {
250
- return `${baseUrl}/v1/predictions`;
251
- }
252
- return `${baseUrl}/v1/models/${params.model}/predictions`;
253
- }
254
- case "sambanova": {
255
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : SAMBANOVA_API_BASE_URL;
256
- if (params.taskHint === "text-generation" && params.chatCompletion) {
257
- return `${baseUrl}/v1/chat/completions`;
258
- }
259
- return baseUrl;
260
- }
261
- case "together": {
262
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : TOGETHER_API_BASE_URL;
263
- if (params.taskHint === "text-to-image") {
264
- return `${baseUrl}/v1/images/generations`;
265
- }
266
- if (params.taskHint === "text-generation") {
267
- if (params.chatCompletion) {
268
- return `${baseUrl}/v1/chat/completions`;
269
- }
270
- return `${baseUrl}/v1/completions`;
271
- }
272
- return baseUrl;
273
- }
274
- case "fireworks-ai": {
275
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FIREWORKS_AI_API_BASE_URL;
276
- if (params.taskHint === "text-generation" && params.chatCompletion) {
277
- return `${baseUrl}/v1/chat/completions`;
278
- }
279
- return baseUrl;
280
- }
281
- case "hyperbolic": {
282
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : HYPERBOLIC_API_BASE_URL;
283
- if (params.taskHint === "text-to-image") {
284
- return `${baseUrl}/v1/images/generations`;
285
- }
286
- return `${baseUrl}/v1/chat/completions`;
287
- }
288
- case "novita": {
289
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : NOVITA_API_BASE_URL;
290
- if (params.taskHint === "text-generation") {
291
- if (params.chatCompletion) {
292
- return `${baseUrl}/chat/completions`;
293
- }
294
- return `${baseUrl}/completions`;
295
- }
296
- return baseUrl;
297
- }
298
- default: {
299
- const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
300
- if (params.taskHint && ["feature-extraction", "sentence-similarity"].includes(params.taskHint)) {
301
- return `${baseUrl}/pipeline/${params.taskHint}/${params.model}`;
302
- }
303
- if (params.taskHint === "text-generation" && params.chatCompletion) {
304
- return `${baseUrl}/models/${params.model}/v1/chat/completions`;
305
- }
306
- return `${baseUrl}/models/${params.model}`;
307
- }
308
- }
309
- }
310
474
  async function loadDefaultModel(task) {
311
475
  if (!tasks) {
312
476
  tasks = await loadTaskInfo();
@@ -573,7 +737,7 @@ async function audioClassification(args, options) {
573
737
  const payload = preparePayload(args);
574
738
  const res = await request(payload, {
575
739
  ...options,
576
- taskHint: "audio-classification"
740
+ task: "audio-classification"
577
741
  });
578
742
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
579
743
  if (!isValidOutput) {
@@ -600,7 +764,7 @@ async function automaticSpeechRecognition(args, options) {
600
764
  const payload = await buildPayload(args);
601
765
  const res = await request(payload, {
602
766
  ...options,
603
- taskHint: "automatic-speech-recognition"
767
+ task: "automatic-speech-recognition"
604
768
  });
605
769
  const isValidOutput = typeof res?.text === "string";
606
770
  if (!isValidOutput) {
@@ -644,7 +808,7 @@ async function textToSpeech(args, options) {
644
808
  } : args;
645
809
  const res = await request(payload, {
646
810
  ...options,
647
- taskHint: "text-to-speech"
811
+ task: "text-to-speech"
648
812
  });
649
813
  if (res instanceof Blob) {
650
814
  return res;
@@ -670,7 +834,7 @@ async function audioToAudio(args, options) {
670
834
  const payload = preparePayload(args);
671
835
  const res = await request(payload, {
672
836
  ...options,
673
- taskHint: "audio-to-audio"
837
+ task: "audio-to-audio"
674
838
  });
675
839
  return validateOutput(res);
676
840
  }
@@ -696,7 +860,7 @@ async function imageClassification(args, options) {
696
860
  const payload = preparePayload2(args);
697
861
  const res = await request(payload, {
698
862
  ...options,
699
- taskHint: "image-classification"
863
+ task: "image-classification"
700
864
  });
701
865
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
702
866
  if (!isValidOutput) {
@@ -710,7 +874,7 @@ async function imageSegmentation(args, options) {
710
874
  const payload = preparePayload2(args);
711
875
  const res = await request(payload, {
712
876
  ...options,
713
- taskHint: "image-segmentation"
877
+ task: "image-segmentation"
714
878
  });
715
879
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
716
880
  if (!isValidOutput) {
@@ -724,7 +888,7 @@ async function imageToText(args, options) {
724
888
  const payload = preparePayload2(args);
725
889
  const res = (await request(payload, {
726
890
  ...options,
727
- taskHint: "image-to-text"
891
+ task: "image-to-text"
728
892
  }))?.[0];
729
893
  if (typeof res?.generated_text !== "string") {
730
894
  throw new InferenceOutputError("Expected {generated_text: string}");
@@ -737,7 +901,7 @@ async function objectDetection(args, options) {
737
901
  const payload = preparePayload2(args);
738
902
  const res = await request(payload, {
739
903
  ...options,
740
- taskHint: "object-detection"
904
+ task: "object-detection"
741
905
  });
742
906
  const isValidOutput = Array.isArray(res) && res.every(
743
907
  (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
@@ -781,7 +945,7 @@ async function textToImage(args, options) {
781
945
  };
782
946
  const res = await request(payload, {
783
947
  ...options,
784
- taskHint: "text-to-image"
948
+ task: "text-to-image"
785
949
  });
786
950
  if (res && typeof res === "object") {
787
951
  if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
@@ -870,7 +1034,7 @@ async function imageToImage(args, options) {
870
1034
  }
871
1035
  const res = await request(reqArgs, {
872
1036
  ...options,
873
- taskHint: "image-to-image"
1037
+ task: "image-to-image"
874
1038
  });
875
1039
  const isValidOutput = res && res instanceof Blob;
876
1040
  if (!isValidOutput) {
@@ -905,7 +1069,7 @@ async function zeroShotImageClassification(args, options) {
905
1069
  const payload = await preparePayload3(args);
906
1070
  const res = await request(payload, {
907
1071
  ...options,
908
- taskHint: "zero-shot-image-classification"
1072
+ task: "zero-shot-image-classification"
909
1073
  });
910
1074
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
911
1075
  if (!isValidOutput) {
@@ -925,7 +1089,7 @@ async function textToVideo(args, options) {
925
1089
  const payload = args.provider === "fal-ai" || args.provider === "replicate" ? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs } : args;
926
1090
  const res = await request(payload, {
927
1091
  ...options,
928
- taskHint: "text-to-video"
1092
+ task: "text-to-video"
929
1093
  });
930
1094
  if (args.provider === "fal-ai") {
931
1095
  const isValidOutput = typeof res === "object" && !!res && "video" in res && typeof res.video === "object" && !!res.video && "url" in res.video && typeof res.video.url === "string" && isUrl(res.video.url);
@@ -948,7 +1112,7 @@ async function textToVideo(args, options) {
948
1112
  async function featureExtraction(args, options) {
949
1113
  const res = await request(args, {
950
1114
  ...options,
951
- taskHint: "feature-extraction"
1115
+ task: "feature-extraction"
952
1116
  });
953
1117
  let isValidOutput = true;
954
1118
  const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
@@ -971,7 +1135,7 @@ async function featureExtraction(args, options) {
971
1135
  async function fillMask(args, options) {
972
1136
  const res = await request(args, {
973
1137
  ...options,
974
- taskHint: "fill-mask"
1138
+ task: "fill-mask"
975
1139
  });
976
1140
  const isValidOutput = Array.isArray(res) && res.every(
977
1141
  (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
@@ -988,7 +1152,7 @@ async function fillMask(args, options) {
988
1152
  async function questionAnswering(args, options) {
989
1153
  const res = await request(args, {
990
1154
  ...options,
991
- taskHint: "question-answering"
1155
+ task: "question-answering"
992
1156
  });
993
1157
  const isValidOutput = Array.isArray(res) ? res.every(
994
1158
  (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
@@ -1003,7 +1167,7 @@ async function questionAnswering(args, options) {
1003
1167
  async function sentenceSimilarity(args, options) {
1004
1168
  const res = await request(prepareInput(args), {
1005
1169
  ...options,
1006
- taskHint: "sentence-similarity"
1170
+ task: "sentence-similarity"
1007
1171
  });
1008
1172
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1009
1173
  if (!isValidOutput) {
@@ -1023,7 +1187,7 @@ function prepareInput(args) {
1023
1187
  async function summarization(args, options) {
1024
1188
  const res = await request(args, {
1025
1189
  ...options,
1026
- taskHint: "summarization"
1190
+ task: "summarization"
1027
1191
  });
1028
1192
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string");
1029
1193
  if (!isValidOutput) {
@@ -1036,7 +1200,7 @@ async function summarization(args, options) {
1036
1200
  async function tableQuestionAnswering(args, options) {
1037
1201
  const res = await request(args, {
1038
1202
  ...options,
1039
- taskHint: "table-question-answering"
1203
+ task: "table-question-answering"
1040
1204
  });
1041
1205
  const isValidOutput = Array.isArray(res) ? res.every((elem) => validate(elem)) : validate(res);
1042
1206
  if (!isValidOutput) {
@@ -1056,7 +1220,7 @@ function validate(elem) {
1056
1220
  async function textClassification(args, options) {
1057
1221
  const res = (await request(args, {
1058
1222
  ...options,
1059
- taskHint: "text-classification"
1223
+ task: "text-classification"
1060
1224
  }))?.[0];
1061
1225
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number");
1062
1226
  if (!isValidOutput) {
@@ -1079,7 +1243,7 @@ async function textGeneration(args, options) {
1079
1243
  args.prompt = args.inputs;
1080
1244
  const raw = await request(args, {
1081
1245
  ...options,
1082
- taskHint: "text-generation"
1246
+ task: "text-generation"
1083
1247
  });
1084
1248
  const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
1085
1249
  if (!isValidOutput) {
@@ -1100,7 +1264,7 @@ async function textGeneration(args, options) {
1100
1264
  };
1101
1265
  const raw = await request(payload, {
1102
1266
  ...options,
1103
- taskHint: "text-generation"
1267
+ task: "text-generation"
1104
1268
  });
1105
1269
  const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
1106
1270
  if (!isValidOutput) {
@@ -1114,7 +1278,7 @@ async function textGeneration(args, options) {
1114
1278
  const res = toArray(
1115
1279
  await request(args, {
1116
1280
  ...options,
1117
- taskHint: "text-generation"
1281
+ task: "text-generation"
1118
1282
  })
1119
1283
  );
1120
1284
  const isValidOutput = Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
@@ -1129,7 +1293,7 @@ async function textGeneration(args, options) {
1129
1293
  async function* textGenerationStream(args, options) {
1130
1294
  yield* streamingRequest(args, {
1131
1295
  ...options,
1132
- taskHint: "text-generation"
1296
+ task: "text-generation"
1133
1297
  });
1134
1298
  }
1135
1299
 
@@ -1138,7 +1302,7 @@ async function tokenClassification(args, options) {
1138
1302
  const res = toArray(
1139
1303
  await request(args, {
1140
1304
  ...options,
1141
- taskHint: "token-classification"
1305
+ task: "token-classification"
1142
1306
  })
1143
1307
  );
1144
1308
  const isValidOutput = Array.isArray(res) && res.every(
@@ -1156,7 +1320,7 @@ async function tokenClassification(args, options) {
1156
1320
  async function translation(args, options) {
1157
1321
  const res = await request(args, {
1158
1322
  ...options,
1159
- taskHint: "translation"
1323
+ task: "translation"
1160
1324
  });
1161
1325
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");
1162
1326
  if (!isValidOutput) {
@@ -1170,7 +1334,7 @@ async function zeroShotClassification(args, options) {
1170
1334
  const res = toArray(
1171
1335
  await request(args, {
1172
1336
  ...options,
1173
- taskHint: "zero-shot-classification"
1337
+ task: "zero-shot-classification"
1174
1338
  })
1175
1339
  );
1176
1340
  const isValidOutput = Array.isArray(res) && res.every(
@@ -1186,7 +1350,7 @@ async function zeroShotClassification(args, options) {
1186
1350
  async function chatCompletion(args, options) {
1187
1351
  const res = await request(args, {
1188
1352
  ...options,
1189
- taskHint: "text-generation",
1353
+ task: "text-generation",
1190
1354
  chatCompletion: true
1191
1355
  });
1192
1356
  const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
@@ -1201,7 +1365,7 @@ async function chatCompletion(args, options) {
1201
1365
  async function* chatCompletionStream(args, options) {
1202
1366
  yield* streamingRequest(args, {
1203
1367
  ...options,
1204
- taskHint: "text-generation",
1368
+ task: "text-generation",
1205
1369
  chatCompletion: true
1206
1370
  });
1207
1371
  }
@@ -1219,7 +1383,7 @@ async function documentQuestionAnswering(args, options) {
1219
1383
  const res = toArray(
1220
1384
  await request(reqArgs, {
1221
1385
  ...options,
1222
- taskHint: "document-question-answering"
1386
+ task: "document-question-answering"
1223
1387
  })
1224
1388
  );
1225
1389
  const isValidOutput = Array.isArray(res) && res.every(
@@ -1243,7 +1407,7 @@ async function visualQuestionAnswering(args, options) {
1243
1407
  };
1244
1408
  const res = await request(reqArgs, {
1245
1409
  ...options,
1246
- taskHint: "visual-question-answering"
1410
+ task: "visual-question-answering"
1247
1411
  });
1248
1412
  const isValidOutput = Array.isArray(res) && res.every(
1249
1413
  (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
@@ -1258,7 +1422,7 @@ async function visualQuestionAnswering(args, options) {
1258
1422
  async function tabularRegression(args, options) {
1259
1423
  const res = await request(args, {
1260
1424
  ...options,
1261
- taskHint: "tabular-regression"
1425
+ task: "tabular-regression"
1262
1426
  });
1263
1427
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1264
1428
  if (!isValidOutput) {
@@ -1271,7 +1435,7 @@ async function tabularRegression(args, options) {
1271
1435
  async function tabularClassification(args, options) {
1272
1436
  const res = await request(args, {
1273
1437
  ...options,
1274
- taskHint: "tabular-classification"
1438
+ task: "tabular-classification"
1275
1439
  });
1276
1440
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1277
1441
  if (!isValidOutput) {
@@ -1323,6 +1487,7 @@ var HfInferenceEndpoint = class {
1323
1487
  // src/types.ts
1324
1488
  var INFERENCE_PROVIDERS = [
1325
1489
  "black-forest-labs",
1490
+ "cohere",
1326
1491
  "fal-ai",
1327
1492
  "fireworks-ai",
1328
1493
  "hf-inference",