@huggingface/inference 3.3.6 → 3.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. package/dist/index.cjs +315 -174
  2. package/dist/index.js +315 -174
  3. package/dist/src/lib/getProviderModelId.d.ts +1 -1
  4. package/dist/src/lib/getProviderModelId.d.ts.map +1 -1
  5. package/dist/src/lib/makeRequestOptions.d.ts +2 -2
  6. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  7. package/dist/src/providers/black-forest-labs.d.ts +2 -1
  8. package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
  9. package/dist/src/providers/fal-ai.d.ts +2 -1
  10. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  11. package/dist/src/providers/fireworks-ai.d.ts +2 -1
  12. package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
  13. package/dist/src/providers/hf-inference.d.ts +3 -0
  14. package/dist/src/providers/hf-inference.d.ts.map +1 -0
  15. package/dist/src/providers/hyperbolic.d.ts +2 -1
  16. package/dist/src/providers/hyperbolic.d.ts.map +1 -1
  17. package/dist/src/providers/nebius.d.ts +2 -1
  18. package/dist/src/providers/nebius.d.ts.map +1 -1
  19. package/dist/src/providers/novita.d.ts +2 -1
  20. package/dist/src/providers/novita.d.ts.map +1 -1
  21. package/dist/src/providers/replicate.d.ts +3 -1
  22. package/dist/src/providers/replicate.d.ts.map +1 -1
  23. package/dist/src/providers/sambanova.d.ts +2 -1
  24. package/dist/src/providers/sambanova.d.ts.map +1 -1
  25. package/dist/src/providers/together.d.ts +2 -1
  26. package/dist/src/providers/together.d.ts.map +1 -1
  27. package/dist/src/tasks/custom/request.d.ts +2 -4
  28. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  29. package/dist/src/tasks/custom/streamingRequest.d.ts +2 -4
  30. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  31. package/dist/src/tasks/nlp/featureExtraction.d.ts +2 -9
  32. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  33. package/dist/src/types.d.ts +24 -3
  34. package/dist/src/types.d.ts.map +1 -1
  35. package/package.json +2 -2
  36. package/src/lib/getProviderModelId.ts +4 -4
  37. package/src/lib/makeRequestOptions.ts +72 -186
  38. package/src/providers/black-forest-labs.ts +26 -2
  39. package/src/providers/consts.ts +1 -1
  40. package/src/providers/fal-ai.ts +24 -2
  41. package/src/providers/fireworks-ai.ts +28 -2
  42. package/src/providers/hf-inference.ts +43 -0
  43. package/src/providers/hyperbolic.ts +28 -2
  44. package/src/providers/nebius.ts +34 -2
  45. package/src/providers/novita.ts +31 -2
  46. package/src/providers/replicate.ts +30 -2
  47. package/src/providers/sambanova.ts +28 -2
  48. package/src/providers/together.ts +34 -2
  49. package/src/tasks/audio/audioClassification.ts +1 -1
  50. package/src/tasks/audio/audioToAudio.ts +1 -1
  51. package/src/tasks/audio/automaticSpeechRecognition.ts +1 -1
  52. package/src/tasks/audio/textToSpeech.ts +1 -1
  53. package/src/tasks/custom/request.ts +2 -4
  54. package/src/tasks/custom/streamingRequest.ts +2 -4
  55. package/src/tasks/cv/imageClassification.ts +1 -1
  56. package/src/tasks/cv/imageSegmentation.ts +1 -1
  57. package/src/tasks/cv/imageToImage.ts +1 -1
  58. package/src/tasks/cv/imageToText.ts +1 -1
  59. package/src/tasks/cv/objectDetection.ts +1 -1
  60. package/src/tasks/cv/textToImage.ts +1 -1
  61. package/src/tasks/cv/textToVideo.ts +1 -1
  62. package/src/tasks/cv/zeroShotImageClassification.ts +1 -1
  63. package/src/tasks/multimodal/documentQuestionAnswering.ts +1 -1
  64. package/src/tasks/multimodal/visualQuestionAnswering.ts +1 -1
  65. package/src/tasks/nlp/chatCompletion.ts +1 -1
  66. package/src/tasks/nlp/chatCompletionStream.ts +1 -1
  67. package/src/tasks/nlp/featureExtraction.ts +3 -10
  68. package/src/tasks/nlp/fillMask.ts +1 -1
  69. package/src/tasks/nlp/questionAnswering.ts +1 -1
  70. package/src/tasks/nlp/sentenceSimilarity.ts +1 -1
  71. package/src/tasks/nlp/summarization.ts +1 -1
  72. package/src/tasks/nlp/tableQuestionAnswering.ts +1 -1
  73. package/src/tasks/nlp/textClassification.ts +1 -1
  74. package/src/tasks/nlp/textGeneration.ts +3 -3
  75. package/src/tasks/nlp/textGenerationStream.ts +1 -1
  76. package/src/tasks/nlp/tokenClassification.ts +1 -1
  77. package/src/tasks/nlp/translation.ts +1 -1
  78. package/src/tasks/nlp/zeroShotClassification.ts +1 -1
  79. package/src/tasks/tabular/tabularClassification.ts +1 -1
  80. package/src/tasks/tabular/tabularRegression.ts +1 -1
  81. package/src/types.ts +28 -2
