@huggingface/inference 3.3.6 → 3.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. package/dist/index.cjs +315 -174
  2. package/dist/index.js +315 -174
  3. package/dist/src/lib/getProviderModelId.d.ts +1 -1
  4. package/dist/src/lib/getProviderModelId.d.ts.map +1 -1
  5. package/dist/src/lib/makeRequestOptions.d.ts +2 -2
  6. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  7. package/dist/src/providers/black-forest-labs.d.ts +2 -1
  8. package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
  9. package/dist/src/providers/fal-ai.d.ts +2 -1
  10. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  11. package/dist/src/providers/fireworks-ai.d.ts +2 -1
  12. package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
  13. package/dist/src/providers/hf-inference.d.ts +3 -0
  14. package/dist/src/providers/hf-inference.d.ts.map +1 -0
  15. package/dist/src/providers/hyperbolic.d.ts +2 -1
  16. package/dist/src/providers/hyperbolic.d.ts.map +1 -1
  17. package/dist/src/providers/nebius.d.ts +2 -1
  18. package/dist/src/providers/nebius.d.ts.map +1 -1
  19. package/dist/src/providers/novita.d.ts +2 -1
  20. package/dist/src/providers/novita.d.ts.map +1 -1
  21. package/dist/src/providers/replicate.d.ts +3 -1
  22. package/dist/src/providers/replicate.d.ts.map +1 -1
  23. package/dist/src/providers/sambanova.d.ts +2 -1
  24. package/dist/src/providers/sambanova.d.ts.map +1 -1
  25. package/dist/src/providers/together.d.ts +2 -1
  26. package/dist/src/providers/together.d.ts.map +1 -1
  27. package/dist/src/tasks/custom/request.d.ts +2 -4
  28. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  29. package/dist/src/tasks/custom/streamingRequest.d.ts +2 -4
  30. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  31. package/dist/src/tasks/nlp/featureExtraction.d.ts +2 -9
  32. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  33. package/dist/src/types.d.ts +24 -3
  34. package/dist/src/types.d.ts.map +1 -1
  35. package/package.json +2 -2
  36. package/src/lib/getProviderModelId.ts +4 -4
  37. package/src/lib/makeRequestOptions.ts +72 -186
  38. package/src/providers/black-forest-labs.ts +26 -2
  39. package/src/providers/consts.ts +1 -1
  40. package/src/providers/fal-ai.ts +24 -2
  41. package/src/providers/fireworks-ai.ts +28 -2
  42. package/src/providers/hf-inference.ts +43 -0
  43. package/src/providers/hyperbolic.ts +28 -2
  44. package/src/providers/nebius.ts +34 -2
  45. package/src/providers/novita.ts +31 -2
  46. package/src/providers/replicate.ts +30 -2
  47. package/src/providers/sambanova.ts +28 -2
  48. package/src/providers/together.ts +34 -2
  49. package/src/tasks/audio/audioClassification.ts +1 -1
  50. package/src/tasks/audio/audioToAudio.ts +1 -1
  51. package/src/tasks/audio/automaticSpeechRecognition.ts +1 -1
  52. package/src/tasks/audio/textToSpeech.ts +1 -1
  53. package/src/tasks/custom/request.ts +2 -4
  54. package/src/tasks/custom/streamingRequest.ts +2 -4
  55. package/src/tasks/cv/imageClassification.ts +1 -1
  56. package/src/tasks/cv/imageSegmentation.ts +1 -1
  57. package/src/tasks/cv/imageToImage.ts +1 -1
  58. package/src/tasks/cv/imageToText.ts +1 -1
  59. package/src/tasks/cv/objectDetection.ts +1 -1
  60. package/src/tasks/cv/textToImage.ts +1 -1
  61. package/src/tasks/cv/textToVideo.ts +1 -1
  62. package/src/tasks/cv/zeroShotImageClassification.ts +1 -1
  63. package/src/tasks/multimodal/documentQuestionAnswering.ts +1 -1
  64. package/src/tasks/multimodal/visualQuestionAnswering.ts +1 -1
  65. package/src/tasks/nlp/chatCompletion.ts +1 -1
  66. package/src/tasks/nlp/chatCompletionStream.ts +1 -1
  67. package/src/tasks/nlp/featureExtraction.ts +3 -10
  68. package/src/tasks/nlp/fillMask.ts +1 -1
  69. package/src/tasks/nlp/questionAnswering.ts +1 -1
  70. package/src/tasks/nlp/sentenceSimilarity.ts +1 -1
  71. package/src/tasks/nlp/summarization.ts +1 -1
  72. package/src/tasks/nlp/tableQuestionAnswering.ts +1 -1
  73. package/src/tasks/nlp/textClassification.ts +1 -1
  74. package/src/tasks/nlp/textGeneration.ts +3 -3
  75. package/src/tasks/nlp/textGenerationStream.ts +1 -1
  76. package/src/tasks/nlp/tokenClassification.ts +1 -1
  77. package/src/tasks/nlp/translation.ts +1 -1
  78. package/src/tasks/nlp/zeroShotClassification.ts +1 -1
  79. package/src/tasks/tabular/tabularClassification.ts +1 -1
  80. package/src/tasks/tabular/tabularRegression.ts +1 -1
  81. package/src/types.ts +28 -2
