@huggingface/inference 3.2.0 → 3.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -46,7 +46,13 @@ Your access token should be kept private. If you need to protect it in front-end
46
46
 
47
47
  You can send inference requests to third-party providers with the inference client.
48
48
 
49
- Currently, we support the following providers: [Fal.ai](https://fal.ai), [Replicate](https://replicate.com), [Together](https://together.xyz) and [Sambanova](https://sambanova.ai).
49
+ Currently, we support the following providers:
50
+ - [Fal.ai](https://fal.ai)
51
+ - [Fireworks AI](https://fireworks.ai)
52
+ - [Nebius](https://studio.nebius.ai)
53
+ - [Replicate](https://replicate.com)
54
+ - [Sambanova](https://sambanova.ai)
55
+ - [Together](https://together.xyz)
50
56
 
51
57
  To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
52
58
  ```ts
@@ -64,13 +70,15 @@ When authenticated with a Hugging Face access token, the request is routed throu
64
70
  When authenticated with a third-party provider key, the request is made directly against that provider's inference API.
65
71
 
66
72
  Only a subset of models are supported when requesting third-party providers. You can check the list of supported models per pipeline tasks here:
67
- - [Fal.ai supported models](./src/providers/fal-ai.ts)
68
- - [Replicate supported models](./src/providers/replicate.ts)
69
- - [Sambanova supported models](./src/providers/sambanova.ts)
70
- - [Together supported models](./src/providers/together.ts)
73
+ - [Fal.ai supported models](https://huggingface.co/api/partners/fal-ai/models)
74
+ - [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models)
75
+ - [Nebius supported models](https://huggingface.co/api/partners/nebius/models)
76
+ - [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
77
+ - [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
78
+ - [Together supported models](https://huggingface.co/api/partners/together/models)
71
79
  - [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)
72
80
 
73
- ❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
81
+ ❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
74
82
  This is not an issue for LLMs as everyone converged on the OpenAI API anyways, but can be more tricky for other tasks like "text-to-image" or "automatic-speech-recognition" where there exists no standard API. Let us know if any help is needed or if we can make things easier for you!
75
83
 
76
84
  👋**Want to add another provider?** Get in touch if you'd like to add support for another Inference provider, and/or request it on https://huggingface.co/spaces/huggingface/HuggingDiscussions/discussions/49
@@ -457,7 +465,7 @@ await hf.zeroShotImageClassification({
457
465
  model: 'openai/clip-vit-large-patch14-336',
458
466
  inputs: {
459
467
  image: await (await fetch('https://placekitten.com/300/300')).blob()
460
- },
468
+ },
461
469
  parameters: {
462
470
  candidate_labels: ['cat', 'dog']
463
471
  }
package/dist/index.cjs CHANGED
@@ -103,6 +103,9 @@ var HF_ROUTER_URL = "https://router.huggingface.co";
103
103
  // src/providers/fal-ai.ts
104
104
  var FAL_AI_API_BASE_URL = "https://fal.run";
105
105
 
106
+ // src/providers/nebius.ts
107
+ var NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
108
+
106
109
  // src/providers/replicate.ts
107
110
  var REPLICATE_API_BASE_URL = "https://api.replicate.com";
108
111
 
@@ -112,6 +115,9 @@ var SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
112
115
  // src/providers/together.ts
113
116
  var TOGETHER_API_BASE_URL = "https://api.together.xyz";
114
117
 
118
+ // src/providers/fireworks-ai.ts
119
+ var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
120
+
115
121
  // src/lib/isUrl.ts
116
122
  function isUrl(modelOrUrl) {
117
123
  return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
@@ -119,14 +125,23 @@ function isUrl(modelOrUrl) {
119
125
 
120
126
  // package.json
121
127
  var name = "@huggingface/inference";
122
- var version = "3.2.0";
128
+ var version = "3.3.1";
123
129
 
124
130
  // src/providers/consts.ts
125
131
  var HARDCODED_MODEL_ID_MAPPING = {
126
132
  /**
127
133
  * "HF model ID" => "Model ID on Inference Provider's side"
134
+ *
135
+ * Example:
136
+ * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
128
137
  */
129
- // "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
138
+ "fal-ai": {},
139
+ "fireworks-ai": {},
140
+ "hf-inference": {},
141
+ nebius: {},
142
+ replicate: {},
143
+ sambanova: {},
144
+ together: {}
130
145
  };
131
146
 
132
147
  // src/lib/getProviderModelId.ts
@@ -139,8 +154,8 @@ async function getProviderModelId(params, args, options = {}) {
139
154
  throw new Error("taskHint must be specified when using a third-party provider");
140
155
  }
141
156
  const task = options.taskHint === "text-generation" && options.chatCompletion ? "conversational" : options.taskHint;
142
- if (HARDCODED_MODEL_ID_MAPPING[params.model]) {
143
- return HARDCODED_MODEL_ID_MAPPING[params.model];
157
+ if (HARDCODED_MODEL_ID_MAPPING[params.provider]?.[params.model]) {
158
+ return HARDCODED_MODEL_ID_MAPPING[params.provider][params.model];
144
159
  }
145
160
  let inferenceProviderMapping;
146
161
  if (inferenceProviderMappingCache.has(params.model)) {
@@ -180,7 +195,7 @@ async function makeRequestOptions(args, options) {
180
195
  const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
181
196
  let otherArgs = remainingArgs;
182
197
  const provider = maybeProvider ?? "hf-inference";
183
- const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion: chatCompletion2 } = options ?? {};
198
+ const { forceTask, includeCredentials, taskHint, chatCompletion: chatCompletion2 } = options ?? {};
184
199
  if (endpointUrl && provider !== "hf-inference") {
185
200
  throw new Error(`Cannot use endpointUrl with a third-party provider.`);
186
201
  }
@@ -218,17 +233,6 @@ async function makeRequestOptions(args, options) {
218
233
  if (!binary) {
219
234
  headers["Content-Type"] = "application/json";
220
235
  }
221
- if (provider === "hf-inference") {
222
- if (wait_for_model) {
223
- headers["X-Wait-For-Model"] = "true";
224
- }
225
- if (use_cache === false) {
226
- headers["X-Use-Cache"] = "false";
227
- }
228
- if (dont_load_model) {
229
- headers["X-Load-Model"] = "0";
230
- }
231
- }
232
236
  if (provider === "replicate") {
233
237
  headers["Prefer"] = "wait";
234
238
  }
@@ -247,7 +251,7 @@ async function makeRequestOptions(args, options) {
247
251
  method: "POST",
248
252
  body: binary ? args.data : JSON.stringify({
249
253
  ...otherArgs,
250
- ...chatCompletion2 || provider === "together" ? { model } : void 0
254
+ ...chatCompletion2 || provider === "together" || provider === "nebius" ? { model } : void 0
251
255
  }),
252
256
  ...credentials ? { credentials } : void 0,
253
257
  signal: options?.signal
@@ -264,6 +268,19 @@ function makeUrl(params) {
264
268
  const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FAL_AI_API_BASE_URL;
265
269
  return `${baseUrl}/${params.model}`;
266
270
  }
271
+ case "nebius": {
272
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : NEBIUS_API_BASE_URL;
273
+ if (params.taskHint === "text-to-image") {
274
+ return `${baseUrl}/v1/images/generations`;
275
+ }
276
+ if (params.taskHint === "text-generation") {
277
+ if (params.chatCompletion) {
278
+ return `${baseUrl}/v1/chat/completions`;
279
+ }
280
+ return `${baseUrl}/v1/completions`;
281
+ }
282
+ return baseUrl;
283
+ }
267
284
  case "replicate": {
268
285
  const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : REPLICATE_API_BASE_URL;
269
286
  if (params.model.includes(":")) {
@@ -291,6 +308,13 @@ function makeUrl(params) {
291
308
  }
292
309
  return baseUrl;
293
310
  }
311
+ case "fireworks-ai": {
312
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FIREWORKS_AI_API_BASE_URL;
313
+ if (params.taskHint === "text-generation" && params.chatCompletion) {
314
+ return `${baseUrl}/v1/chat/completions`;
315
+ }
316
+ return baseUrl;
317
+ }
294
318
  default: {
295
319
  const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
296
320
  const url = params.forceTask ? `${baseUrl}/pipeline/${params.forceTask}/${params.model}` : `${baseUrl}/models/${params.model}`;
@@ -323,11 +347,8 @@ async function loadTaskInfo() {
323
347
  async function request(args, options) {
324
348
  const { url, info } = await makeRequestOptions(args, options);
325
349
  const response = await (options?.fetch ?? fetch)(url, info);
326
- if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
327
- return request(args, {
328
- ...options,
329
- wait_for_model: true
330
- });
350
+ if (options?.retry_on_error !== false && response.status === 503) {
351
+ return request(args, options);
331
352
  }
332
353
  if (!response.ok) {
333
354
  const contentType = response.headers.get("Content-Type");
@@ -456,11 +477,8 @@ function newMessage() {
456
477
  async function* streamingRequest(args, options) {
457
478
  const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
458
479
  const response = await (options?.fetch ?? fetch)(url, info);
459
- if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
460
- return yield* streamingRequest(args, {
461
- ...options,
462
- wait_for_model: true
463
- });
480
+ if (options?.retry_on_error !== false && response.status === 503) {
481
+ return yield* streamingRequest(args, options);
464
482
  }
465
483
  if (!response.ok) {
466
484
  if (response.headers.get("Content-Type")?.startsWith("application/json")) {
@@ -751,13 +769,27 @@ async function objectDetection(args, options) {
751
769
  }
752
770
 
753
771
  // src/tasks/cv/textToImage.ts
772
+ function getResponseFormatArg(provider) {
773
+ switch (provider) {
774
+ case "fal-ai":
775
+ return { sync_mode: true };
776
+ case "nebius":
777
+ return { response_format: "b64_json" };
778
+ case "replicate":
779
+ return void 0;
780
+ case "together":
781
+ return { response_format: "base64" };
782
+ default:
783
+ return void 0;
784
+ }
785
+ }
754
786
  async function textToImage(args, options) {
755
- const payload = args.provider === "together" || args.provider === "fal-ai" || args.provider === "replicate" ? {
787
+ const payload = !args.provider || args.provider === "hf-inference" || args.provider === "sambanova" ? args : {
756
788
  ...omit(args, ["inputs", "parameters"]),
757
789
  ...args.parameters,
758
- ...args.provider !== "replicate" ? { response_format: "base64" } : void 0,
790
+ ...getResponseFormatArg(args.provider),
759
791
  prompt: args.inputs
760
- } : args;
792
+ };
761
793
  const res = await request(payload, {
762
794
  ...options,
763
795
  taskHint: "text-to-image"
@@ -1137,8 +1169,8 @@ async function chatCompletion(args, options) {
1137
1169
  taskHint: "text-generation",
1138
1170
  chatCompletion: true
1139
1171
  });
1140
- const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai does not output a system_fingerprint
1141
- (res.system_fingerprint === void 0 || typeof res.system_fingerprint === "string") && typeof res?.usage === "object";
1172
+ const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
1173
+ (res.system_fingerprint === void 0 || res.system_fingerprint === null || typeof res.system_fingerprint === "string") && typeof res?.usage === "object";
1142
1174
  if (!isValidOutput) {
1143
1175
  throw new InferenceOutputError("Expected ChatCompletionOutput");
1144
1176
  }
@@ -1269,7 +1301,15 @@ var HfInferenceEndpoint = class {
1269
1301
  };
1270
1302
 
1271
1303
  // src/types.ts
1272
- var INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-inference"];
1304
+ var INFERENCE_PROVIDERS = [
1305
+ "fal-ai",
1306
+ "fireworks-ai",
1307
+ "nebius",
1308
+ "hf-inference",
1309
+ "replicate",
1310
+ "sambanova",
1311
+ "together"
1312
+ ];
1273
1313
  // Annotate the CommonJS export names for ESM import in node:
1274
1314
  0 && (module.exports = {
1275
1315
  HfInference,
package/dist/index.js CHANGED
@@ -48,6 +48,9 @@ var HF_ROUTER_URL = "https://router.huggingface.co";
48
48
  // src/providers/fal-ai.ts
49
49
  var FAL_AI_API_BASE_URL = "https://fal.run";
50
50
 
51
+ // src/providers/nebius.ts
52
+ var NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
53
+
51
54
  // src/providers/replicate.ts
52
55
  var REPLICATE_API_BASE_URL = "https://api.replicate.com";
53
56
 
@@ -57,6 +60,9 @@ var SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
57
60
  // src/providers/together.ts
58
61
  var TOGETHER_API_BASE_URL = "https://api.together.xyz";
59
62
 
63
+ // src/providers/fireworks-ai.ts
64
+ var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
65
+
60
66
  // src/lib/isUrl.ts
61
67
  function isUrl(modelOrUrl) {
62
68
  return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
@@ -64,14 +70,23 @@ function isUrl(modelOrUrl) {
64
70
 
65
71
  // package.json
66
72
  var name = "@huggingface/inference";
67
- var version = "3.2.0";
73
+ var version = "3.3.1";
68
74
 
69
75
  // src/providers/consts.ts
70
76
  var HARDCODED_MODEL_ID_MAPPING = {
71
77
  /**
72
78
  * "HF model ID" => "Model ID on Inference Provider's side"
79
+ *
80
+ * Example:
81
+ * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
73
82
  */
74
- // "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
83
+ "fal-ai": {},
84
+ "fireworks-ai": {},
85
+ "hf-inference": {},
86
+ nebius: {},
87
+ replicate: {},
88
+ sambanova: {},
89
+ together: {}
75
90
  };
76
91
 
77
92
  // src/lib/getProviderModelId.ts
@@ -84,8 +99,8 @@ async function getProviderModelId(params, args, options = {}) {
84
99
  throw new Error("taskHint must be specified when using a third-party provider");
85
100
  }
86
101
  const task = options.taskHint === "text-generation" && options.chatCompletion ? "conversational" : options.taskHint;
87
- if (HARDCODED_MODEL_ID_MAPPING[params.model]) {
88
- return HARDCODED_MODEL_ID_MAPPING[params.model];
102
+ if (HARDCODED_MODEL_ID_MAPPING[params.provider]?.[params.model]) {
103
+ return HARDCODED_MODEL_ID_MAPPING[params.provider][params.model];
89
104
  }
90
105
  let inferenceProviderMapping;
91
106
  if (inferenceProviderMappingCache.has(params.model)) {
@@ -125,7 +140,7 @@ async function makeRequestOptions(args, options) {
125
140
  const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
126
141
  let otherArgs = remainingArgs;
127
142
  const provider = maybeProvider ?? "hf-inference";
128
- const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion: chatCompletion2 } = options ?? {};
143
+ const { forceTask, includeCredentials, taskHint, chatCompletion: chatCompletion2 } = options ?? {};
129
144
  if (endpointUrl && provider !== "hf-inference") {
130
145
  throw new Error(`Cannot use endpointUrl with a third-party provider.`);
131
146
  }
@@ -163,17 +178,6 @@ async function makeRequestOptions(args, options) {
163
178
  if (!binary) {
164
179
  headers["Content-Type"] = "application/json";
165
180
  }
166
- if (provider === "hf-inference") {
167
- if (wait_for_model) {
168
- headers["X-Wait-For-Model"] = "true";
169
- }
170
- if (use_cache === false) {
171
- headers["X-Use-Cache"] = "false";
172
- }
173
- if (dont_load_model) {
174
- headers["X-Load-Model"] = "0";
175
- }
176
- }
177
181
  if (provider === "replicate") {
178
182
  headers["Prefer"] = "wait";
179
183
  }
@@ -192,7 +196,7 @@ async function makeRequestOptions(args, options) {
192
196
  method: "POST",
193
197
  body: binary ? args.data : JSON.stringify({
194
198
  ...otherArgs,
195
- ...chatCompletion2 || provider === "together" ? { model } : void 0
199
+ ...chatCompletion2 || provider === "together" || provider === "nebius" ? { model } : void 0
196
200
  }),
197
201
  ...credentials ? { credentials } : void 0,
198
202
  signal: options?.signal
@@ -209,6 +213,19 @@ function makeUrl(params) {
209
213
  const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FAL_AI_API_BASE_URL;
210
214
  return `${baseUrl}/${params.model}`;
211
215
  }
216
+ case "nebius": {
217
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : NEBIUS_API_BASE_URL;
218
+ if (params.taskHint === "text-to-image") {
219
+ return `${baseUrl}/v1/images/generations`;
220
+ }
221
+ if (params.taskHint === "text-generation") {
222
+ if (params.chatCompletion) {
223
+ return `${baseUrl}/v1/chat/completions`;
224
+ }
225
+ return `${baseUrl}/v1/completions`;
226
+ }
227
+ return baseUrl;
228
+ }
212
229
  case "replicate": {
213
230
  const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : REPLICATE_API_BASE_URL;
214
231
  if (params.model.includes(":")) {
@@ -236,6 +253,13 @@ function makeUrl(params) {
236
253
  }
237
254
  return baseUrl;
238
255
  }
256
+ case "fireworks-ai": {
257
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FIREWORKS_AI_API_BASE_URL;
258
+ if (params.taskHint === "text-generation" && params.chatCompletion) {
259
+ return `${baseUrl}/v1/chat/completions`;
260
+ }
261
+ return baseUrl;
262
+ }
239
263
  default: {
240
264
  const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
241
265
  const url = params.forceTask ? `${baseUrl}/pipeline/${params.forceTask}/${params.model}` : `${baseUrl}/models/${params.model}`;
@@ -268,11 +292,8 @@ async function loadTaskInfo() {
268
292
  async function request(args, options) {
269
293
  const { url, info } = await makeRequestOptions(args, options);
270
294
  const response = await (options?.fetch ?? fetch)(url, info);
271
- if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
272
- return request(args, {
273
- ...options,
274
- wait_for_model: true
275
- });
295
+ if (options?.retry_on_error !== false && response.status === 503) {
296
+ return request(args, options);
276
297
  }
277
298
  if (!response.ok) {
278
299
  const contentType = response.headers.get("Content-Type");
@@ -401,11 +422,8 @@ function newMessage() {
401
422
  async function* streamingRequest(args, options) {
402
423
  const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
403
424
  const response = await (options?.fetch ?? fetch)(url, info);
404
- if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
405
- return yield* streamingRequest(args, {
406
- ...options,
407
- wait_for_model: true
408
- });
425
+ if (options?.retry_on_error !== false && response.status === 503) {
426
+ return yield* streamingRequest(args, options);
409
427
  }
410
428
  if (!response.ok) {
411
429
  if (response.headers.get("Content-Type")?.startsWith("application/json")) {
@@ -696,13 +714,27 @@ async function objectDetection(args, options) {
696
714
  }
697
715
 
698
716
  // src/tasks/cv/textToImage.ts
717
+ function getResponseFormatArg(provider) {
718
+ switch (provider) {
719
+ case "fal-ai":
720
+ return { sync_mode: true };
721
+ case "nebius":
722
+ return { response_format: "b64_json" };
723
+ case "replicate":
724
+ return void 0;
725
+ case "together":
726
+ return { response_format: "base64" };
727
+ default:
728
+ return void 0;
729
+ }
730
+ }
699
731
  async function textToImage(args, options) {
700
- const payload = args.provider === "together" || args.provider === "fal-ai" || args.provider === "replicate" ? {
732
+ const payload = !args.provider || args.provider === "hf-inference" || args.provider === "sambanova" ? args : {
701
733
  ...omit(args, ["inputs", "parameters"]),
702
734
  ...args.parameters,
703
- ...args.provider !== "replicate" ? { response_format: "base64" } : void 0,
735
+ ...getResponseFormatArg(args.provider),
704
736
  prompt: args.inputs
705
- } : args;
737
+ };
706
738
  const res = await request(payload, {
707
739
  ...options,
708
740
  taskHint: "text-to-image"
@@ -1082,8 +1114,8 @@ async function chatCompletion(args, options) {
1082
1114
  taskHint: "text-generation",
1083
1115
  chatCompletion: true
1084
1116
  });
1085
- const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai does not output a system_fingerprint
1086
- (res.system_fingerprint === void 0 || typeof res.system_fingerprint === "string") && typeof res?.usage === "object";
1117
+ const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
1118
+ (res.system_fingerprint === void 0 || res.system_fingerprint === null || typeof res.system_fingerprint === "string") && typeof res?.usage === "object";
1087
1119
  if (!isValidOutput) {
1088
1120
  throw new InferenceOutputError("Expected ChatCompletionOutput");
1089
1121
  }
@@ -1214,7 +1246,15 @@ var HfInferenceEndpoint = class {
1214
1246
  };
1215
1247
 
1216
1248
  // src/types.ts
1217
- var INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-inference"];
1249
+ var INFERENCE_PROVIDERS = [
1250
+ "fal-ai",
1251
+ "fireworks-ai",
1252
+ "nebius",
1253
+ "hf-inference",
1254
+ "replicate",
1255
+ "sambanova",
1256
+ "together"
1257
+ ];
1218
1258
  export {
1219
1259
  HfInference,
1220
1260
  HfInferenceEndpoint,
@@ -1 +1 @@
1
- {"version":3,"file":"makeRequestOptions.d.ts","sourceRoot":"","sources":["../../../src/lib/makeRequestOptions.ts"],"names":[],"mappings":"AAMA,OAAO,KAAK,EAAE,aAAa,EAAE,OAAO,EAAE,WAAW,EAAE,MAAM,UAAU,CAAC;AAapE;;GAEG;AACH,wBAAsB,kBAAkB,CACvC,IAAI,EAAE,WAAW,GAAG;IACnB,IAAI,CAAC,EAAE,IAAI,GAAG,WAAW,CAAC;IAC1B,MAAM,CAAC,EAAE,OAAO,CAAC;CACjB,EACD,OAAO,CAAC,EAAE,OAAO,GAAG;IACnB,yFAAyF;IACzF,SAAS,CAAC,EAAE,MAAM,GAAG,aAAa,CAAC;IACnC,sCAAsC;IACtC,QAAQ,CAAC,EAAE,aAAa,CAAC;IACzB,cAAc,CAAC,EAAE,OAAO,CAAC;CACzB,GACC,OAAO,CAAC;IAAE,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,WAAW,CAAA;CAAE,CAAC,CAqH7C"}
1
+ {"version":3,"file":"makeRequestOptions.d.ts","sourceRoot":"","sources":["../../../src/lib/makeRequestOptions.ts"],"names":[],"mappings":"AAQA,OAAO,KAAK,EAAE,aAAa,EAAE,OAAO,EAAE,WAAW,EAAE,MAAM,UAAU,CAAC;AAapE;;GAEG;AACH,wBAAsB,kBAAkB,CACvC,IAAI,EAAE,WAAW,GAAG;IACnB,IAAI,CAAC,EAAE,IAAI,GAAG,WAAW,CAAC;IAC1B,MAAM,CAAC,EAAE,OAAO,CAAC;CACjB,EACD,OAAO,CAAC,EAAE,OAAO,GAAG;IACnB,yFAAyF;IACzF,SAAS,CAAC,EAAE,MAAM,GAAG,aAAa,CAAC;IACnC,sCAAsC;IACtC,QAAQ,CAAC,EAAE,aAAa,CAAC;IACzB,cAAc,CAAC,EAAE,OAAO,CAAC;CACzB,GACC,OAAO,CAAC;IAAE,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,WAAW,CAAA;CAAE,CAAC,CAwG7C"}
@@ -1,10 +1,13 @@
1
- import type { ModelId } from "../types";
1
+ import type { InferenceProvider } from "../types";
2
+ import { type ModelId } from "../types";
2
3
  type ProviderId = string;
3
4
  /**
4
5
  * If you want to try to run inference for a new model locally before it's registered on huggingface.co
5
6
  * for a given Inference Provider,
6
7
  * you can add it to the following dictionary, for dev purposes.
8
+ *
9
+ * We also inject into this dictionary from tests.
7
10
  */
8
- export declare const HARDCODED_MODEL_ID_MAPPING: Record<ModelId, ProviderId>;
11
+ export declare const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelId, ProviderId>>;
9
12
  export {};
10
13
  //# sourceMappingURL=consts.d.ts.map
@@ -1 +1 @@
1
- {"version":3,"file":"consts.d.ts","sourceRoot":"","sources":["../../../src/providers/consts.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,OAAO,EAAE,MAAM,UAAU,CAAC;AAExC,KAAK,UAAU,GAAG,MAAM,CAAC;AAEzB;;;;GAIG;AACH,eAAO,MAAM,0BAA0B,EAAE,MAAM,CAAC,OAAO,EAAE,UAAU,CAKlE,CAAC"}
1
+ {"version":3,"file":"consts.d.ts","sourceRoot":"","sources":["../../../src/providers/consts.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,iBAAiB,EAAE,MAAM,UAAU,CAAC;AAClD,OAAO,EAAE,KAAK,OAAO,EAAE,MAAM,UAAU,CAAC;AAExC,KAAK,UAAU,GAAG,MAAM,CAAC;AACzB;;;;;;GAMG;AACH,eAAO,MAAM,0BAA0B,EAAE,MAAM,CAAC,iBAAiB,EAAE,MAAM,CAAC,OAAO,EAAE,UAAU,CAAC,CAc7F,CAAC"}
@@ -0,0 +1,18 @@
1
+ export declare const FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
2
+ /**
3
+ * See the registered mapping of HF model ID => Fireworks model ID here:
4
+ *
5
+ * https://huggingface.co/api/partners/fireworks/models
6
+ *
7
+ * This is a publicly available mapping.
8
+ *
9
+ * If you want to try to run inference for a new model locally before it's registered on huggingface.co,
10
+ * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
11
+ *
12
+ * - If you work at Fireworks and want to update this mapping, please use the model mapping API we provide on huggingface.co
13
+ * - If you're a community member and want to add a new supported HF model to Fireworks, please open an issue on the present repo
14
+ * and we will tag Fireworks team members.
15
+ *
16
+ * Thanks!
17
+ */
18
+ //# sourceMappingURL=fireworks-ai.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"fireworks-ai.d.ts","sourceRoot":"","sources":["../../../src/providers/fireworks-ai.ts"],"names":[],"mappings":"AAAA,eAAO,MAAM,yBAAyB,uCAAuC,CAAC;AAE9E;;;;;;;;;;;;;;;GAeG"}
@@ -0,0 +1,18 @@
1
+ export declare const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
2
+ /**
3
+ * See the registered mapping of HF model ID => Nebius model ID here:
4
+ *
5
+ * https://huggingface.co/api/partners/nebius/models
6
+ *
7
+ * This is a publicly available mapping.
8
+ *
9
+ * If you want to try to run inference for a new model locally before it's registered on huggingface.co,
10
+ * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
11
+ *
12
+ * - If you work at Nebius and want to update this mapping, please use the model mapping API we provide on huggingface.co
13
+ * - If you're a community member and want to add a new supported HF model to Nebius, please open an issue on the present repo
14
+ * and we will tag Nebius team members.
15
+ *
16
+ * Thanks!
17
+ */
18
+ //# sourceMappingURL=nebius.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"nebius.d.ts","sourceRoot":"","sources":["../../../src/providers/nebius.ts"],"names":[],"mappings":"AAAA,eAAO,MAAM,mBAAmB,iCAAiC,CAAC;AAElE;;;;;;;;;;;;;;;GAeG"}
@@ -1 +1 @@
1
- {"version":3,"file":"request.d.ts","sourceRoot":"","sources":["../../../../src/tasks/custom/request.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,aAAa,EAAE,OAAO,EAAE,WAAW,EAAE,MAAM,aAAa,CAAC;AAGvE;;GAEG;AACH,wBAAsB,OAAO,CAAC,CAAC,EAC9B,IAAI,EAAE,WAAW,EACjB,OAAO,CAAC,EAAE,OAAO,GAAG;IACnB,yFAAyF;IACzF,IAAI,CAAC,EAAE,MAAM,GAAG,aAAa,CAAC;IAC9B,sCAAsC;IACtC,QAAQ,CAAC,EAAE,aAAa,CAAC;IACzB,oCAAoC;IACpC,cAAc,CAAC,EAAE,OAAO,CAAC;CACzB,GACC,OAAO,CAAC,CAAC,CAAC,CAmCZ"}
1
+ {"version":3,"file":"request.d.ts","sourceRoot":"","sources":["../../../../src/tasks/custom/request.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,aAAa,EAAE,OAAO,EAAE,WAAW,EAAE,MAAM,aAAa,CAAC;AAGvE;;GAEG;AACH,wBAAsB,OAAO,CAAC,CAAC,EAC9B,IAAI,EAAE,WAAW,EACjB,OAAO,CAAC,EAAE,OAAO,GAAG;IACnB,yFAAyF;IACzF,IAAI,CAAC,EAAE,MAAM,GAAG,aAAa,CAAC;IAC9B,sCAAsC;IACtC,QAAQ,CAAC,EAAE,aAAa,CAAC;IACzB,oCAAoC;IACpC,cAAc,CAAC,EAAE,OAAO,CAAC;CACzB,GACC,OAAO,CAAC,CAAC,CAAC,CAgCZ"}
@@ -1 +1 @@
1
- {"version":3,"file":"streamingRequest.d.ts","sourceRoot":"","sources":["../../../../src/tasks/custom/streamingRequest.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,aAAa,EAAE,OAAO,EAAE,WAAW,EAAE,MAAM,aAAa,CAAC;AAKvE;;GAEG;AACH,wBAAuB,gBAAgB,CAAC,CAAC,EACxC,IAAI,EAAE,WAAW,EACjB,OAAO,CAAC,EAAE,OAAO,GAAG;IACnB,yFAAyF;IACzF,IAAI,CAAC,EAAE,MAAM,GAAG,aAAa,CAAC;IAC9B,sCAAsC;IACtC,QAAQ,CAAC,EAAE,aAAa,CAAC;IACzB,oCAAoC;IACpC,cAAc,CAAC,EAAE,OAAO,CAAC;CACzB,GACC,cAAc,CAAC,CAAC,CAAC,CAsFnB"}
1
+ {"version":3,"file":"streamingRequest.d.ts","sourceRoot":"","sources":["../../../../src/tasks/custom/streamingRequest.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,aAAa,EAAE,OAAO,EAAE,WAAW,EAAE,MAAM,aAAa,CAAC;AAKvE;;GAEG;AACH,wBAAuB,gBAAgB,CAAC,CAAC,EACxC,IAAI,EAAE,WAAW,EACjB,OAAO,CAAC,EAAE,OAAO,GAAG;IACnB,yFAAyF;IACzF,IAAI,CAAC,EAAE,MAAM,GAAG,aAAa,CAAC;IAC9B,sCAAsC;IACtC,QAAQ,CAAC,EAAE,aAAa,CAAC;IACzB,oCAAoC;IACpC,cAAc,CAAC,EAAE,OAAO,CAAC;CACzB,GACC,cAAc,CAAC,CAAC,CAAC,CAmFnB"}
@@ -1 +1 @@
1
- {"version":3,"file":"textToImage.d.ts","sourceRoot":"","sources":["../../../../src/tasks/cv/textToImage.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,gBAAgB,EAAqB,MAAM,oBAAoB,CAAC;AAE9E,OAAO,KAAK,EAAE,QAAQ,EAAE,OAAO,EAAE,MAAM,aAAa,CAAC;AAIrD,MAAM,MAAM,eAAe,GAAG,QAAQ,GAAG,gBAAgB,CAAC;AAW1D;;;GAGG;AACH,wBAAsB,WAAW,CAAC,IAAI,EAAE,eAAe,EAAE,OAAO,CAAC,EAAE,OAAO,GAAG,OAAO,CAAC,IAAI,CAAC,CAoCzF"}
1
+ {"version":3,"file":"textToImage.d.ts","sourceRoot":"","sources":["../../../../src/tasks/cv/textToImage.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,gBAAgB,EAAqB,MAAM,oBAAoB,CAAC;AAE9E,OAAO,KAAK,EAAE,QAAQ,EAAqB,OAAO,EAAE,MAAM,aAAa,CAAC;AAIxE,MAAM,MAAM,eAAe,GAAG,QAAQ,GAAG,gBAAgB,CAAC;AA0B1D;;;GAGG;AACH,wBAAsB,WAAW,CAAC,IAAI,EAAE,eAAe,EAAE,OAAO,CAAC,EAAE,OAAO,GAAG,OAAO,CAAC,IAAI,CAAC,CAqCzF"}
@@ -1 +1 @@
1
- {"version":3,"file":"chatCompletion.d.ts","sourceRoot":"","sources":["../../../../src/tasks/nlp/chatCompletion.ts"],"names":[],"mappings":"AACA,OAAO,KAAK,EAAE,QAAQ,EAAE,OAAO,EAAE,MAAM,aAAa,CAAC;AAErD,OAAO,KAAK,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,MAAM,oBAAoB,CAAC;AAEpF;;GAEG;AACH,wBAAsB,cAAc,CACnC,IAAI,EAAE,QAAQ,GAAG,mBAAmB,EACpC,OAAO,CAAC,EAAE,OAAO,GACf,OAAO,CAAC,oBAAoB,CAAC,CAoB/B"}
1
+ {"version":3,"file":"chatCompletion.d.ts","sourceRoot":"","sources":["../../../../src/tasks/nlp/chatCompletion.ts"],"names":[],"mappings":"AACA,OAAO,KAAK,EAAE,QAAQ,EAAE,OAAO,EAAE,MAAM,aAAa,CAAC;AAErD,OAAO,KAAK,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,MAAM,oBAAoB,CAAC;AAEpF;;GAEG;AACH,wBAAsB,cAAc,CACnC,IAAI,EAAE,QAAQ,GAAG,mBAAmB,EACpC,OAAO,CAAC,EAAE,OAAO,GACf,OAAO,CAAC,oBAAoB,CAAC,CAuB/B"}
@@ -5,25 +5,9 @@ import type { ChatCompletionInput, PipelineType } from "@huggingface/tasks";
5
5
  export type ModelId = string;
6
6
  export interface Options {
7
7
  /**
8
- * (Default: true) Boolean. If a request 503s and wait_for_model is set to false, the request will be retried with the same parameters but with wait_for_model set to true.
8
+ * (Default: true) Boolean. If a request 503s, the request will be retried with the same parameters.
9
9
  */
10
10
  retry_on_error?: boolean;
11
- /**
12
- * (Default: true). Boolean. There is a cache layer on Inference API (serverless) to speedup requests we have already seen. Most models can use those results as is as models are deterministic (meaning the results will be the same anyway). However if you use a non deterministic model, you can set this parameter to prevent the caching mechanism from being used resulting in a real new query.
13
- */
14
- use_cache?: boolean;
15
- /**
16
- * (Default: false). Boolean. Do not load the model if it's not already available.
17
- */
18
- dont_load_model?: boolean;
19
- /**
20
- * (Default: false). Boolean to use GPU instead of CPU for inference (requires Startup plan at least).
21
- */
22
- use_gpu?: boolean;
23
- /**
24
- * (Default: false) Boolean. If the model is not ready, wait for it instead of receiving 503. It limits the number of requests required to get your inference done. It is advised to only set this flag to true after receiving a 503 error as it will limit hanging in your application to known places.
25
- */
26
- wait_for_model?: boolean;
27
11
  /**
28
12
  * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
29
13
  */
@@ -38,7 +22,7 @@ export interface Options {
38
22
  includeCredentials?: string | boolean;
39
23
  }
40
24
  export type InferenceTask = Exclude<PipelineType, "other">;
41
- export declare const INFERENCE_PROVIDERS: readonly ["fal-ai", "replicate", "sambanova", "together", "hf-inference"];
25
+ export declare const INFERENCE_PROVIDERS: readonly ["fal-ai", "fireworks-ai", "nebius", "hf-inference", "replicate", "sambanova", "together"];
42
26
  export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
43
27
  export interface BaseArgs {
44
28
  /**
@@ -1 +1 @@
1
- {"version":3,"file":"types.d.ts","sourceRoot":"","sources":["../../src/types.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,mBAAmB,EAAE,YAAY,EAAE,MAAM,oBAAoB,CAAC;AAE5E;;GAEG;AACH,MAAM,MAAM,OAAO,GAAG,MAAM,CAAC;AAE7B,MAAM,WAAW,OAAO;IACvB;;OAEG;IACH,cAAc,CAAC,EAAE,OAAO,CAAC;IACzB;;OAEG;IACH,SAAS,CAAC,EAAE,OAAO,CAAC;IACpB;;OAEG;IACH,eAAe,CAAC,EAAE,OAAO,CAAC;IAC1B;;OAEG;IACH,OAAO,CAAC,EAAE,OAAO,CAAC;IAElB;;OAEG;IACH,cAAc,CAAC,EAAE,OAAO,CAAC;IACzB;;OAEG;IACH,KAAK,CAAC,EAAE,OAAO,KAAK,CAAC;IACrB;;OAEG;IACH,MAAM,CAAC,EAAE,WAAW,CAAC;IAErB;;OAEG;IACH,kBAAkB,CAAC,EAAE,MAAM,GAAG,OAAO,CAAC;CACtC;AAED,MAAM,MAAM,aAAa,GAAG,OAAO,CAAC,YAAY,EAAE,OAAO,CAAC,CAAC;AAE3D,eAAO,MAAM,mBAAmB,2EAA4E,CAAC;AAC7G,MAAM,MAAM,iBAAiB,GAAG,CAAC,OAAO,mBAAmB,CAAC,CAAC,MAAM,CAAC,CAAC;AAErE,MAAM,WAAW,QAAQ;IACxB;;;;;;OAMG;IACH,WAAW,CAAC,EAAE,MAAM,CAAC;IAErB;;;;;;;OAOG;IACH,KAAK,CAAC,EAAE,OAAO,CAAC;IAEhB;;;;OAIG;IACH,WAAW,CAAC,EAAE,MAAM,CAAC;IAErB;;;;OAIG;IACH,QAAQ,CAAC,EAAE,iBAAiB,CAAC;CAC7B;AAED,MAAM,MAAM,WAAW,GAAG,QAAQ,GACjC,CACG;IAAE,IAAI,EAAE,IAAI,GAAG,WAAW,CAAA;CAAE,GAC5B;IAAE,MAAM,EAAE,OAAO,CAAA;CAAE,GACnB;IAAE,MAAM,EAAE,MAAM,CAAA;CAAE,GAClB;IAAE,IAAI,EAAE,MAAM,CAAA;CAAE,GAChB;IAAE,SAAS,EAAE,MAAM,CAAA;CAAE,GACrB,mBAAmB,CACrB,GAAG;IACH,UAAU,CAAC,EAAE,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;IACrC,WAAW,CAAC,EAAE,MAAM,CAAC;CACrB,CAAC"}
1
+ {"version":3,"file":"types.d.ts","sourceRoot":"","sources":["../../src/types.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,mBAAmB,EAAE,YAAY,EAAE,MAAM,oBAAoB,CAAC;AAE5E;;GAEG;AACH,MAAM,MAAM,OAAO,GAAG,MAAM,CAAC;AAE7B,MAAM,WAAW,OAAO;IACvB;;OAEG;IACH,cAAc,CAAC,EAAE,OAAO,CAAC;IAEzB;;OAEG;IACH,KAAK,CAAC,EAAE,OAAO,KAAK,CAAC;IACrB;;OAEG;IACH,MAAM,CAAC,EAAE,WAAW,CAAC;IAErB;;OAEG;IACH,kBAAkB,CAAC,EAAE,MAAM,GAAG,OAAO,CAAC;CACtC;AAED,MAAM,MAAM,aAAa,GAAG,OAAO,CAAC,YAAY,EAAE,OAAO,CAAC,CAAC;AAE3D,eAAO,MAAM,mBAAmB,qGAQtB,CAAC;AACX,MAAM,MAAM,iBAAiB,GAAG,CAAC,OAAO,mBAAmB,CAAC,CAAC,MAAM,CAAC,CAAC;AAErE,MAAM,WAAW,QAAQ;IACxB;;;;;;OAMG;IACH,WAAW,CAAC,EAAE,MAAM,CAAC;IAErB;;;;;;;OAOG;IACH,KAAK,CAAC,EAAE,OAAO,CAAC;IAEhB;;;;OAIG;IACH,WAAW,CAAC,EAAE,MAAM,CAAC;IAErB;;;;OAIG;IACH,QAAQ,CAAC,EAAE,iBAAiB,CAAC;CAC7B;AAED,MAAM,MAAM,WAAW,GAAG,QAAQ,GACjC,CACG;IAAE,IAAI,EAAE,IAAI,GAAG,WAAW,CAAA;CAAE,GAC5B;IAAE,MAAM,EAAE,OAAO,CAAA;CAAE,GACnB;IAAE,MAAM,EAAE,MAAM,CAAA;CAAE,GAClB;IAAE,IAAI,EAAE,MAAM,CAAA;CAAE,GAChB;IAAE,SAAS,EAAE,MAAM,CAAA;CAAE,GACrB,mBAAmB,CACrB,GAAG;IACH,UAAU,CAAC,EAAE,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;IACrC,WAAW,CAAC,EAAE,MAAM,CAAC;CACrB,CAAC"}
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@huggingface/inference",
3
- "version": "3.2.0",
3
+ "version": "3.3.1",
4
4
  "packageManager": "pnpm@8.10.5",
5
5
  "license": "MIT",
6
6
  "author": "Tim Mikeladze <tim.mikeladze@gmail.com>",
@@ -30,8 +30,8 @@ export async function getProviderModelId(
30
30
  options.taskHint === "text-generation" && options.chatCompletion ? "conversational" : options.taskHint;
31
31
 
32
32
  // A dict called HARDCODED_MODEL_ID_MAPPING takes precedence in all cases (useful for dev purposes)
33
- if (HARDCODED_MODEL_ID_MAPPING[params.model]) {
34
- return HARDCODED_MODEL_ID_MAPPING[params.model];
33
+ if (HARDCODED_MODEL_ID_MAPPING[params.provider]?.[params.model]) {
34
+ return HARDCODED_MODEL_ID_MAPPING[params.provider][params.model];
35
35
  }
36
36
 
37
37
  let inferenceProviderMapping: InferenceProviderMapping | null;
@@ -1,8 +1,10 @@
1
1
  import { HF_HUB_URL, HF_ROUTER_URL } from "../config";
2
2
  import { FAL_AI_API_BASE_URL } from "../providers/fal-ai";
3
+ import { NEBIUS_API_BASE_URL } from "../providers/nebius";
3
4
  import { REPLICATE_API_BASE_URL } from "../providers/replicate";
4
5
  import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
5
6
  import { TOGETHER_API_BASE_URL } from "../providers/together";
7
+ import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
6
8
  import type { InferenceProvider } from "../types";
7
9
  import type { InferenceTask, Options, RequestArgs } from "../types";
8
10
  import { isUrl } from "./isUrl";
@@ -37,8 +39,7 @@ export async function makeRequestOptions(
37
39
  let otherArgs = remainingArgs;
38
40
  const provider = maybeProvider ?? "hf-inference";
39
41
 
40
- const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion } =
41
- options ?? {};
42
+ const { forceTask, includeCredentials, taskHint, chatCompletion } = options ?? {};
42
43
 
43
44
  if (endpointUrl && provider !== "hf-inference") {
44
45
  throw new Error(`Cannot use endpointUrl with a third-party provider.`);
@@ -100,18 +101,6 @@ export async function makeRequestOptions(
100
101
  headers["Content-Type"] = "application/json";
101
102
  }
102
103
 
103
- if (provider === "hf-inference") {
104
- if (wait_for_model) {
105
- headers["X-Wait-For-Model"] = "true";
106
- }
107
- if (use_cache === false) {
108
- headers["X-Use-Cache"] = "false";
109
- }
110
- if (dont_load_model) {
111
- headers["X-Load-Model"] = "0";
112
- }
113
- }
114
-
115
104
  if (provider === "replicate") {
116
105
  headers["Prefer"] = "wait";
117
106
  }
@@ -142,7 +131,7 @@ export async function makeRequestOptions(
142
131
  ? args.data
143
132
  : JSON.stringify({
144
133
  ...otherArgs,
145
- ...(chatCompletion || provider === "together" ? { model } : undefined),
134
+ ...(chatCompletion || provider === "together" || provider === "nebius" ? { model } : undefined),
146
135
  }),
147
136
  ...(credentials ? { credentials } : undefined),
148
137
  signal: options?.signal,
@@ -171,6 +160,22 @@ function makeUrl(params: {
171
160
  : FAL_AI_API_BASE_URL;
172
161
  return `${baseUrl}/${params.model}`;
173
162
  }
163
+ case "nebius": {
164
+ const baseUrl = shouldProxy
165
+ ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
166
+ : NEBIUS_API_BASE_URL;
167
+
168
+ if (params.taskHint === "text-to-image") {
169
+ return `${baseUrl}/v1/images/generations`;
170
+ }
171
+ if (params.taskHint === "text-generation") {
172
+ if (params.chatCompletion) {
173
+ return `${baseUrl}/v1/chat/completions`;
174
+ }
175
+ return `${baseUrl}/v1/completions`;
176
+ }
177
+ return baseUrl;
178
+ }
174
179
  case "replicate": {
175
180
  const baseUrl = shouldProxy
176
181
  ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
@@ -208,6 +213,15 @@ function makeUrl(params: {
208
213
  }
209
214
  return baseUrl;
210
215
  }
216
+ case "fireworks-ai": {
217
+ const baseUrl = shouldProxy
218
+ ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
219
+ : FIREWORKS_AI_API_BASE_URL;
220
+ if (params.taskHint === "text-generation" && params.chatCompletion) {
221
+ return `${baseUrl}/v1/chat/completions`;
222
+ }
223
+ return baseUrl;
224
+ }
211
225
  default: {
212
226
  const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
213
227
  const url = params.forceTask
@@ -1,15 +1,26 @@
1
- import type { ModelId } from "../types";
1
+ import type { InferenceProvider } from "../types";
2
+ import { type ModelId } from "../types";
2
3
 
3
4
  type ProviderId = string;
4
-
5
5
  /**
6
6
  * If you want to try to run inference for a new model locally before it's registered on huggingface.co
7
7
  * for a given Inference Provider,
8
8
  * you can add it to the following dictionary, for dev purposes.
9
+ *
10
+ * We also inject into this dictionary from tests.
9
11
  */
10
- export const HARDCODED_MODEL_ID_MAPPING: Record<ModelId, ProviderId> = {
12
+ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelId, ProviderId>> = {
11
13
  /**
12
14
  * "HF model ID" => "Model ID on Inference Provider's side"
15
+ *
16
+ * Example:
17
+ * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
13
18
  */
14
- // "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
19
+ "fal-ai": {},
20
+ "fireworks-ai": {},
21
+ "hf-inference": {},
22
+ nebius: {},
23
+ replicate: {},
24
+ sambanova: {},
25
+ together: {},
15
26
  };
@@ -0,0 +1,18 @@
1
+ export const FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
2
+
3
+ /**
4
+ * See the registered mapping of HF model ID => Fireworks model ID here:
5
+ *
6
+ * https://huggingface.co/api/partners/fireworks/models
7
+ *
8
+ * This is a publicly available mapping.
9
+ *
10
+ * If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11
+ * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12
+ *
13
+ * - If you work at Fireworks and want to update this mapping, please use the model mapping API we provide on huggingface.co
14
+ * - If you're a community member and want to add a new supported HF model to Fireworks, please open an issue on the present repo
15
+ * and we will tag Fireworks team members.
16
+ *
17
+ * Thanks!
18
+ */
@@ -0,0 +1,18 @@
1
+ export const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
2
+
3
+ /**
4
+ * See the registered mapping of HF model ID => Nebius model ID here:
5
+ *
6
+ * https://huggingface.co/api/partners/nebius/models
7
+ *
8
+ * This is a publicly available mapping.
9
+ *
10
+ * If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11
+ * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12
+ *
13
+ * - If you work at Nebius and want to update this mapping, please use the model mapping API we provide on huggingface.co
14
+ * - If you're a community member and want to add a new supported HF model to Nebius, please open an issue on the present repo
15
+ * and we will tag Nebius team members.
16
+ *
17
+ * Thanks!
18
+ */
@@ -18,11 +18,8 @@ export async function request<T>(
18
18
  const { url, info } = await makeRequestOptions(args, options);
19
19
  const response = await (options?.fetch ?? fetch)(url, info);
20
20
 
21
- if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
22
- return request(args, {
23
- ...options,
24
- wait_for_model: true,
25
- });
21
+ if (options?.retry_on_error !== false && response.status === 503) {
22
+ return request(args, options);
26
23
  }
27
24
 
28
25
  if (!response.ok) {
@@ -20,11 +20,8 @@ export async function* streamingRequest<T>(
20
20
  const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
21
21
  const response = await (options?.fetch ?? fetch)(url, info);
22
22
 
23
- if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
24
- return yield* streamingRequest(args, {
25
- ...options,
26
- wait_for_model: true,
27
- });
23
+ if (options?.retry_on_error !== false && response.status === 503) {
24
+ return yield* streamingRequest(args, options);
28
25
  }
29
26
  if (!response.ok) {
30
27
  if (response.headers.get("Content-Type")?.startsWith("application/json")) {
@@ -1,6 +1,6 @@
1
1
  import type { TextToImageInput, TextToImageOutput } from "@huggingface/tasks";
2
2
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
3
- import type { BaseArgs, Options } from "../../types";
3
+ import type { BaseArgs, InferenceProvider, Options } from "../../types";
4
4
  import { omit } from "../../utils/omit";
5
5
  import { request } from "../custom/request";
6
6
 
@@ -15,24 +15,40 @@ interface OutputUrlImageGeneration {
15
15
  output: string[];
16
16
  }
17
17
 
18
+ function getResponseFormatArg(provider: InferenceProvider) {
19
+ switch (provider) {
20
+ case "fal-ai":
21
+ return { sync_mode: true };
22
+ case "nebius":
23
+ return { response_format: "b64_json" };
24
+ case "replicate":
25
+ return undefined;
26
+ case "together":
27
+ return { response_format: "base64" };
28
+ default:
29
+ return undefined;
30
+ }
31
+ }
32
+
18
33
  /**
19
34
  * This task reads some text input and outputs an image.
20
35
  * Recommended model: stabilityai/stable-diffusion-2
21
36
  */
22
37
  export async function textToImage(args: TextToImageArgs, options?: Options): Promise<Blob> {
23
38
  const payload =
24
- args.provider === "together" || args.provider === "fal-ai" || args.provider === "replicate"
25
- ? {
39
+ !args.provider || args.provider === "hf-inference" || args.provider === "sambanova"
40
+ ? args
41
+ : {
26
42
  ...omit(args, ["inputs", "parameters"]),
27
43
  ...args.parameters,
28
- ...(args.provider !== "replicate" ? { response_format: "base64" } : undefined),
44
+ ...getResponseFormatArg(args.provider),
29
45
  prompt: args.inputs,
30
- }
31
- : args;
46
+ };
32
47
  const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(payload, {
33
48
  ...options,
34
49
  taskHint: "text-to-image",
35
50
  });
51
+
36
52
  if (res && typeof res === "object") {
37
53
  if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
38
54
  const image = await fetch(res.images[0].url);
@@ -15,14 +15,17 @@ export async function chatCompletion(
15
15
  taskHint: "text-generation",
16
16
  chatCompletion: true,
17
17
  });
18
+
18
19
  const isValidOutput =
19
20
  typeof res === "object" &&
20
21
  Array.isArray(res?.choices) &&
21
22
  typeof res?.created === "number" &&
22
23
  typeof res?.id === "string" &&
23
24
  typeof res?.model === "string" &&
24
- /// Together.ai does not output a system_fingerprint
25
- (res.system_fingerprint === undefined || typeof res.system_fingerprint === "string") &&
25
+ /// Together.ai and Nebius do not output a system_fingerprint
26
+ (res.system_fingerprint === undefined ||
27
+ res.system_fingerprint === null ||
28
+ typeof res.system_fingerprint === "string") &&
26
29
  typeof res?.usage === "object";
27
30
 
28
31
  if (!isValidOutput) {
package/src/types.ts CHANGED
@@ -7,26 +7,10 @@ export type ModelId = string;
7
7
 
8
8
  export interface Options {
9
9
  /**
10
- * (Default: true) Boolean. If a request 503s and wait_for_model is set to false, the request will be retried with the same parameters but with wait_for_model set to true.
10
+ * (Default: true) Boolean. If a request 503s, the request will be retried with the same parameters.
11
11
  */
12
12
  retry_on_error?: boolean;
13
- /**
14
- * (Default: true). Boolean. There is a cache layer on Inference API (serverless) to speedup requests we have already seen. Most models can use those results as is as models are deterministic (meaning the results will be the same anyway). However if you use a non deterministic model, you can set this parameter to prevent the caching mechanism from being used resulting in a real new query.
15
- */
16
- use_cache?: boolean;
17
- /**
18
- * (Default: false). Boolean. Do not load the model if it's not already available.
19
- */
20
- dont_load_model?: boolean;
21
- /**
22
- * (Default: false). Boolean to use GPU instead of CPU for inference (requires Startup plan at least).
23
- */
24
- use_gpu?: boolean;
25
13
 
26
- /**
27
- * (Default: false) Boolean. If the model is not ready, wait for it instead of receiving 503. It limits the number of requests required to get your inference done. It is advised to only set this flag to true after receiving a 503 error as it will limit hanging in your application to known places.
28
- */
29
- wait_for_model?: boolean;
30
14
  /**
31
15
  * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
32
16
  */
@@ -44,7 +28,15 @@ export interface Options {
44
28
 
45
29
  export type InferenceTask = Exclude<PipelineType, "other">;
46
30
 
47
- export const INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-inference"] as const;
31
+ export const INFERENCE_PROVIDERS = [
32
+ "fal-ai",
33
+ "fireworks-ai",
34
+ "nebius",
35
+ "hf-inference",
36
+ "replicate",
37
+ "sambanova",
38
+ "together",
39
+ ] as const;
48
40
  export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
49
41
 
50
42
  export interface BaseArgs {