package/dist/index.cjs CHANGED
@@ -100,32 +100,256 @@ __export(tasks_exports, {
100
100
  var HF_HUB_URL = "https://huggingface.co";
101
101
  var HF_ROUTER_URL = "https://router.huggingface.co";
102
102
 
103
+ // src/providers/black-forest-labs.ts
104
+ var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
105
+ var makeBody = (params) => {
106
+ return params.args;
107
+ };
108
+ var makeHeaders = (params) => {
109
+ if (params.authMethod === "provider-key") {
110
+ return { "X-Key": `${params.accessToken}` };
111
+ } else {
112
+ return { Authorization: `Bearer ${params.accessToken}` };
113
+ }
114
+ };
115
+ var makeUrl = (params) => {
116
+ return `${params.baseUrl}/${params.model}`;
117
+ };
118
+ var BLACK_FOREST_LABS_CONFIG = {
119
+ baseUrl: BLACK_FOREST_LABS_AI_API_BASE_URL,
120
+ makeBody,
121
+ makeHeaders,
122
+ makeUrl
123
+ };
124
+
103
125
  // src/providers/fal-ai.ts
104
126
  var FAL_AI_API_BASE_URL = "https://fal.run";
127
+ var makeBody2 = (params) => {
128
+ return params.args;
129
+ };
130
+ var makeHeaders2 = (params) => {
131
+ return {
132
+ Authorization: params.authMethod === "provider-key" ? `Key ${params.accessToken}` : `Bearer ${params.accessToken}`
133
+ };
134
+ };
135
+ var makeUrl2 = (params) => {
136
+ return `${params.baseUrl}/${params.model}`;
137
+ };
138
+ var FAL_AI_CONFIG = {
139
+ baseUrl: FAL_AI_API_BASE_URL,
140
+ makeBody: makeBody2,
141
+ makeHeaders: makeHeaders2,
142
+ makeUrl: makeUrl2
143
+ };
144
+
145
+ // src/providers/fireworks-ai.ts
146
+ var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
147
+ var makeBody3 = (params) => {
148
+ return {
149
+ ...params.args,
150
+ ...params.chatCompletion ? { model: params.model } : void 0
151
+ };
152
+ };
153
+ var makeHeaders3 = (params) => {
154
+ return { Authorization: `Bearer ${params.accessToken}` };
155
+ };
156
+ var makeUrl3 = (params) => {
157
+ if (params.task === "text-generation" && params.chatCompletion) {
158
+ return `${params.baseUrl}/v1/chat/completions`;
159
+ }
160
+ return params.baseUrl;
161
+ };
162
+ var FIREWORKS_AI_CONFIG = {
163
+ baseUrl: FIREWORKS_AI_API_BASE_URL,
164
+ makeBody: makeBody3,
165
+ makeHeaders: makeHeaders3,
166
+ makeUrl: makeUrl3
167
+ };
168
+
169
+ // src/providers/hf-inference.ts
170
+ var makeBody4 = (params) => {
171
+ return {
172
+ ...params.args,
173
+ ...params.chatCompletion ? { model: params.model } : void 0
174
+ };
175
+ };
176
+ var makeHeaders4 = (params) => {
177
+ return { Authorization: `Bearer ${params.accessToken}` };
178
+ };
179
+ var makeUrl4 = (params) => {
180
+ if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
181
+ return `${params.baseUrl}/pipeline/${params.task}/${params.model}`;
182
+ }
183
+ if (params.task === "text-generation" && params.chatCompletion) {
184
+ return `${params.baseUrl}/models/${params.model}/v1/chat/completions`;
185
+ }
186
+ return `${params.baseUrl}/models/${params.model}`;
187
+ };
188
+ var HF_INFERENCE_CONFIG = {
189
+ baseUrl: `${HF_ROUTER_URL}/hf-inference`,
190
+ makeBody: makeBody4,
191
+ makeHeaders: makeHeaders4,
192
+ makeUrl: makeUrl4
193
+ };
194
+
195
+ // src/providers/hyperbolic.ts
196
+ var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
197
+ var makeBody5 = (params) => {
198
+ return {
199
+ ...params.args,
200
+ ...params.task === "text-to-image" ? { model_name: params.model } : { model: params.model }
201
+ };
202
+ };
203
+ var makeHeaders5 = (params) => {
204
+ return { Authorization: `Bearer ${params.accessToken}` };
205
+ };
206
+ var makeUrl5 = (params) => {
207
+ if (params.task === "text-to-image") {
208
+ return `${params.baseUrl}/v1/images/generations`;
209
+ }
210
+ return `${params.baseUrl}/v1/chat/completions`;
211
+ };
212
+ var HYPERBOLIC_CONFIG = {
213
+ baseUrl: HYPERBOLIC_API_BASE_URL,
214
+ makeBody: makeBody5,
215
+ makeHeaders: makeHeaders5,
216
+ makeUrl: makeUrl5
217
+ };
105
218
 
106
219
  // src/providers/nebius.ts
107
220
  var NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
221
+ var makeBody6 = (params) => {
222
+ return {
223
+ ...params.args,
224
+ model: params.model
225
+ };
226
+ };
227
+ var makeHeaders6 = (params) => {
228
+ return { Authorization: `Bearer ${params.accessToken}` };
229
+ };
230
+ var makeUrl6 = (params) => {
231
+ if (params.task === "text-to-image") {
232
+ return `${params.baseUrl}/v1/images/generations`;
233
+ }
234
+ if (params.task === "text-generation") {
235
+ if (params.chatCompletion) {
236
+ return `${params.baseUrl}/v1/chat/completions`;
237
+ }
238
+ return `${params.baseUrl}/v1/completions`;
239
+ }
240
+ return params.baseUrl;
241
+ };
242
+ var NEBIUS_CONFIG = {
243
+ baseUrl: NEBIUS_API_BASE_URL,
244
+ makeBody: makeBody6,
245
+ makeHeaders: makeHeaders6,
246
+ makeUrl: makeUrl6
247
+ };
248
+
249
+ // src/providers/novita.ts
250
+ var NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
251
+ var makeBody7 = (params) => {
252
+ return {
253
+ ...params.args,
254
+ ...params.chatCompletion ? { model: params.model } : void 0
255
+ };
256
+ };
257
+ var makeHeaders7 = (params) => {
258
+ return { Authorization: `Bearer ${params.accessToken}` };
259
+ };
260
+ var makeUrl7 = (params) => {
261
+ if (params.task === "text-generation") {
262
+ if (params.chatCompletion) {
263
+ return `${params.baseUrl}/chat/completions`;
264
+ }
265
+ return `${params.baseUrl}/completions`;
266
+ }
267
+ return params.baseUrl;
268
+ };
269
+ var NOVITA_CONFIG = {
270
+ baseUrl: NOVITA_API_BASE_URL,
271
+ makeBody: makeBody7,
272
+ makeHeaders: makeHeaders7,
273
+ makeUrl: makeUrl7
274
+ };
108
275
 
109
276
  // src/providers/replicate.ts
110
277
  var REPLICATE_API_BASE_URL = "https://api.replicate.com";
278
+ var makeBody8 = (params) => {
279
+ return {
280
+ input: params.args,
281
+ version: params.model.includes(":") ? params.model.split(":")[1] : void 0
282
+ };
283
+ };
284
+ var makeHeaders8 = (params) => {
285
+ return { Authorization: `Bearer ${params.accessToken}` };
286
+ };
287
+ var makeUrl8 = (params) => {
288
+ if (params.model.includes(":")) {
289
+ return `${params.baseUrl}/v1/predictions`;
290
+ }
291
+ return `${params.baseUrl}/v1/models/${params.model}/predictions`;
292
+ };
293
+ var REPLICATE_CONFIG = {
294
+ baseUrl: REPLICATE_API_BASE_URL,
295
+ makeBody: makeBody8,
296
+ makeHeaders: makeHeaders8,
297
+ makeUrl: makeUrl8
298
+ };
111
299
 
112
300
  // src/providers/sambanova.ts
113
301
  var SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
302
+ var makeBody9 = (params) => {
303
+ return {
304
+ ...params.args,
305
+ ...params.chatCompletion ? { model: params.model } : void 0
306
+ };
307
+ };
308
+ var makeHeaders9 = (params) => {
309
+ return { Authorization: `Bearer ${params.accessToken}` };
310
+ };
311
+ var makeUrl9 = (params) => {
312
+ if (params.task === "text-generation" && params.chatCompletion) {
313
+ return `${params.baseUrl}/v1/chat/completions`;
314
+ }
315
+ return params.baseUrl;
316
+ };
317
+ var SAMBANOVA_CONFIG = {
318
+ baseUrl: SAMBANOVA_API_BASE_URL,
319
+ makeBody: makeBody9,
320
+ makeHeaders: makeHeaders9,
321
+ makeUrl: makeUrl9
322
+ };
114
323
 
115
324
  // src/providers/together.ts
116
325
  var TOGETHER_API_BASE_URL = "https://api.together.xyz";
117
-
118
- // src/providers/novita.ts
119
- var NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
120
-
121
- // src/providers/fireworks-ai.ts
122
- var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
123
-
124
- // src/providers/hyperbolic.ts
125
- var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
126
-
127
- // src/providers/black-forest-labs.ts
128
- var BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
326
+ var makeBody10 = (params) => {
327
+ return {
328
+ ...params.args,
329
+ model: params.model
330
+ };
331
+ };
332
+ var makeHeaders10 = (params) => {
333
+ return { Authorization: `Bearer ${params.accessToken}` };
334
+ };
335
+ var makeUrl10 = (params) => {
336
+ if (params.task === "text-to-image") {
337
+ return `${params.baseUrl}/v1/images/generations`;
338
+ }
339
+ if (params.task === "text-generation") {
340
+ if (params.chatCompletion) {
341
+ return `${params.baseUrl}/v1/chat/completions`;
342
+ }
343
+ return `${params.baseUrl}/v1/completions`;
344
+ }
345
+ return params.baseUrl;
346
+ };
347
+ var TOGETHER_CONFIG = {
348
+ baseUrl: TOGETHER_API_BASE_URL,
349
+ makeBody: makeBody10,
350
+ makeHeaders: makeHeaders10,
351
+ makeUrl: makeUrl10
352
+ };
129
353
 
130
354
  // src/lib/isUrl.ts
131
355
  function isUrl(modelOrUrl) {
@@ -134,7 +358,7 @@ function isUrl(modelOrUrl) {
134
358
 
135
359
  // package.json
136
360
  var name = "@huggingface/inference";
137
- var version = "3.3.6";
361
+ var version = "3.3.7";
138
362
 
139
363
  // src/providers/consts.ts
140
364
  var HARDCODED_MODEL_ID_MAPPING = {
@@ -150,10 +374,10 @@ var HARDCODED_MODEL_ID_MAPPING = {
150
374
  "hf-inference": {},
151
375
  hyperbolic: {},
152
376
  nebius: {},
377
+ novita: {},
153
378
  replicate: {},
154
379
  sambanova: {},
155
- together: {},
156
- novita: {}
380
+ together: {}
157
381
  };
158
382
 
159
383
  // src/lib/getProviderModelId.ts
@@ -162,10 +386,10 @@ async function getProviderModelId(params, args, options = {}) {
162
386
  if (params.provider === "hf-inference") {
163
387
  return params.model;
164
388
  }
165
- if (!options.taskHint) {
166
- throw new Error("taskHint must be specified when using a third-party provider");
389
+ if (!options.task) {
390
+ throw new Error("task must be specified when using a third-party provider");
167
391
  }
168
- const task = options.taskHint === "text-generation" && options.chatCompletion ? "conversational" : options.taskHint;
392
+ const task = options.task === "text-generation" && options.chatCompletion ? "conversational" : options.task;
169
393
  if (HARDCODED_MODEL_ID_MAPPING[params.provider]?.[params.model]) {
170
394
  return HARDCODED_MODEL_ID_MAPPING[params.provider][params.model];
171
395
  }
@@ -203,165 +427,82 @@ async function getProviderModelId(params, args, options = {}) {
203
427
  // src/lib/makeRequestOptions.ts
204
428
  var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
205
429
  var tasks = null;
430
+ var providerConfigs = {
431
+ "black-forest-labs": BLACK_FOREST_LABS_CONFIG,
432
+ "fal-ai": FAL_AI_CONFIG,
433
+ "fireworks-ai": FIREWORKS_AI_CONFIG,
434
+ "hf-inference": HF_INFERENCE_CONFIG,
435
+ hyperbolic: HYPERBOLIC_CONFIG,
436
+ nebius: NEBIUS_CONFIG,
437
+ novita: NOVITA_CONFIG,
438
+ replicate: REPLICATE_CONFIG,
439
+ sambanova: SAMBANOVA_CONFIG,
440
+ together: TOGETHER_CONFIG
441
+ };
206
442
  async function makeRequestOptions(args, options) {
207
443
  const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
208
- let otherArgs = remainingArgs;
209
444
  const provider = maybeProvider ?? "hf-inference";
210
- const { includeCredentials, taskHint, chatCompletion: chatCompletion2 } = options ?? {};
445
+ const providerConfig = providerConfigs[provider];
446
+ const { includeCredentials, task, chatCompletion: chatCompletion2, signal } = options ?? {};
211
447
  if (endpointUrl && provider !== "hf-inference") {
212
448
  throw new Error(`Cannot use endpointUrl with a third-party provider.`);
213
449
  }
214
450
  if (maybeModel && isUrl(maybeModel)) {
215
451
  throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
216
452
  }
217
- if (!maybeModel && !taskHint) {
453
+ if (!maybeModel && !task) {
218
454
  throw new Error("No model provided, and no task has been specified.");
219
455
  }
220
- const hfModel = maybeModel ?? await loadDefaultModel(taskHint);
456
+ if (!providerConfig) {
457
+ throw new Error(`No provider config found for provider ${provider}`);
458
+ }
459
+ const hfModel = maybeModel ?? await loadDefaultModel(task);
221
460
  const model = await getProviderModelId({ model: hfModel, provider }, args, {
222
- taskHint,
461
+ task,
223
462
  chatCompletion: chatCompletion2,
224
463
  fetch: options?.fetch
225
464
  });
226
465
  const authMethod = accessToken ? accessToken.startsWith("hf_") ? "hf-token" : "provider-key" : includeCredentials === "include" ? "credentials-include" : "none";
227
- const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : makeUrl({
228
- authMethod,
229
- chatCompletion: chatCompletion2 ?? false,
466
+ const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : providerConfig.makeUrl({
467
+ baseUrl: authMethod !== "provider-key" ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider) : providerConfig.baseUrl,
230
468
  model,
231
- provider: provider ?? "hf-inference",
232
- taskHint
469
+ chatCompletion: chatCompletion2,
470
+ task
233
471
  });
234
- const headers = {};
235
- if (accessToken) {
236
- if (provider === "fal-ai" && authMethod === "provider-key") {
237
- headers["Authorization"] = `Key ${accessToken}`;
238
- } else if (provider === "black-forest-labs" && authMethod === "provider-key") {
239
- headers["X-Key"] = accessToken;
240
- } else {
241
- headers["Authorization"] = `Bearer ${accessToken}`;
242
- }
243
- }
244
- const ownUserAgent = `${name}/${version}`;
245
- headers["User-Agent"] = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : void 0].filter((x) => x !== void 0).join(" ");
246
472
  const binary = "data" in args && !!args.data;
473
+ const headers = providerConfig.makeHeaders({
474
+ accessToken,
475
+ authMethod
476
+ });
247
477
  if (!binary) {
248
478
  headers["Content-Type"] = "application/json";
249
479
  }
250
- if (provider === "replicate") {
251
- headers["Prefer"] = "wait";
252
- }
480
+ const ownUserAgent = `${name}/${version}`;
481
+ const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : void 0].filter((x) => x !== void 0).join(" ");
482
+ headers["User-Agent"] = userAgent;
483
+ const body = binary ? args.data : JSON.stringify(
484
+ providerConfig.makeBody({
485
+ args: remainingArgs,
486
+ model,
487
+ task,
488
+ chatCompletion: chatCompletion2
489
+ })
490
+ );
253
491
  let credentials;
254
492
  if (typeof includeCredentials === "string") {
255
493
  credentials = includeCredentials;
256
494
  } else if (includeCredentials === true) {
257
495
  credentials = "include";
258
496
  }
259
- if (provider === "replicate") {
260
- const version2 = model.includes(":") ? model.split(":")[1] : void 0;
261
- otherArgs = { input: otherArgs, version: version2 };
262
- }
263
497
  const info = {
264
498
  headers,
265
499
  method: "POST",
266
- body: binary ? args.data : JSON.stringify({
267
- ...otherArgs,
268
- ...taskHint === "text-to-image" && provider === "hyperbolic" ? { model_name: model } : chatCompletion2 || provider === "together" || provider === "nebius" || provider === "hyperbolic" ? { model } : void 0
269
- }),
500
+ body,
270
501
  ...credentials ? { credentials } : void 0,
271
- signal: options?.signal
502
+ signal
272
503
  };
273
504
  return { url, info };
274
505
  }
275
- function makeUrl(params) {
276
- if (params.authMethod === "none" && params.provider !== "hf-inference") {
277
- throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
278
- }
279
- const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
280
- switch (params.provider) {
281
- case "black-forest-labs": {
282
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : BLACKFORESTLABS_AI_API_BASE_URL;
283
- return `${baseUrl}/${params.model}`;
284
- }
285
- case "fal-ai": {
286
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FAL_AI_API_BASE_URL;
287
- return `${baseUrl}/${params.model}`;
288
- }
289
- case "nebius": {
290
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : NEBIUS_API_BASE_URL;
291
- if (params.taskHint === "text-to-image") {
292
- return `${baseUrl}/v1/images/generations`;
293
- }
294
- if (params.taskHint === "text-generation") {
295
- if (params.chatCompletion) {
296
- return `${baseUrl}/v1/chat/completions`;
297
- }
298
- return `${baseUrl}/v1/completions`;
299
- }
300
- return baseUrl;
301
- }
302
- case "replicate": {
303
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : REPLICATE_API_BASE_URL;
304
- if (params.model.includes(":")) {
305
- return `${baseUrl}/v1/predictions`;
306
- }
307
- return `${baseUrl}/v1/models/${params.model}/predictions`;
308
- }
309
- case "sambanova": {
310
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : SAMBANOVA_API_BASE_URL;
311
- if (params.taskHint === "text-generation" && params.chatCompletion) {
312
- return `${baseUrl}/v1/chat/completions`;
313
- }
314
- return baseUrl;
315
- }
316
- case "together": {
317
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : TOGETHER_API_BASE_URL;
318
- if (params.taskHint === "text-to-image") {
319
- return `${baseUrl}/v1/images/generations`;
320
- }
321
- if (params.taskHint === "text-generation") {
322
- if (params.chatCompletion) {
323
- return `${baseUrl}/v1/chat/completions`;
324
- }
325
- return `${baseUrl}/v1/completions`;
326
- }
327
- return baseUrl;
328
- }
329
- case "fireworks-ai": {
330
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FIREWORKS_AI_API_BASE_URL;
331
- if (params.taskHint === "text-generation" && params.chatCompletion) {
332
- return `${baseUrl}/v1/chat/completions`;
333
- }
334
- return baseUrl;
335
- }
336
- case "hyperbolic": {
337
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : HYPERBOLIC_API_BASE_URL;
338
- if (params.taskHint === "text-to-image") {
339
- return `${baseUrl}/v1/images/generations`;
340
- }
341
- return `${baseUrl}/v1/chat/completions`;
342
- }
343
- case "novita": {
344
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : NOVITA_API_BASE_URL;
345
- if (params.taskHint === "text-generation") {
346
- if (params.chatCompletion) {
347
- return `${baseUrl}/chat/completions`;
348
- }
349
- return `${baseUrl}/completions`;
350
- }
351
- return baseUrl;
352
- }
353
- default: {
354
- const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
355
- if (params.taskHint && ["feature-extraction", "sentence-similarity"].includes(params.taskHint)) {
356
- return `${baseUrl}/pipeline/${params.taskHint}/${params.model}`;
357
- }
358
- if (params.taskHint === "text-generation" && params.chatCompletion) {
359
- return `${baseUrl}/models/${params.model}/v1/chat/completions`;
360
- }
361
- return `${baseUrl}/models/${params.model}`;
362
- }
363
- }
364
- }
365
506
  async function loadDefaultModel(task) {
366
507
  if (!tasks) {
367
508
  tasks = await loadTaskInfo();
@@ -628,7 +769,7 @@ async function audioClassification(args, options) {
628
769
  const payload = preparePayload(args);
629
770
  const res = await request(payload, {
630
771
  ...options,
631
- taskHint: "audio-classification"
772
+ task: "audio-classification"
632
773
  });
633
774
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
634
775
  if (!isValidOutput) {
@@ -655,7 +796,7 @@ async function automaticSpeechRecognition(args, options) {
655
796
  const payload = await buildPayload(args);
656
797
  const res = await request(payload, {
657
798
  ...options,
658
- taskHint: "automatic-speech-recognition"
799
+ task: "automatic-speech-recognition"
659
800
  });
660
801
  const isValidOutput = typeof res?.text === "string";
661
802
  if (!isValidOutput) {
@@ -699,7 +840,7 @@ async function textToSpeech(args, options) {
699
840
  } : args;
700
841
  const res = await request(payload, {
701
842
  ...options,
702
- taskHint: "text-to-speech"
843
+ task: "text-to-speech"
703
844
  });
704
845
  if (res instanceof Blob) {
705
846
  return res;
@@ -725,7 +866,7 @@ async function audioToAudio(args, options) {
725
866
  const payload = preparePayload(args);
726
867
  const res = await request(payload, {
727
868
  ...options,
728
- taskHint: "audio-to-audio"
869
+ task: "audio-to-audio"
729
870
  });
730
871
  return validateOutput(res);
731
872
  }
@@ -751,7 +892,7 @@ async function imageClassification(args, options) {
751
892
  const payload = preparePayload2(args);
752
893
  const res = await request(payload, {
753
894
  ...options,
754
- taskHint: "image-classification"
895
+ task: "image-classification"
755
896
  });
756
897
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
757
898
  if (!isValidOutput) {
@@ -765,7 +906,7 @@ async function imageSegmentation(args, options) {
765
906
  const payload = preparePayload2(args);
766
907
  const res = await request(payload, {
767
908
  ...options,
768
- taskHint: "image-segmentation"
909
+ task: "image-segmentation"
769
910
  });
770
911
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
771
912
  if (!isValidOutput) {
@@ -779,7 +920,7 @@ async function imageToText(args, options) {
779
920
  const payload = preparePayload2(args);
780
921
  const res = (await request(payload, {
781
922
  ...options,
782
- taskHint: "image-to-text"
923
+ task: "image-to-text"
783
924
  }))?.[0];
784
925
  if (typeof res?.generated_text !== "string") {
785
926
  throw new InferenceOutputError("Expected {generated_text: string}");
@@ -792,7 +933,7 @@ async function objectDetection(args, options) {
792
933
  const payload = preparePayload2(args);
793
934
  const res = await request(payload, {
794
935
  ...options,
795
- taskHint: "object-detection"
936
+ task: "object-detection"
796
937
  });
797
938
  const isValidOutput = Array.isArray(res) && res.every(
798
939
  (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
@@ -836,7 +977,7 @@ async function textToImage(args, options) {
836
977
  };
837
978
  const res = await request(payload, {
838
979
  ...options,
839
- taskHint: "text-to-image"
980
+ task: "text-to-image"
840
981
  });
841
982
  if (res && typeof res === "object") {
842
983
  if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
@@ -925,7 +1066,7 @@ async function imageToImage(args, options) {
925
1066
  }
926
1067
  const res = await request(reqArgs, {
927
1068
  ...options,
928
- taskHint: "image-to-image"
1069
+ task: "image-to-image"
929
1070
  });
930
1071
  const isValidOutput = res && res instanceof Blob;
931
1072
  if (!isValidOutput) {
@@ -960,7 +1101,7 @@ async function zeroShotImageClassification(args, options) {
960
1101
  const payload = await preparePayload3(args);
961
1102
  const res = await request(payload, {
962
1103
  ...options,
963
- taskHint: "zero-shot-image-classification"
1104
+ task: "zero-shot-image-classification"
964
1105
  });
965
1106
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
966
1107
  if (!isValidOutput) {
@@ -980,7 +1121,7 @@ async function textToVideo(args, options) {
980
1121
  const payload = args.provider === "fal-ai" || args.provider === "replicate" ? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs } : args;
981
1122
  const res = await request(payload, {
982
1123
  ...options,
983
- taskHint: "text-to-video"
1124
+ task: "text-to-video"
984
1125
  });
985
1126
  if (args.provider === "fal-ai") {
986
1127
  const isValidOutput = typeof res === "object" && !!res && "video" in res && typeof res.video === "object" && !!res.video && "url" in res.video && typeof res.video.url === "string" && isUrl(res.video.url);
@@ -1003,7 +1144,7 @@ async function textToVideo(args, options) {
1003
1144
  async function featureExtraction(args, options) {
1004
1145
  const res = await request(args, {
1005
1146
  ...options,
1006
- taskHint: "feature-extraction"
1147
+ task: "feature-extraction"
1007
1148
  });
1008
1149
  let isValidOutput = true;
1009
1150
  const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
@@ -1026,7 +1167,7 @@ async function featureExtraction(args, options) {
1026
1167
  async function fillMask(args, options) {
1027
1168
  const res = await request(args, {
1028
1169
  ...options,
1029
- taskHint: "fill-mask"
1170
+ task: "fill-mask"
1030
1171
  });
1031
1172
  const isValidOutput = Array.isArray(res) && res.every(
1032
1173
  (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
@@ -1043,7 +1184,7 @@ async function fillMask(args, options) {
1043
1184
  async function questionAnswering(args, options) {
1044
1185
  const res = await request(args, {
1045
1186
  ...options,
1046
- taskHint: "question-answering"
1187
+ task: "question-answering"
1047
1188
  });
1048
1189
  const isValidOutput = Array.isArray(res) ? res.every(
1049
1190
  (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
@@ -1058,7 +1199,7 @@ async function questionAnswering(args, options) {
1058
1199
  async function sentenceSimilarity(args, options) {
1059
1200
  const res = await request(prepareInput(args), {
1060
1201
  ...options,
1061
- taskHint: "sentence-similarity"
1202
+ task: "sentence-similarity"
1062
1203
  });
1063
1204
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1064
1205
  if (!isValidOutput) {
@@ -1078,7 +1219,7 @@ function prepareInput(args) {
1078
1219
  async function summarization(args, options) {
1079
1220
  const res = await request(args, {
1080
1221
  ...options,
1081
- taskHint: "summarization"
1222
+ task: "summarization"
1082
1223
  });
1083
1224
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string");
1084
1225
  if (!isValidOutput) {
@@ -1091,7 +1232,7 @@ async function summarization(args, options) {
1091
1232
  async function tableQuestionAnswering(args, options) {
1092
1233
  const res = await request(args, {
1093
1234
  ...options,
1094
- taskHint: "table-question-answering"
1235
+ task: "table-question-answering"
1095
1236
  });
1096
1237
  const isValidOutput = Array.isArray(res) ? res.every((elem) => validate(elem)) : validate(res);
1097
1238
  if (!isValidOutput) {
@@ -1111,7 +1252,7 @@ function validate(elem) {
1111
1252
  async function textClassification(args, options) {
1112
1253
  const res = (await request(args, {
1113
1254
  ...options,
1114
- taskHint: "text-classification"
1255
+ task: "text-classification"
1115
1256
  }))?.[0];
1116
1257
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number");
1117
1258
  if (!isValidOutput) {
@@ -1134,7 +1275,7 @@ async function textGeneration(args, options) {
1134
1275
  args.prompt = args.inputs;
1135
1276
  const raw = await request(args, {
1136
1277
  ...options,
1137
- taskHint: "text-generation"
1278
+ task: "text-generation"
1138
1279
  });
1139
1280
  const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
1140
1281
  if (!isValidOutput) {
@@ -1155,7 +1296,7 @@ async function textGeneration(args, options) {
1155
1296
  };
1156
1297
  const raw = await request(payload, {
1157
1298
  ...options,
1158
- taskHint: "text-generation"
1299
+ task: "text-generation"
1159
1300
  });
1160
1301
  const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
1161
1302
  if (!isValidOutput) {
@@ -1169,7 +1310,7 @@ async function textGeneration(args, options) {
1169
1310
  const res = toArray(
1170
1311
  await request(args, {
1171
1312
  ...options,
1172
- taskHint: "text-generation"
1313
+ task: "text-generation"
1173
1314
  })
1174
1315
  );
1175
1316
  const isValidOutput = Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
@@ -1184,7 +1325,7 @@ async function textGeneration(args, options) {
1184
1325
  async function* textGenerationStream(args, options) {
1185
1326
  yield* streamingRequest(args, {
1186
1327
  ...options,
1187
- taskHint: "text-generation"
1328
+ task: "text-generation"
1188
1329
  });
1189
1330
  }
1190
1331
 
@@ -1193,7 +1334,7 @@ async function tokenClassification(args, options) {
1193
1334
  const res = toArray(
1194
1335
  await request(args, {
1195
1336
  ...options,
1196
- taskHint: "token-classification"
1337
+ task: "token-classification"
1197
1338
  })
1198
1339
  );
1199
1340
  const isValidOutput = Array.isArray(res) && res.every(
@@ -1211,7 +1352,7 @@ async function tokenClassification(args, options) {
1211
1352
  async function translation(args, options) {
1212
1353
  const res = await request(args, {
1213
1354
  ...options,
1214
- taskHint: "translation"
1355
+ task: "translation"
1215
1356
  });
1216
1357
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");
1217
1358
  if (!isValidOutput) {
@@ -1225,7 +1366,7 @@ async function zeroShotClassification(args, options) {
1225
1366
  const res = toArray(
1226
1367
  await request(args, {
1227
1368
  ...options,
1228
- taskHint: "zero-shot-classification"
1369
+ task: "zero-shot-classification"
1229
1370
  })
1230
1371
  );
1231
1372
  const isValidOutput = Array.isArray(res) && res.every(
@@ -1241,7 +1382,7 @@ async function zeroShotClassification(args, options) {
1241
1382
  async function chatCompletion(args, options) {
1242
1383
  const res = await request(args, {
1243
1384
  ...options,
1244
- taskHint: "text-generation",
1385
+ task: "text-generation",
1245
1386
  chatCompletion: true
1246
1387
  });
1247
1388
  const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
@@ -1256,7 +1397,7 @@ async function chatCompletion(args, options) {
1256
1397
  async function* chatCompletionStream(args, options) {
1257
1398
  yield* streamingRequest(args, {
1258
1399
  ...options,
1259
- taskHint: "text-generation",
1400
+ task: "text-generation",
1260
1401
  chatCompletion: true
1261
1402
  });
1262
1403
  }
@@ -1274,7 +1415,7 @@ async function documentQuestionAnswering(args, options) {
1274
1415
  const res = toArray(
1275
1416
  await request(reqArgs, {
1276
1417
  ...options,
1277
- taskHint: "document-question-answering"
1418
+ task: "document-question-answering"
1278
1419
  })
1279
1420
  );
1280
1421
  const isValidOutput = Array.isArray(res) && res.every(
@@ -1298,7 +1439,7 @@ async function visualQuestionAnswering(args, options) {
1298
1439
  };
1299
1440
  const res = await request(reqArgs, {
1300
1441
  ...options,
1301
- taskHint: "visual-question-answering"
1442
+ task: "visual-question-answering"
1302
1443
  });
1303
1444
  const isValidOutput = Array.isArray(res) && res.every(
1304
1445
  (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
@@ -1313,7 +1454,7 @@ async function visualQuestionAnswering(args, options) {
1313
1454
  async function tabularRegression(args, options) {
1314
1455
  const res = await request(args, {
1315
1456
  ...options,
1316
- taskHint: "tabular-regression"
1457
+ task: "tabular-regression"
1317
1458
  });
1318
1459
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1319
1460
  if (!isValidOutput) {
@@ -1326,7 +1467,7 @@ async function tabularRegression(args, options) {
1326
1467
  async function tabularClassification(args, options) {
1327
1468
  const res = await request(args, {
1328
1469
  ...options,
1329
- taskHint: "tabular-classification"
1470
+ task: "tabular-classification"
1330
1471
  });
1331
1472
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1332
1473
  if (!isValidOutput) {