package/dist/index.js CHANGED
@@ -45,32 +45,256 @@ __export(tasks_exports, {
45
45
  var HF_HUB_URL = "https://huggingface.co";
46
46
  var HF_ROUTER_URL = "https://router.huggingface.co";
47
47
 
48
+ // src/providers/black-forest-labs.ts
49
+ var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
50
+ var makeBody = (params) => {
51
+ return params.args;
52
+ };
53
+ var makeHeaders = (params) => {
54
+ if (params.authMethod === "provider-key") {
55
+ return { "X-Key": `${params.accessToken}` };
56
+ } else {
57
+ return { Authorization: `Bearer ${params.accessToken}` };
58
+ }
59
+ };
60
+ var makeUrl = (params) => {
61
+ return `${params.baseUrl}/${params.model}`;
62
+ };
63
+ var BLACK_FOREST_LABS_CONFIG = {
64
+ baseUrl: BLACK_FOREST_LABS_AI_API_BASE_URL,
65
+ makeBody,
66
+ makeHeaders,
67
+ makeUrl
68
+ };
69
+
48
70
  // src/providers/fal-ai.ts
49
71
  var FAL_AI_API_BASE_URL = "https://fal.run";
72
+ var makeBody2 = (params) => {
73
+ return params.args;
74
+ };
75
+ var makeHeaders2 = (params) => {
76
+ return {
77
+ Authorization: params.authMethod === "provider-key" ? `Key ${params.accessToken}` : `Bearer ${params.accessToken}`
78
+ };
79
+ };
80
+ var makeUrl2 = (params) => {
81
+ return `${params.baseUrl}/${params.model}`;
82
+ };
83
+ var FAL_AI_CONFIG = {
84
+ baseUrl: FAL_AI_API_BASE_URL,
85
+ makeBody: makeBody2,
86
+ makeHeaders: makeHeaders2,
87
+ makeUrl: makeUrl2
88
+ };
89
+
90
+ // src/providers/fireworks-ai.ts
91
+ var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
92
+ var makeBody3 = (params) => {
93
+ return {
94
+ ...params.args,
95
+ ...params.chatCompletion ? { model: params.model } : void 0
96
+ };
97
+ };
98
+ var makeHeaders3 = (params) => {
99
+ return { Authorization: `Bearer ${params.accessToken}` };
100
+ };
101
+ var makeUrl3 = (params) => {
102
+ if (params.task === "text-generation" && params.chatCompletion) {
103
+ return `${params.baseUrl}/v1/chat/completions`;
104
+ }
105
+ return params.baseUrl;
106
+ };
107
+ var FIREWORKS_AI_CONFIG = {
108
+ baseUrl: FIREWORKS_AI_API_BASE_URL,
109
+ makeBody: makeBody3,
110
+ makeHeaders: makeHeaders3,
111
+ makeUrl: makeUrl3
112
+ };
113
+
114
+ // src/providers/hf-inference.ts
115
+ var makeBody4 = (params) => {
116
+ return {
117
+ ...params.args,
118
+ ...params.chatCompletion ? { model: params.model } : void 0
119
+ };
120
+ };
121
+ var makeHeaders4 = (params) => {
122
+ return { Authorization: `Bearer ${params.accessToken}` };
123
+ };
124
+ var makeUrl4 = (params) => {
125
+ if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
126
+ return `${params.baseUrl}/pipeline/${params.task}/${params.model}`;
127
+ }
128
+ if (params.task === "text-generation" && params.chatCompletion) {
129
+ return `${params.baseUrl}/models/${params.model}/v1/chat/completions`;
130
+ }
131
+ return `${params.baseUrl}/models/${params.model}`;
132
+ };
133
+ var HF_INFERENCE_CONFIG = {
134
+ baseUrl: `${HF_ROUTER_URL}/hf-inference`,
135
+ makeBody: makeBody4,
136
+ makeHeaders: makeHeaders4,
137
+ makeUrl: makeUrl4
138
+ };
139
+
140
+ // src/providers/hyperbolic.ts
141
+ var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
142
+ var makeBody5 = (params) => {
143
+ return {
144
+ ...params.args,
145
+ ...params.task === "text-to-image" ? { model_name: params.model } : { model: params.model }
146
+ };
147
+ };
148
+ var makeHeaders5 = (params) => {
149
+ return { Authorization: `Bearer ${params.accessToken}` };
150
+ };
151
+ var makeUrl5 = (params) => {
152
+ if (params.task === "text-to-image") {
153
+ return `${params.baseUrl}/v1/images/generations`;
154
+ }
155
+ return `${params.baseUrl}/v1/chat/completions`;
156
+ };
157
+ var HYPERBOLIC_CONFIG = {
158
+ baseUrl: HYPERBOLIC_API_BASE_URL,
159
+ makeBody: makeBody5,
160
+ makeHeaders: makeHeaders5,
161
+ makeUrl: makeUrl5
162
+ };
50
163
 
51
164
  // src/providers/nebius.ts
52
165
  var NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
166
+ var makeBody6 = (params) => {
167
+ return {
168
+ ...params.args,
169
+ model: params.model
170
+ };
171
+ };
172
+ var makeHeaders6 = (params) => {
173
+ return { Authorization: `Bearer ${params.accessToken}` };
174
+ };
175
+ var makeUrl6 = (params) => {
176
+ if (params.task === "text-to-image") {
177
+ return `${params.baseUrl}/v1/images/generations`;
178
+ }
179
+ if (params.task === "text-generation") {
180
+ if (params.chatCompletion) {
181
+ return `${params.baseUrl}/v1/chat/completions`;
182
+ }
183
+ return `${params.baseUrl}/v1/completions`;
184
+ }
185
+ return params.baseUrl;
186
+ };
187
+ var NEBIUS_CONFIG = {
188
+ baseUrl: NEBIUS_API_BASE_URL,
189
+ makeBody: makeBody6,
190
+ makeHeaders: makeHeaders6,
191
+ makeUrl: makeUrl6
192
+ };
193
+
194
+ // src/providers/novita.ts
195
+ var NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
196
+ var makeBody7 = (params) => {
197
+ return {
198
+ ...params.args,
199
+ ...params.chatCompletion ? { model: params.model } : void 0
200
+ };
201
+ };
202
+ var makeHeaders7 = (params) => {
203
+ return { Authorization: `Bearer ${params.accessToken}` };
204
+ };
205
+ var makeUrl7 = (params) => {
206
+ if (params.task === "text-generation") {
207
+ if (params.chatCompletion) {
208
+ return `${params.baseUrl}/chat/completions`;
209
+ }
210
+ return `${params.baseUrl}/completions`;
211
+ }
212
+ return params.baseUrl;
213
+ };
214
+ var NOVITA_CONFIG = {
215
+ baseUrl: NOVITA_API_BASE_URL,
216
+ makeBody: makeBody7,
217
+ makeHeaders: makeHeaders7,
218
+ makeUrl: makeUrl7
219
+ };
53
220
 
54
221
  // src/providers/replicate.ts
55
222
  var REPLICATE_API_BASE_URL = "https://api.replicate.com";
223
+ var makeBody8 = (params) => {
224
+ return {
225
+ input: params.args,
226
+ version: params.model.includes(":") ? params.model.split(":")[1] : void 0
227
+ };
228
+ };
229
+ var makeHeaders8 = (params) => {
230
+ return { Authorization: `Bearer ${params.accessToken}` };
231
+ };
232
+ var makeUrl8 = (params) => {
233
+ if (params.model.includes(":")) {
234
+ return `${params.baseUrl}/v1/predictions`;
235
+ }
236
+ return `${params.baseUrl}/v1/models/${params.model}/predictions`;
237
+ };
238
+ var REPLICATE_CONFIG = {
239
+ baseUrl: REPLICATE_API_BASE_URL,
240
+ makeBody: makeBody8,
241
+ makeHeaders: makeHeaders8,
242
+ makeUrl: makeUrl8
243
+ };
56
244
 
57
245
  // src/providers/sambanova.ts
58
246
  var SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
247
+ var makeBody9 = (params) => {
248
+ return {
249
+ ...params.args,
250
+ ...params.chatCompletion ? { model: params.model } : void 0
251
+ };
252
+ };
253
+ var makeHeaders9 = (params) => {
254
+ return { Authorization: `Bearer ${params.accessToken}` };
255
+ };
256
+ var makeUrl9 = (params) => {
257
+ if (params.task === "text-generation" && params.chatCompletion) {
258
+ return `${params.baseUrl}/v1/chat/completions`;
259
+ }
260
+ return params.baseUrl;
261
+ };
262
+ var SAMBANOVA_CONFIG = {
263
+ baseUrl: SAMBANOVA_API_BASE_URL,
264
+ makeBody: makeBody9,
265
+ makeHeaders: makeHeaders9,
266
+ makeUrl: makeUrl9
267
+ };
59
268
 
60
269
  // src/providers/together.ts
61
270
  var TOGETHER_API_BASE_URL = "https://api.together.xyz";
62
-
63
- // src/providers/novita.ts
64
- var NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
65
-
66
- // src/providers/fireworks-ai.ts
67
- var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
68
-
69
- // src/providers/hyperbolic.ts
70
- var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
71
-
72
- // src/providers/black-forest-labs.ts
73
- var BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
271
+ var makeBody10 = (params) => {
272
+ return {
273
+ ...params.args,
274
+ model: params.model
275
+ };
276
+ };
277
+ var makeHeaders10 = (params) => {
278
+ return { Authorization: `Bearer ${params.accessToken}` };
279
+ };
280
+ var makeUrl10 = (params) => {
281
+ if (params.task === "text-to-image") {
282
+ return `${params.baseUrl}/v1/images/generations`;
283
+ }
284
+ if (params.task === "text-generation") {
285
+ if (params.chatCompletion) {
286
+ return `${params.baseUrl}/v1/chat/completions`;
287
+ }
288
+ return `${params.baseUrl}/v1/completions`;
289
+ }
290
+ return params.baseUrl;
291
+ };
292
+ var TOGETHER_CONFIG = {
293
+ baseUrl: TOGETHER_API_BASE_URL,
294
+ makeBody: makeBody10,
295
+ makeHeaders: makeHeaders10,
296
+ makeUrl: makeUrl10
297
+ };
74
298
 
75
299
  // src/lib/isUrl.ts
76
300
  function isUrl(modelOrUrl) {
@@ -79,7 +303,7 @@ function isUrl(modelOrUrl) {
79
303
 
80
304
  // package.json
81
305
  var name = "@huggingface/inference";
82
- var version = "3.3.6";
306
+ var version = "3.3.7";
83
307
 
84
308
  // src/providers/consts.ts
85
309
  var HARDCODED_MODEL_ID_MAPPING = {
@@ -95,10 +319,10 @@ var HARDCODED_MODEL_ID_MAPPING = {
95
319
  "hf-inference": {},
96
320
  hyperbolic: {},
97
321
  nebius: {},
322
+ novita: {},
98
323
  replicate: {},
99
324
  sambanova: {},
100
- together: {},
101
- novita: {}
325
+ together: {}
102
326
  };
103
327
 
104
328
  // src/lib/getProviderModelId.ts
@@ -107,10 +331,10 @@ async function getProviderModelId(params, args, options = {}) {
107
331
  if (params.provider === "hf-inference") {
108
332
  return params.model;
109
333
  }
110
- if (!options.taskHint) {
111
- throw new Error("taskHint must be specified when using a third-party provider");
334
+ if (!options.task) {
335
+ throw new Error("task must be specified when using a third-party provider");
112
336
  }
113
- const task = options.taskHint === "text-generation" && options.chatCompletion ? "conversational" : options.taskHint;
337
+ const task = options.task === "text-generation" && options.chatCompletion ? "conversational" : options.task;
114
338
  if (HARDCODED_MODEL_ID_MAPPING[params.provider]?.[params.model]) {
115
339
  return HARDCODED_MODEL_ID_MAPPING[params.provider][params.model];
116
340
  }
@@ -148,165 +372,82 @@ async function getProviderModelId(params, args, options = {}) {
148
372
  // src/lib/makeRequestOptions.ts
149
373
  var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
150
374
  var tasks = null;
375
+ var providerConfigs = {
376
+ "black-forest-labs": BLACK_FOREST_LABS_CONFIG,
377
+ "fal-ai": FAL_AI_CONFIG,
378
+ "fireworks-ai": FIREWORKS_AI_CONFIG,
379
+ "hf-inference": HF_INFERENCE_CONFIG,
380
+ hyperbolic: HYPERBOLIC_CONFIG,
381
+ nebius: NEBIUS_CONFIG,
382
+ novita: NOVITA_CONFIG,
383
+ replicate: REPLICATE_CONFIG,
384
+ sambanova: SAMBANOVA_CONFIG,
385
+ together: TOGETHER_CONFIG
386
+ };
151
387
  async function makeRequestOptions(args, options) {
152
388
  const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
153
- let otherArgs = remainingArgs;
154
389
  const provider = maybeProvider ?? "hf-inference";
155
- const { includeCredentials, taskHint, chatCompletion: chatCompletion2 } = options ?? {};
390
+ const providerConfig = providerConfigs[provider];
391
+ const { includeCredentials, task, chatCompletion: chatCompletion2, signal } = options ?? {};
156
392
  if (endpointUrl && provider !== "hf-inference") {
157
393
  throw new Error(`Cannot use endpointUrl with a third-party provider.`);
158
394
  }
159
395
  if (maybeModel && isUrl(maybeModel)) {
160
396
  throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
161
397
  }
162
- if (!maybeModel && !taskHint) {
398
+ if (!maybeModel && !task) {
163
399
  throw new Error("No model provided, and no task has been specified.");
164
400
  }
165
- const hfModel = maybeModel ?? await loadDefaultModel(taskHint);
401
+ if (!providerConfig) {
402
+ throw new Error(`No provider config found for provider ${provider}`);
403
+ }
404
+ const hfModel = maybeModel ?? await loadDefaultModel(task);
166
405
  const model = await getProviderModelId({ model: hfModel, provider }, args, {
167
- taskHint,
406
+ task,
168
407
  chatCompletion: chatCompletion2,
169
408
  fetch: options?.fetch
170
409
  });
171
410
  const authMethod = accessToken ? accessToken.startsWith("hf_") ? "hf-token" : "provider-key" : includeCredentials === "include" ? "credentials-include" : "none";
172
- const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : makeUrl({
173
- authMethod,
174
- chatCompletion: chatCompletion2 ?? false,
411
+ const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : providerConfig.makeUrl({
412
+ baseUrl: authMethod !== "provider-key" ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider) : providerConfig.baseUrl,
175
413
  model,
176
- provider: provider ?? "hf-inference",
177
- taskHint
414
+ chatCompletion: chatCompletion2,
415
+ task
178
416
  });
179
- const headers = {};
180
- if (accessToken) {
181
- if (provider === "fal-ai" && authMethod === "provider-key") {
182
- headers["Authorization"] = `Key ${accessToken}`;
183
- } else if (provider === "black-forest-labs" && authMethod === "provider-key") {
184
- headers["X-Key"] = accessToken;
185
- } else {
186
- headers["Authorization"] = `Bearer ${accessToken}`;
187
- }
188
- }
189
- const ownUserAgent = `${name}/${version}`;
190
- headers["User-Agent"] = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : void 0].filter((x) => x !== void 0).join(" ");
191
417
  const binary = "data" in args && !!args.data;
418
+ const headers = providerConfig.makeHeaders({
419
+ accessToken,
420
+ authMethod
421
+ });
192
422
  if (!binary) {
193
423
  headers["Content-Type"] = "application/json";
194
424
  }
195
- if (provider === "replicate") {
196
- headers["Prefer"] = "wait";
197
- }
425
+ const ownUserAgent = `${name}/${version}`;
426
+ const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : void 0].filter((x) => x !== void 0).join(" ");
427
+ headers["User-Agent"] = userAgent;
428
+ const body = binary ? args.data : JSON.stringify(
429
+ providerConfig.makeBody({
430
+ args: remainingArgs,
431
+ model,
432
+ task,
433
+ chatCompletion: chatCompletion2
434
+ })
435
+ );
198
436
  let credentials;
199
437
  if (typeof includeCredentials === "string") {
200
438
  credentials = includeCredentials;
201
439
  } else if (includeCredentials === true) {
202
440
  credentials = "include";
203
441
  }
204
- if (provider === "replicate") {
205
- const version2 = model.includes(":") ? model.split(":")[1] : void 0;
206
- otherArgs = { input: otherArgs, version: version2 };
207
- }
208
442
  const info = {
209
443
  headers,
210
444
  method: "POST",
211
- body: binary ? args.data : JSON.stringify({
212
- ...otherArgs,
213
- ...taskHint === "text-to-image" && provider === "hyperbolic" ? { model_name: model } : chatCompletion2 || provider === "together" || provider === "nebius" || provider === "hyperbolic" ? { model } : void 0
214
- }),
445
+ body,
215
446
  ...credentials ? { credentials } : void 0,
216
- signal: options?.signal
447
+ signal
217
448
  };
218
449
  return { url, info };
219
450
  }
220
- function makeUrl(params) {
221
- if (params.authMethod === "none" && params.provider !== "hf-inference") {
222
- throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
223
- }
224
- const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
225
- switch (params.provider) {
226
- case "black-forest-labs": {
227
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : BLACKFORESTLABS_AI_API_BASE_URL;
228
- return `${baseUrl}/${params.model}`;
229
- }
230
- case "fal-ai": {
231
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FAL_AI_API_BASE_URL;
232
- return `${baseUrl}/${params.model}`;
233
- }
234
- case "nebius": {
235
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : NEBIUS_API_BASE_URL;
236
- if (params.taskHint === "text-to-image") {
237
- return `${baseUrl}/v1/images/generations`;
238
- }
239
- if (params.taskHint === "text-generation") {
240
- if (params.chatCompletion) {
241
- return `${baseUrl}/v1/chat/completions`;
242
- }
243
- return `${baseUrl}/v1/completions`;
244
- }
245
- return baseUrl;
246
- }
247
- case "replicate": {
248
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : REPLICATE_API_BASE_URL;
249
- if (params.model.includes(":")) {
250
- return `${baseUrl}/v1/predictions`;
251
- }
252
- return `${baseUrl}/v1/models/${params.model}/predictions`;
253
- }
254
- case "sambanova": {
255
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : SAMBANOVA_API_BASE_URL;
256
- if (params.taskHint === "text-generation" && params.chatCompletion) {
257
- return `${baseUrl}/v1/chat/completions`;
258
- }
259
- return baseUrl;
260
- }
261
- case "together": {
262
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : TOGETHER_API_BASE_URL;
263
- if (params.taskHint === "text-to-image") {
264
- return `${baseUrl}/v1/images/generations`;
265
- }
266
- if (params.taskHint === "text-generation") {
267
- if (params.chatCompletion) {
268
- return `${baseUrl}/v1/chat/completions`;
269
- }
270
- return `${baseUrl}/v1/completions`;
271
- }
272
- return baseUrl;
273
- }
274
- case "fireworks-ai": {
275
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FIREWORKS_AI_API_BASE_URL;
276
- if (params.taskHint === "text-generation" && params.chatCompletion) {
277
- return `${baseUrl}/v1/chat/completions`;
278
- }
279
- return baseUrl;
280
- }
281
- case "hyperbolic": {
282
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : HYPERBOLIC_API_BASE_URL;
283
- if (params.taskHint === "text-to-image") {
284
- return `${baseUrl}/v1/images/generations`;
285
- }
286
- return `${baseUrl}/v1/chat/completions`;
287
- }
288
- case "novita": {
289
- const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : NOVITA_API_BASE_URL;
290
- if (params.taskHint === "text-generation") {
291
- if (params.chatCompletion) {
292
- return `${baseUrl}/chat/completions`;
293
- }
294
- return `${baseUrl}/completions`;
295
- }
296
- return baseUrl;
297
- }
298
- default: {
299
- const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
300
- if (params.taskHint && ["feature-extraction", "sentence-similarity"].includes(params.taskHint)) {
301
- return `${baseUrl}/pipeline/${params.taskHint}/${params.model}`;
302
- }
303
- if (params.taskHint === "text-generation" && params.chatCompletion) {
304
- return `${baseUrl}/models/${params.model}/v1/chat/completions`;
305
- }
306
- return `${baseUrl}/models/${params.model}`;
307
- }
308
- }
309
- }
310
451
  async function loadDefaultModel(task) {
311
452
  if (!tasks) {
312
453
  tasks = await loadTaskInfo();
@@ -573,7 +714,7 @@ async function audioClassification(args, options) {
573
714
  const payload = preparePayload(args);
574
715
  const res = await request(payload, {
575
716
  ...options,
576
- taskHint: "audio-classification"
717
+ task: "audio-classification"
577
718
  });
578
719
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
579
720
  if (!isValidOutput) {
@@ -600,7 +741,7 @@ async function automaticSpeechRecognition(args, options) {
600
741
  const payload = await buildPayload(args);
601
742
  const res = await request(payload, {
602
743
  ...options,
603
- taskHint: "automatic-speech-recognition"
744
+ task: "automatic-speech-recognition"
604
745
  });
605
746
  const isValidOutput = typeof res?.text === "string";
606
747
  if (!isValidOutput) {
@@ -644,7 +785,7 @@ async function textToSpeech(args, options) {
644
785
  } : args;
645
786
  const res = await request(payload, {
646
787
  ...options,
647
- taskHint: "text-to-speech"
788
+ task: "text-to-speech"
648
789
  });
649
790
  if (res instanceof Blob) {
650
791
  return res;
@@ -670,7 +811,7 @@ async function audioToAudio(args, options) {
670
811
  const payload = preparePayload(args);
671
812
  const res = await request(payload, {
672
813
  ...options,
673
- taskHint: "audio-to-audio"
814
+ task: "audio-to-audio"
674
815
  });
675
816
  return validateOutput(res);
676
817
  }
@@ -696,7 +837,7 @@ async function imageClassification(args, options) {
696
837
  const payload = preparePayload2(args);
697
838
  const res = await request(payload, {
698
839
  ...options,
699
- taskHint: "image-classification"
840
+ task: "image-classification"
700
841
  });
701
842
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
702
843
  if (!isValidOutput) {
@@ -710,7 +851,7 @@ async function imageSegmentation(args, options) {
710
851
  const payload = preparePayload2(args);
711
852
  const res = await request(payload, {
712
853
  ...options,
713
- taskHint: "image-segmentation"
854
+ task: "image-segmentation"
714
855
  });
715
856
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
716
857
  if (!isValidOutput) {
@@ -724,7 +865,7 @@ async function imageToText(args, options) {
724
865
  const payload = preparePayload2(args);
725
866
  const res = (await request(payload, {
726
867
  ...options,
727
- taskHint: "image-to-text"
868
+ task: "image-to-text"
728
869
  }))?.[0];
729
870
  if (typeof res?.generated_text !== "string") {
730
871
  throw new InferenceOutputError("Expected {generated_text: string}");
@@ -737,7 +878,7 @@ async function objectDetection(args, options) {
737
878
  const payload = preparePayload2(args);
738
879
  const res = await request(payload, {
739
880
  ...options,
740
- taskHint: "object-detection"
881
+ task: "object-detection"
741
882
  });
742
883
  const isValidOutput = Array.isArray(res) && res.every(
743
884
  (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
@@ -781,7 +922,7 @@ async function textToImage(args, options) {
781
922
  };
782
923
  const res = await request(payload, {
783
924
  ...options,
784
- taskHint: "text-to-image"
925
+ task: "text-to-image"
785
926
  });
786
927
  if (res && typeof res === "object") {
787
928
  if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
@@ -870,7 +1011,7 @@ async function imageToImage(args, options) {
870
1011
  }
871
1012
  const res = await request(reqArgs, {
872
1013
  ...options,
873
- taskHint: "image-to-image"
1014
+ task: "image-to-image"
874
1015
  });
875
1016
  const isValidOutput = res && res instanceof Blob;
876
1017
  if (!isValidOutput) {
@@ -905,7 +1046,7 @@ async function zeroShotImageClassification(args, options) {
905
1046
  const payload = await preparePayload3(args);
906
1047
  const res = await request(payload, {
907
1048
  ...options,
908
- taskHint: "zero-shot-image-classification"
1049
+ task: "zero-shot-image-classification"
909
1050
  });
910
1051
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
911
1052
  if (!isValidOutput) {
@@ -925,7 +1066,7 @@ async function textToVideo(args, options) {
925
1066
  const payload = args.provider === "fal-ai" || args.provider === "replicate" ? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs } : args;
926
1067
  const res = await request(payload, {
927
1068
  ...options,
928
- taskHint: "text-to-video"
1069
+ task: "text-to-video"
929
1070
  });
930
1071
  if (args.provider === "fal-ai") {
931
1072
  const isValidOutput = typeof res === "object" && !!res && "video" in res && typeof res.video === "object" && !!res.video && "url" in res.video && typeof res.video.url === "string" && isUrl(res.video.url);
@@ -948,7 +1089,7 @@ async function textToVideo(args, options) {
948
1089
  async function featureExtraction(args, options) {
949
1090
  const res = await request(args, {
950
1091
  ...options,
951
- taskHint: "feature-extraction"
1092
+ task: "feature-extraction"
952
1093
  });
953
1094
  let isValidOutput = true;
954
1095
  const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
@@ -971,7 +1112,7 @@ async function featureExtraction(args, options) {
971
1112
  async function fillMask(args, options) {
972
1113
  const res = await request(args, {
973
1114
  ...options,
974
- taskHint: "fill-mask"
1115
+ task: "fill-mask"
975
1116
  });
976
1117
  const isValidOutput = Array.isArray(res) && res.every(
977
1118
  (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
@@ -988,7 +1129,7 @@ async function fillMask(args, options) {
988
1129
  async function questionAnswering(args, options) {
989
1130
  const res = await request(args, {
990
1131
  ...options,
991
- taskHint: "question-answering"
1132
+ task: "question-answering"
992
1133
  });
993
1134
  const isValidOutput = Array.isArray(res) ? res.every(
994
1135
  (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
@@ -1003,7 +1144,7 @@ async function questionAnswering(args, options) {
1003
1144
  async function sentenceSimilarity(args, options) {
1004
1145
  const res = await request(prepareInput(args), {
1005
1146
  ...options,
1006
- taskHint: "sentence-similarity"
1147
+ task: "sentence-similarity"
1007
1148
  });
1008
1149
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1009
1150
  if (!isValidOutput) {
@@ -1023,7 +1164,7 @@ function prepareInput(args) {
1023
1164
  async function summarization(args, options) {
1024
1165
  const res = await request(args, {
1025
1166
  ...options,
1026
- taskHint: "summarization"
1167
+ task: "summarization"
1027
1168
  });
1028
1169
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string");
1029
1170
  if (!isValidOutput) {
@@ -1036,7 +1177,7 @@ async function summarization(args, options) {
1036
1177
  async function tableQuestionAnswering(args, options) {
1037
1178
  const res = await request(args, {
1038
1179
  ...options,
1039
- taskHint: "table-question-answering"
1180
+ task: "table-question-answering"
1040
1181
  });
1041
1182
  const isValidOutput = Array.isArray(res) ? res.every((elem) => validate(elem)) : validate(res);
1042
1183
  if (!isValidOutput) {
@@ -1056,7 +1197,7 @@ function validate(elem) {
1056
1197
  async function textClassification(args, options) {
1057
1198
  const res = (await request(args, {
1058
1199
  ...options,
1059
- taskHint: "text-classification"
1200
+ task: "text-classification"
1060
1201
  }))?.[0];
1061
1202
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number");
1062
1203
  if (!isValidOutput) {
@@ -1079,7 +1220,7 @@ async function textGeneration(args, options) {
1079
1220
  args.prompt = args.inputs;
1080
1221
  const raw = await request(args, {
1081
1222
  ...options,
1082
- taskHint: "text-generation"
1223
+ task: "text-generation"
1083
1224
  });
1084
1225
  const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
1085
1226
  if (!isValidOutput) {
@@ -1100,7 +1241,7 @@ async function textGeneration(args, options) {
1100
1241
  };
1101
1242
  const raw = await request(payload, {
1102
1243
  ...options,
1103
- taskHint: "text-generation"
1244
+ task: "text-generation"
1104
1245
  });
1105
1246
  const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
1106
1247
  if (!isValidOutput) {
@@ -1114,7 +1255,7 @@ async function textGeneration(args, options) {
1114
1255
  const res = toArray(
1115
1256
  await request(args, {
1116
1257
  ...options,
1117
- taskHint: "text-generation"
1258
+ task: "text-generation"
1118
1259
  })
1119
1260
  );
1120
1261
  const isValidOutput = Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
@@ -1129,7 +1270,7 @@ async function textGeneration(args, options) {
1129
1270
  async function* textGenerationStream(args, options) {
1130
1271
  yield* streamingRequest(args, {
1131
1272
  ...options,
1132
- taskHint: "text-generation"
1273
+ task: "text-generation"
1133
1274
  });
1134
1275
  }
1135
1276
 
@@ -1138,7 +1279,7 @@ async function tokenClassification(args, options) {
1138
1279
  const res = toArray(
1139
1280
  await request(args, {
1140
1281
  ...options,
1141
- taskHint: "token-classification"
1282
+ task: "token-classification"
1142
1283
  })
1143
1284
  );
1144
1285
  const isValidOutput = Array.isArray(res) && res.every(
@@ -1156,7 +1297,7 @@ async function tokenClassification(args, options) {
1156
1297
  async function translation(args, options) {
1157
1298
  const res = await request(args, {
1158
1299
  ...options,
1159
- taskHint: "translation"
1300
+ task: "translation"
1160
1301
  });
1161
1302
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");
1162
1303
  if (!isValidOutput) {
@@ -1170,7 +1311,7 @@ async function zeroShotClassification(args, options) {
1170
1311
  const res = toArray(
1171
1312
  await request(args, {
1172
1313
  ...options,
1173
- taskHint: "zero-shot-classification"
1314
+ task: "zero-shot-classification"
1174
1315
  })
1175
1316
  );
1176
1317
  const isValidOutput = Array.isArray(res) && res.every(
@@ -1186,7 +1327,7 @@ async function zeroShotClassification(args, options) {
1186
1327
  async function chatCompletion(args, options) {
1187
1328
  const res = await request(args, {
1188
1329
  ...options,
1189
- taskHint: "text-generation",
1330
+ task: "text-generation",
1190
1331
  chatCompletion: true
1191
1332
  });
1192
1333
  const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
@@ -1201,7 +1342,7 @@ async function chatCompletion(args, options) {
1201
1342
  async function* chatCompletionStream(args, options) {
1202
1343
  yield* streamingRequest(args, {
1203
1344
  ...options,
1204
- taskHint: "text-generation",
1345
+ task: "text-generation",
1205
1346
  chatCompletion: true
1206
1347
  });
1207
1348
  }
@@ -1219,7 +1360,7 @@ async function documentQuestionAnswering(args, options) {
1219
1360
  const res = toArray(
1220
1361
  await request(reqArgs, {
1221
1362
  ...options,
1222
- taskHint: "document-question-answering"
1363
+ task: "document-question-answering"
1223
1364
  })
1224
1365
  );
1225
1366
  const isValidOutput = Array.isArray(res) && res.every(
@@ -1243,7 +1384,7 @@ async function visualQuestionAnswering(args, options) {
1243
1384
  };
1244
1385
  const res = await request(reqArgs, {
1245
1386
  ...options,
1246
- taskHint: "visual-question-answering"
1387
+ task: "visual-question-answering"
1247
1388
  });
1248
1389
  const isValidOutput = Array.isArray(res) && res.every(
1249
1390
  (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
@@ -1258,7 +1399,7 @@ async function visualQuestionAnswering(args, options) {
1258
1399
  async function tabularRegression(args, options) {
1259
1400
  const res = await request(args, {
1260
1401
  ...options,
1261
- taskHint: "tabular-regression"
1402
+ task: "tabular-regression"
1262
1403
  });
1263
1404
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1264
1405
  if (!isValidOutput) {
@@ -1271,7 +1412,7 @@ async function tabularRegression(args, options) {
1271
1412
  async function tabularClassification(args, options) {
1272
1413
  const res = await request(args, {
1273
1414
  ...options,
1274
- taskHint: "tabular-classification"
1415
+ task: "tabular-classification"
1275
1416
  });
1276
1417
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1277
1418
  if (!isValidOutput) {