@huggingface/inference 3.3.3 → 3.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. package/README.md +4 -0
  2. package/dist/index.cjs +109 -49
  3. package/dist/index.js +109 -49
  4. package/dist/src/lib/makeRequestOptions.d.ts +0 -2
  5. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  6. package/dist/src/providers/black-forest-labs.d.ts +18 -0
  7. package/dist/src/providers/black-forest-labs.d.ts.map +1 -0
  8. package/dist/src/providers/consts.d.ts.map +1 -1
  9. package/dist/src/providers/hyperbolic.d.ts +18 -0
  10. package/dist/src/providers/hyperbolic.d.ts.map +1 -0
  11. package/dist/src/providers/novita.d.ts +18 -0
  12. package/dist/src/providers/novita.d.ts.map +1 -0
  13. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  14. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  15. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  16. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  17. package/dist/src/types.d.ts +1 -1
  18. package/dist/src/types.d.ts.map +1 -1
  19. package/dist/src/utils/delay.d.ts +2 -0
  20. package/dist/src/utils/delay.d.ts.map +1 -0
  21. package/dist/test/HfInference.spec.d.ts.map +1 -1
  22. package/package.json +1 -1
  23. package/src/lib/makeRequestOptions.ts +51 -16
  24. package/src/providers/black-forest-labs.ts +18 -0
  25. package/src/providers/consts.ts +3 -0
  26. package/src/providers/hyperbolic.ts +18 -0
  27. package/src/providers/novita.ts +18 -0
  28. package/src/tasks/cv/textToImage.ts +60 -1
  29. package/src/tasks/nlp/featureExtraction.ts +0 -4
  30. package/src/tasks/nlp/sentenceSimilarity.ts +0 -3
  31. package/src/tasks/nlp/textGeneration.ts +31 -0
  32. package/src/types.ts +5 -1
  33. package/src/utils/delay.ts +5 -0
package/README.md CHANGED
@@ -49,10 +49,13 @@ You can send inference requests to third-party providers with the inference clie
49
49
  Currently, we support the following providers:
50
50
  - [Fal.ai](https://fal.ai)
51
51
  - [Fireworks AI](https://fireworks.ai)
52
+ - [Hyperbolic](https://hyperbolic.xyz)
52
53
  - [Nebius](https://studio.nebius.ai)
54
+ - [Novita](https://novita.ai/?utm_source=github_huggingface&utm_medium=github_readme&utm_campaign=link)
53
55
  - [Replicate](https://replicate.com)
54
56
  - [Sambanova](https://sambanova.ai)
55
57
  - [Together](https://together.xyz)
58
+ - [Blackforestlabs](https://blackforestlabs.ai)
56
59
 
57
60
  To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
58
61
  ```ts
@@ -72,6 +75,7 @@ When authenticated with a third-party provider key, the request is made directly
72
75
  Only a subset of models are supported when requesting third-party providers. You can check the list of supported models per pipeline tasks here:
73
76
  - [Fal.ai supported models](https://huggingface.co/api/partners/fal-ai/models)
74
77
  - [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models)
78
+ - [Hyperbolic supported models](https://huggingface.co/api/partners/hyperbolic/models)
75
79
  - [Nebius supported models](https://huggingface.co/api/partners/nebius/models)
76
80
  - [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
77
81
  - [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
package/dist/index.cjs CHANGED
@@ -115,9 +115,18 @@ var SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
115
115
  // src/providers/together.ts
116
116
  var TOGETHER_API_BASE_URL = "https://api.together.xyz";
117
117
 
118
+ // src/providers/novita.ts
119
+ var NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
120
+
118
121
  // src/providers/fireworks-ai.ts
119
122
  var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
120
123
 
124
+ // src/providers/hyperbolic.ts
125
+ var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
126
+
127
+ // src/providers/black-forest-labs.ts
128
+ var BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
129
+
121
130
  // src/lib/isUrl.ts
122
131
  function isUrl(modelOrUrl) {
123
132
  return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
@@ -125,7 +134,7 @@ function isUrl(modelOrUrl) {
125
134
 
126
135
  // package.json
127
136
  var name = "@huggingface/inference";
128
- var version = "3.3.3";
137
+ var version = "3.3.4";
129
138
 
130
139
  // src/providers/consts.ts
131
140
  var HARDCODED_MODEL_ID_MAPPING = {
@@ -135,13 +144,16 @@ var HARDCODED_MODEL_ID_MAPPING = {
135
144
  * Example:
136
145
  * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
137
146
  */
147
+ "black-forest-labs": {},
138
148
  "fal-ai": {},
139
149
  "fireworks-ai": {},
140
150
  "hf-inference": {},
151
+ hyperbolic: {},
141
152
  nebius: {},
142
153
  replicate: {},
143
154
  sambanova: {},
144
- together: {}
155
+ together: {},
156
+ novita: {}
145
157
  };
146
158
 
147
159
  // src/lib/getProviderModelId.ts
@@ -195,13 +207,10 @@ async function makeRequestOptions(args, options) {
195
207
  const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
196
208
  let otherArgs = remainingArgs;
197
209
  const provider = maybeProvider ?? "hf-inference";
198
- const { forceTask, includeCredentials, taskHint, chatCompletion: chatCompletion2 } = options ?? {};
210
+ const { includeCredentials, taskHint, chatCompletion: chatCompletion2 } = options ?? {};
199
211
  if (endpointUrl && provider !== "hf-inference") {
200
212
  throw new Error(`Cannot use endpointUrl with a third-party provider.`);
201
213
  }
202
- if (forceTask && provider !== "hf-inference") {
203
- throw new Error(`Cannot use forceTask with a third-party provider.`);
204
- }
205
214
  if (maybeModel && isUrl(maybeModel)) {
206
215
  throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
207
216
  }
@@ -218,14 +227,19 @@ async function makeRequestOptions(args, options) {
218
227
  const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : makeUrl({
219
228
  authMethod,
220
229
  chatCompletion: chatCompletion2 ?? false,
221
- forceTask,
222
230
  model,
223
231
  provider: provider ?? "hf-inference",
224
232
  taskHint
225
233
  });
226
234
  const headers = {};
227
235
  if (accessToken) {
228
- headers["Authorization"] = provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
236
+ if (provider === "fal-ai" && authMethod === "provider-key") {
237
+ headers["Authorization"] = `Key ${accessToken}`;
238
+ } else if (provider === "black-forest-labs" && authMethod === "provider-key") {
239
+ headers["X-Key"] = accessToken;
240
+ } else {
241
+ headers["Authorization"] = `Bearer ${accessToken}`;
242
+ }
229
243
  }
230
244
  const ownUserAgent = `${name}/${version}`;
231
245
  headers["User-Agent"] = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : void 0].filter((x) => x !== void 0).join(" ");
@@ -251,7 +265,7 @@ async function makeRequestOptions(args, options) {
251
265
  method: "POST",
252
266
  body: binary ? args.data : JSON.stringify({
253
267
  ...otherArgs,
254
- ...chatCompletion2 || provider === "together" || provider === "nebius" ? { model } : void 0
268
+ ...taskHint === "text-to-image" && provider === "hyperbolic" ? { model_name: model } : chatCompletion2 || provider === "together" || provider === "nebius" || provider === "hyperbolic" ? { model } : void 0
255
269
  }),
256
270
  ...credentials ? { credentials } : void 0,
257
271
  signal: options?.signal
@@ -264,6 +278,10 @@ function makeUrl(params) {
264
278
  }
265
279
  const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
266
280
  switch (params.provider) {
281
+ case "black-forest-labs": {
282
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : BLACKFORESTLABS_AI_API_BASE_URL;
283
+ return `${baseUrl}/${params.model}`;
284
+ }
267
285
  case "fal-ai": {
268
286
  const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FAL_AI_API_BASE_URL;
269
287
  return `${baseUrl}/${params.model}`;
@@ -315,13 +333,32 @@ function makeUrl(params) {
315
333
  }
316
334
  return baseUrl;
317
335
  }
336
+ case "hyperbolic": {
337
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : HYPERBOLIC_API_BASE_URL;
338
+ if (params.taskHint === "text-to-image") {
339
+ return `${baseUrl}/v1/images/generations`;
340
+ }
341
+ return `${baseUrl}/v1/chat/completions`;
342
+ }
343
+ case "novita": {
344
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : NOVITA_API_BASE_URL;
345
+ if (params.taskHint === "text-generation") {
346
+ if (params.chatCompletion) {
347
+ return `${baseUrl}/chat/completions`;
348
+ }
349
+ return `${baseUrl}/completions`;
350
+ }
351
+ return baseUrl;
352
+ }
318
353
  default: {
319
354
  const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
320
- const url = params.forceTask ? `${baseUrl}/pipeline/${params.forceTask}/${params.model}` : `${baseUrl}/models/${params.model}`;
355
+ if (params.taskHint && ["feature-extraction", "sentence-similarity"].includes(params.taskHint)) {
356
+ return `${baseUrl}/pipeline/${params.taskHint}/${params.model}`;
357
+ }
321
358
  if (params.taskHint === "text-generation" && params.chatCompletion) {
322
- return url + `/v1/chat/completions`;
359
+ return `${baseUrl}/models/${params.model}/v1/chat/completions`;
323
360
  }
324
- return url;
361
+ return `${baseUrl}/models/${params.model}`;
325
362
  }
326
363
  }
327
364
  }
@@ -768,6 +805,13 @@ async function objectDetection(args, options) {
768
805
  return res;
769
806
  }
770
807
 
808
+ // src/utils/delay.ts
809
+ function delay(ms) {
810
+ return new Promise((resolve) => {
811
+ setTimeout(() => resolve(), ms);
812
+ });
813
+ }
814
+
771
815
  // src/tasks/cv/textToImage.ts
772
816
  function getResponseFormatArg(provider) {
773
817
  switch (provider) {
@@ -795,10 +839,18 @@ async function textToImage(args, options) {
795
839
  taskHint: "text-to-image"
796
840
  });
797
841
  if (res && typeof res === "object") {
842
+ if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
843
+ return await pollBflResponse(res.polling_url);
844
+ }
798
845
  if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
799
846
  const image = await fetch(res.images[0].url);
800
847
  return await image.blob();
801
848
  }
849
+ if (args.provider === "hyperbolic" && "images" in res && Array.isArray(res.images) && res.images[0] && typeof res.images[0].image === "string") {
850
+ const base64Response = await fetch(`data:image/jpeg;base64,${res.images[0].image}`);
851
+ const blob = await base64Response.blob();
852
+ return blob;
853
+ }
802
854
  if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
803
855
  const base64Data = res.data[0].b64_json;
804
856
  const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
@@ -817,6 +869,24 @@ async function textToImage(args, options) {
817
869
  }
818
870
  return res;
819
871
  }
872
+ async function pollBflResponse(url) {
873
+ const urlObj = new URL(url);
874
+ for (let step = 0; step < 5; step++) {
875
+ await delay(1e3);
876
+ console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
877
+ urlObj.searchParams.set("attempt", step.toString(10));
878
+ const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
879
+ if (!resp.ok) {
880
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
881
+ }
882
+ const payload = await resp.json();
883
+ if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
884
+ const image = await fetch(payload.result.sample);
885
+ return await image.blob();
886
+ }
887
+ }
888
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
889
+ }
820
890
 
821
891
  // src/tasks/cv/imageToImage.ts
822
892
  async function imageToImage(args, options) {
@@ -911,43 +981,11 @@ async function textToVideo(args, options) {
911
981
  }
912
982
  }
913
983
 
914
- // src/lib/getDefaultTask.ts
915
- var taskCache = /* @__PURE__ */ new Map();
916
- var CACHE_DURATION = 10 * 60 * 1e3;
917
- var MAX_CACHE_ITEMS = 1e3;
918
- async function getDefaultTask(model, accessToken, options) {
919
- if (isUrl(model)) {
920
- return null;
921
- }
922
- const key = `${model}:${accessToken}`;
923
- let cachedTask = taskCache.get(key);
924
- if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
925
- taskCache.delete(key);
926
- cachedTask = void 0;
927
- }
928
- if (cachedTask === void 0) {
929
- const modelTask = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
930
- headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
931
- }).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
932
- if (!modelTask) {
933
- return null;
934
- }
935
- cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
936
- taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
937
- if (taskCache.size > MAX_CACHE_ITEMS) {
938
- taskCache.delete(taskCache.keys().next().value);
939
- }
940
- }
941
- return cachedTask.task;
942
- }
943
-
944
984
  // src/tasks/nlp/featureExtraction.ts
945
985
  async function featureExtraction(args, options) {
946
- const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : void 0;
947
986
  const res = await request(args, {
948
987
  ...options,
949
- taskHint: "feature-extraction",
950
- ...defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" }
988
+ taskHint: "feature-extraction"
951
989
  });
952
990
  let isValidOutput = true;
953
991
  const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
@@ -1000,11 +1038,9 @@ async function questionAnswering(args, options) {
1000
1038
 
1001
1039
  // src/tasks/nlp/sentenceSimilarity.ts
1002
1040
  async function sentenceSimilarity(args, options) {
1003
- const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : void 0;
1004
1041
  const res = await request(prepareInput(args), {
1005
1042
  ...options,
1006
- taskHint: "sentence-similarity",
1007
- ...defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }
1043
+ taskHint: "sentence-similarity"
1008
1044
  });
1009
1045
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1010
1046
  if (!isValidOutput) {
@@ -1090,6 +1126,27 @@ async function textGeneration(args, options) {
1090
1126
  return {
1091
1127
  generated_text: completion.text
1092
1128
  };
1129
+ } else if (args.provider === "hyperbolic") {
1130
+ const payload = {
1131
+ messages: [{ content: args.inputs, role: "user" }],
1132
+ ...args.parameters ? {
1133
+ max_tokens: args.parameters.max_new_tokens,
1134
+ ...omit(args.parameters, "max_new_tokens")
1135
+ } : void 0,
1136
+ ...omit(args, ["inputs", "parameters"])
1137
+ };
1138
+ const raw = await request(payload, {
1139
+ ...options,
1140
+ taskHint: "text-generation"
1141
+ });
1142
+ const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
1143
+ if (!isValidOutput) {
1144
+ throw new InferenceOutputError("Expected ChatCompletionOutput");
1145
+ }
1146
+ const completion = raw.choices[0];
1147
+ return {
1148
+ generated_text: completion.message.content
1149
+ };
1093
1150
  } else {
1094
1151
  const res = toArray(
1095
1152
  await request(args, {
@@ -1302,10 +1359,13 @@ var HfInferenceEndpoint = class {
1302
1359
 
1303
1360
  // src/types.ts
1304
1361
  var INFERENCE_PROVIDERS = [
1362
+ "black-forest-labs",
1305
1363
  "fal-ai",
1306
1364
  "fireworks-ai",
1307
- "nebius",
1308
1365
  "hf-inference",
1366
+ "hyperbolic",
1367
+ "nebius",
1368
+ "novita",
1309
1369
  "replicate",
1310
1370
  "sambanova",
1311
1371
  "together"
package/dist/index.js CHANGED
@@ -60,9 +60,18 @@ var SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
60
60
  // src/providers/together.ts
61
61
  var TOGETHER_API_BASE_URL = "https://api.together.xyz";
62
62
 
63
+ // src/providers/novita.ts
64
+ var NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
65
+
63
66
  // src/providers/fireworks-ai.ts
64
67
  var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
65
68
 
69
+ // src/providers/hyperbolic.ts
70
+ var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
71
+
72
+ // src/providers/black-forest-labs.ts
73
+ var BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
74
+
66
75
  // src/lib/isUrl.ts
67
76
  function isUrl(modelOrUrl) {
68
77
  return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
@@ -70,7 +79,7 @@ function isUrl(modelOrUrl) {
70
79
 
71
80
  // package.json
72
81
  var name = "@huggingface/inference";
73
- var version = "3.3.3";
82
+ var version = "3.3.4";
74
83
 
75
84
  // src/providers/consts.ts
76
85
  var HARDCODED_MODEL_ID_MAPPING = {
@@ -80,13 +89,16 @@ var HARDCODED_MODEL_ID_MAPPING = {
80
89
  * Example:
81
90
  * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
82
91
  */
92
+ "black-forest-labs": {},
83
93
  "fal-ai": {},
84
94
  "fireworks-ai": {},
85
95
  "hf-inference": {},
96
+ hyperbolic: {},
86
97
  nebius: {},
87
98
  replicate: {},
88
99
  sambanova: {},
89
- together: {}
100
+ together: {},
101
+ novita: {}
90
102
  };
91
103
 
92
104
  // src/lib/getProviderModelId.ts
@@ -140,13 +152,10 @@ async function makeRequestOptions(args, options) {
140
152
  const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
141
153
  let otherArgs = remainingArgs;
142
154
  const provider = maybeProvider ?? "hf-inference";
143
- const { forceTask, includeCredentials, taskHint, chatCompletion: chatCompletion2 } = options ?? {};
155
+ const { includeCredentials, taskHint, chatCompletion: chatCompletion2 } = options ?? {};
144
156
  if (endpointUrl && provider !== "hf-inference") {
145
157
  throw new Error(`Cannot use endpointUrl with a third-party provider.`);
146
158
  }
147
- if (forceTask && provider !== "hf-inference") {
148
- throw new Error(`Cannot use forceTask with a third-party provider.`);
149
- }
150
159
  if (maybeModel && isUrl(maybeModel)) {
151
160
  throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
152
161
  }
@@ -163,14 +172,19 @@ async function makeRequestOptions(args, options) {
163
172
  const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : makeUrl({
164
173
  authMethod,
165
174
  chatCompletion: chatCompletion2 ?? false,
166
- forceTask,
167
175
  model,
168
176
  provider: provider ?? "hf-inference",
169
177
  taskHint
170
178
  });
171
179
  const headers = {};
172
180
  if (accessToken) {
173
- headers["Authorization"] = provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
181
+ if (provider === "fal-ai" && authMethod === "provider-key") {
182
+ headers["Authorization"] = `Key ${accessToken}`;
183
+ } else if (provider === "black-forest-labs" && authMethod === "provider-key") {
184
+ headers["X-Key"] = accessToken;
185
+ } else {
186
+ headers["Authorization"] = `Bearer ${accessToken}`;
187
+ }
174
188
  }
175
189
  const ownUserAgent = `${name}/${version}`;
176
190
  headers["User-Agent"] = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : void 0].filter((x) => x !== void 0).join(" ");
@@ -196,7 +210,7 @@ async function makeRequestOptions(args, options) {
196
210
  method: "POST",
197
211
  body: binary ? args.data : JSON.stringify({
198
212
  ...otherArgs,
199
- ...chatCompletion2 || provider === "together" || provider === "nebius" ? { model } : void 0
213
+ ...taskHint === "text-to-image" && provider === "hyperbolic" ? { model_name: model } : chatCompletion2 || provider === "together" || provider === "nebius" || provider === "hyperbolic" ? { model } : void 0
200
214
  }),
201
215
  ...credentials ? { credentials } : void 0,
202
216
  signal: options?.signal
@@ -209,6 +223,10 @@ function makeUrl(params) {
209
223
  }
210
224
  const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
211
225
  switch (params.provider) {
226
+ case "black-forest-labs": {
227
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : BLACKFORESTLABS_AI_API_BASE_URL;
228
+ return `${baseUrl}/${params.model}`;
229
+ }
212
230
  case "fal-ai": {
213
231
  const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FAL_AI_API_BASE_URL;
214
232
  return `${baseUrl}/${params.model}`;
@@ -260,13 +278,32 @@ function makeUrl(params) {
260
278
  }
261
279
  return baseUrl;
262
280
  }
281
+ case "hyperbolic": {
282
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : HYPERBOLIC_API_BASE_URL;
283
+ if (params.taskHint === "text-to-image") {
284
+ return `${baseUrl}/v1/images/generations`;
285
+ }
286
+ return `${baseUrl}/v1/chat/completions`;
287
+ }
288
+ case "novita": {
289
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : NOVITA_API_BASE_URL;
290
+ if (params.taskHint === "text-generation") {
291
+ if (params.chatCompletion) {
292
+ return `${baseUrl}/chat/completions`;
293
+ }
294
+ return `${baseUrl}/completions`;
295
+ }
296
+ return baseUrl;
297
+ }
263
298
  default: {
264
299
  const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
265
- const url = params.forceTask ? `${baseUrl}/pipeline/${params.forceTask}/${params.model}` : `${baseUrl}/models/${params.model}`;
300
+ if (params.taskHint && ["feature-extraction", "sentence-similarity"].includes(params.taskHint)) {
301
+ return `${baseUrl}/pipeline/${params.taskHint}/${params.model}`;
302
+ }
266
303
  if (params.taskHint === "text-generation" && params.chatCompletion) {
267
- return url + `/v1/chat/completions`;
304
+ return `${baseUrl}/models/${params.model}/v1/chat/completions`;
268
305
  }
269
- return url;
306
+ return `${baseUrl}/models/${params.model}`;
270
307
  }
271
308
  }
272
309
  }
@@ -713,6 +750,13 @@ async function objectDetection(args, options) {
713
750
  return res;
714
751
  }
715
752
 
753
+ // src/utils/delay.ts
754
+ function delay(ms) {
755
+ return new Promise((resolve) => {
756
+ setTimeout(() => resolve(), ms);
757
+ });
758
+ }
759
+
716
760
  // src/tasks/cv/textToImage.ts
717
761
  function getResponseFormatArg(provider) {
718
762
  switch (provider) {
@@ -740,10 +784,18 @@ async function textToImage(args, options) {
740
784
  taskHint: "text-to-image"
741
785
  });
742
786
  if (res && typeof res === "object") {
787
+ if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
788
+ return await pollBflResponse(res.polling_url);
789
+ }
743
790
  if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
744
791
  const image = await fetch(res.images[0].url);
745
792
  return await image.blob();
746
793
  }
794
+ if (args.provider === "hyperbolic" && "images" in res && Array.isArray(res.images) && res.images[0] && typeof res.images[0].image === "string") {
795
+ const base64Response = await fetch(`data:image/jpeg;base64,${res.images[0].image}`);
796
+ const blob = await base64Response.blob();
797
+ return blob;
798
+ }
747
799
  if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
748
800
  const base64Data = res.data[0].b64_json;
749
801
  const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
@@ -762,6 +814,24 @@ async function textToImage(args, options) {
762
814
  }
763
815
  return res;
764
816
  }
817
+ async function pollBflResponse(url) {
818
+ const urlObj = new URL(url);
819
+ for (let step = 0; step < 5; step++) {
820
+ await delay(1e3);
821
+ console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
822
+ urlObj.searchParams.set("attempt", step.toString(10));
823
+ const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
824
+ if (!resp.ok) {
825
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
826
+ }
827
+ const payload = await resp.json();
828
+ if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
829
+ const image = await fetch(payload.result.sample);
830
+ return await image.blob();
831
+ }
832
+ }
833
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
834
+ }
765
835
 
766
836
  // src/tasks/cv/imageToImage.ts
767
837
  async function imageToImage(args, options) {
@@ -856,43 +926,11 @@ async function textToVideo(args, options) {
856
926
  }
857
927
  }
858
928
 
859
- // src/lib/getDefaultTask.ts
860
- var taskCache = /* @__PURE__ */ new Map();
861
- var CACHE_DURATION = 10 * 60 * 1e3;
862
- var MAX_CACHE_ITEMS = 1e3;
863
- async function getDefaultTask(model, accessToken, options) {
864
- if (isUrl(model)) {
865
- return null;
866
- }
867
- const key = `${model}:${accessToken}`;
868
- let cachedTask = taskCache.get(key);
869
- if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
870
- taskCache.delete(key);
871
- cachedTask = void 0;
872
- }
873
- if (cachedTask === void 0) {
874
- const modelTask = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
875
- headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
876
- }).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
877
- if (!modelTask) {
878
- return null;
879
- }
880
- cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
881
- taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
882
- if (taskCache.size > MAX_CACHE_ITEMS) {
883
- taskCache.delete(taskCache.keys().next().value);
884
- }
885
- }
886
- return cachedTask.task;
887
- }
888
-
889
929
  // src/tasks/nlp/featureExtraction.ts
890
930
  async function featureExtraction(args, options) {
891
- const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : void 0;
892
931
  const res = await request(args, {
893
932
  ...options,
894
- taskHint: "feature-extraction",
895
- ...defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" }
933
+ taskHint: "feature-extraction"
896
934
  });
897
935
  let isValidOutput = true;
898
936
  const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
@@ -945,11 +983,9 @@ async function questionAnswering(args, options) {
945
983
 
946
984
  // src/tasks/nlp/sentenceSimilarity.ts
947
985
  async function sentenceSimilarity(args, options) {
948
- const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : void 0;
949
986
  const res = await request(prepareInput(args), {
950
987
  ...options,
951
- taskHint: "sentence-similarity",
952
- ...defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }
988
+ taskHint: "sentence-similarity"
953
989
  });
954
990
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
955
991
  if (!isValidOutput) {
@@ -1035,6 +1071,27 @@ async function textGeneration(args, options) {
1035
1071
  return {
1036
1072
  generated_text: completion.text
1037
1073
  };
1074
+ } else if (args.provider === "hyperbolic") {
1075
+ const payload = {
1076
+ messages: [{ content: args.inputs, role: "user" }],
1077
+ ...args.parameters ? {
1078
+ max_tokens: args.parameters.max_new_tokens,
1079
+ ...omit(args.parameters, "max_new_tokens")
1080
+ } : void 0,
1081
+ ...omit(args, ["inputs", "parameters"])
1082
+ };
1083
+ const raw = await request(payload, {
1084
+ ...options,
1085
+ taskHint: "text-generation"
1086
+ });
1087
+ const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
1088
+ if (!isValidOutput) {
1089
+ throw new InferenceOutputError("Expected ChatCompletionOutput");
1090
+ }
1091
+ const completion = raw.choices[0];
1092
+ return {
1093
+ generated_text: completion.message.content
1094
+ };
1038
1095
  } else {
1039
1096
  const res = toArray(
1040
1097
  await request(args, {
@@ -1247,10 +1304,13 @@ var HfInferenceEndpoint = class {
1247
1304
 
1248
1305
  // src/types.ts
1249
1306
  var INFERENCE_PROVIDERS = [
1307
+ "black-forest-labs",
1250
1308
  "fal-ai",
1251
1309
  "fireworks-ai",
1252
- "nebius",
1253
1310
  "hf-inference",
1311
+ "hyperbolic",
1312
+ "nebius",
1313
+ "novita",
1254
1314
  "replicate",
1255
1315
  "sambanova",
1256
1316
  "together"
@@ -6,8 +6,6 @@ export declare function makeRequestOptions(args: RequestArgs & {
6
6
  data?: Blob | ArrayBuffer;
7
7
  stream?: boolean;
8
8
  }, options?: Options & {
9
- /** When a model can be used for multiple tasks, and we want to run a non-default task */
10
- forceTask?: string | InferenceTask;
11
9
  /** To load default model if needed */
12
10
  taskHint?: InferenceTask;
13
11
  chatCompletion?: boolean;
@@ -1 +1 @@
1
- {"version":3,"file":"makeRequestOptions.d.ts","sourceRoot":"","sources":["../../../src/lib/makeRequestOptions.ts"],"names":[],"mappings":"AAQA,OAAO,KAAK,EAAE,aAAa,EAAE,OAAO,EAAE,WAAW,EAAE,MAAM,UAAU,CAAC;AAapE;;GAEG;AACH,wBAAsB,kBAAkB,CACvC,IAAI,EAAE,WAAW,GAAG;IACnB,IAAI,CAAC,EAAE,IAAI,GAAG,WAAW,CAAC;IAC1B,MAAM,CAAC,EAAE,OAAO,CAAC;CACjB,EACD,OAAO,CAAC,EAAE,OAAO,GAAG;IACnB,yFAAyF;IACzF,SAAS,CAAC,EAAE,MAAM,GAAG,aAAa,CAAC;IACnC,sCAAsC;IACtC,QAAQ,CAAC,EAAE,aAAa,CAAC;IACzB,cAAc,CAAC,EAAE,OAAO,CAAC;CACzB,GACC,OAAO,CAAC;IAAE,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,WAAW,CAAA;CAAE,CAAC,CAwG7C"}
1
+ {"version":3,"file":"makeRequestOptions.d.ts","sourceRoot":"","sources":["../../../src/lib/makeRequestOptions.ts"],"names":[],"mappings":"AAWA,OAAO,KAAK,EAAE,aAAa,EAAE,OAAO,EAAE,WAAW,EAAE,MAAM,UAAU,CAAC;AAapE;;GAEG;AACH,wBAAsB,kBAAkB,CACvC,IAAI,EAAE,WAAW,GAAG;IACnB,IAAI,CAAC,EAAE,IAAI,GAAG,WAAW,CAAC;IAC1B,MAAM,CAAC,EAAE,OAAO,CAAC;CACjB,EACD,OAAO,CAAC,EAAE,OAAO,GAAG;IACnB,sCAAsC;IACtC,QAAQ,CAAC,EAAE,aAAa,CAAC;IACzB,cAAc,CAAC,EAAE,OAAO,CAAC;CACzB,GACC,OAAO,CAAC;IAAE,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,WAAW,CAAA;CAAE,CAAC,CA6G7C"}
@@ -0,0 +1,18 @@
1
+ export declare const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
2
+ /**
3
+ * See the registered mapping of HF model ID => Black Forest Labs model ID here:
4
+ *
5
+ * https://huggingface.co/api/partners/blackforestlabs/models
6
+ *
7
+ * This is a publicly available mapping.
8
+ *
9
+ * If you want to try to run inference for a new model locally before it's registered on huggingface.co,
10
+ * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
11
+ *
12
+ * - If you work at Black Forest Labs and want to update this mapping, please use the model mapping API we provide on huggingface.co
13
+ * - If you're a community member and want to add a new supported HF model to Black Forest Labs, please open an issue on the present repo
14
+ * and we will tag Black Forest Labs team members.
15
+ *
16
+ * Thanks!
17
+ */
18
+ //# sourceMappingURL=black-forest-labs.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"black-forest-labs.d.ts","sourceRoot":"","sources":["../../../src/providers/black-forest-labs.ts"],"names":[],"mappings":"AAAA,eAAO,MAAM,+BAA+B,8BAA8B,CAAC;AAE3E;;;;;;;;;;;;;;;GAeG"}
@@ -1 +1 @@
1
- {"version":3,"file":"consts.d.ts","sourceRoot":"","sources":["../../../src/providers/consts.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,iBAAiB,EAAE,MAAM,UAAU,CAAC;AAClD,OAAO,EAAE,KAAK,OAAO,EAAE,MAAM,UAAU,CAAC;AAExC,KAAK,UAAU,GAAG,MAAM,CAAC;AACzB;;;;;;GAMG;AACH,eAAO,MAAM,0BAA0B,EAAE,MAAM,CAAC,iBAAiB,EAAE,MAAM,CAAC,OAAO,EAAE,UAAU,CAAC,CAc7F,CAAC"}
1
+ {"version":3,"file":"consts.d.ts","sourceRoot":"","sources":["../../../src/providers/consts.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,iBAAiB,EAAE,MAAM,UAAU,CAAC;AAClD,OAAO,EAAE,KAAK,OAAO,EAAE,MAAM,UAAU,CAAC;AAExC,KAAK,UAAU,GAAG,MAAM,CAAC;AACzB;;;;;;GAMG;AACH,eAAO,MAAM,0BAA0B,EAAE,MAAM,CAAC,iBAAiB,EAAE,MAAM,CAAC,OAAO,EAAE,UAAU,CAAC,CAiB7F,CAAC"}
@@ -0,0 +1,18 @@
1
+ export declare const HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
2
+ /**
3
+ * See the registered mapping of HF model ID => Hyperbolic model ID here:
4
+ *
5
+ * https://huggingface.co/api/partners/hyperbolic/models
6
+ *
7
+ * This is a publicly available mapping.
8
+ *
9
+ * If you want to try to run inference for a new model locally before it's registered on huggingface.co,
10
+ * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
11
+ *
12
+ * - If you work at Hyperbolic and want to update this mapping, please use the model mapping API we provide on huggingface.co
13
+ * - If you're a community member and want to add a new supported HF model to Hyperbolic, please open an issue on the present repo
14
+ * and we will tag Hyperbolic team members.
15
+ *
16
+ * Thanks!
17
+ */
18
+ //# sourceMappingURL=hyperbolic.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"hyperbolic.d.ts","sourceRoot":"","sources":["../../../src/providers/hyperbolic.ts"],"names":[],"mappings":"AAAA,eAAO,MAAM,uBAAuB,+BAA+B,CAAC;AAEpE;;;;;;;;;;;;;;;GAeG"}
@@ -0,0 +1,18 @@
1
+ export declare const NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
2
+ /**
3
+ * See the registered mapping of HF model ID => Novita model ID here:
4
+ *
5
+ * https://huggingface.co/api/partners/novita/models
6
+ *
7
+ * This is a publicly available mapping.
8
+ *
9
+ * If you want to try to run inference for a new model locally before it's registered on huggingface.co,
10
+ * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
11
+ *
12
+ * - If you work at Novita and want to update this mapping, please use the model mapping API we provide on huggingface.co
13
+ * - If you're a community member and want to add a new supported HF model to Novita, please open an issue on the present repo
14
+ * and we will tag Novita team members.
15
+ *
16
+ * Thanks!
17
+ */
18
+ //# sourceMappingURL=novita.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"novita.d.ts","sourceRoot":"","sources":["../../../src/providers/novita.ts"],"names":[],"mappings":"AAAA,eAAO,MAAM,mBAAmB,oCAAoC,CAAC;AAErE;;;;;;;;;;;;;;;GAeG"}
@@ -1 +1 @@
1
- {"version":3,"file":"textToImage.d.ts","sourceRoot":"","sources":["../../../../src/tasks/cv/textToImage.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,gBAAgB,EAAqB,MAAM,oBAAoB,CAAC;AAE9E,OAAO,KAAK,EAAE,QAAQ,EAAqB,OAAO,EAAE,MAAM,aAAa,CAAC;AAIxE,MAAM,MAAM,eAAe,GAAG,QAAQ,GAAG,gBAAgB,CAAC;AA0B1D;;;GAGG;AACH,wBAAsB,WAAW,CAAC,IAAI,EAAE,eAAe,EAAE,OAAO,CAAC,EAAE,OAAO,GAAG,OAAO,CAAC,IAAI,CAAC,CAqCzF"}
1
+ {"version":3,"file":"textToImage.d.ts","sourceRoot":"","sources":["../../../../src/tasks/cv/textToImage.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,gBAAgB,EAAqB,MAAM,oBAAoB,CAAC;AAE9E,OAAO,KAAK,EAAE,QAAQ,EAAqB,OAAO,EAAE,MAAM,aAAa,CAAC;AAKxE,MAAM,MAAM,eAAe,GAAG,QAAQ,GAAG,gBAAgB,CAAC;AAkC1D;;;GAGG;AACH,wBAAsB,WAAW,CAAC,IAAI,EAAE,eAAe,EAAE,OAAO,CAAC,EAAE,OAAO,GAAG,OAAO,CAAC,IAAI,CAAC,CAyDzF"}
@@ -1 +1 @@
1
- {"version":3,"file":"featureExtraction.d.ts","sourceRoot":"","sources":["../../../../src/tasks/nlp/featureExtraction.ts"],"names":[],"mappings":"AAEA,OAAO,KAAK,EAAE,QAAQ,EAAE,OAAO,EAAE,MAAM,aAAa,CAAC;AAGrD,MAAM,MAAM,qBAAqB,GAAG,QAAQ,GAAG;IAC9C;;;;;OAKG;IACH,MAAM,EAAE,MAAM,GAAG,MAAM,EAAE,CAAC;CAC1B,CAAC;AAEF;;GAEG;AACH,MAAM,MAAM,uBAAuB,GAAG,CAAC,MAAM,GAAG,MAAM,EAAE,GAAG,MAAM,EAAE,EAAE,CAAC,EAAE,CAAC;AAEzE;;GAEG;AACH,wBAAsB,iBAAiB,CACtC,IAAI,EAAE,qBAAqB,EAC3B,OAAO,CAAC,EAAE,OAAO,GACf,OAAO,CAAC,uBAAuB,CAAC,CAyBlC"}
1
+ {"version":3,"file":"featureExtraction.d.ts","sourceRoot":"","sources":["../../../../src/tasks/nlp/featureExtraction.ts"],"names":[],"mappings":"AACA,OAAO,KAAK,EAAE,QAAQ,EAAE,OAAO,EAAE,MAAM,aAAa,CAAC;AAGrD,MAAM,MAAM,qBAAqB,GAAG,QAAQ,GAAG;IAC9C;;;;;OAKG;IACH,MAAM,EAAE,MAAM,GAAG,MAAM,EAAE,CAAC;CAC1B,CAAC;AAEF;;GAEG;AACH,MAAM,MAAM,uBAAuB,GAAG,CAAC,MAAM,GAAG,MAAM,EAAE,GAAG,MAAM,EAAE,EAAE,CAAC,EAAE,CAAC;AAEzE;;GAEG;AACH,wBAAsB,iBAAiB,CACtC,IAAI,EAAE,qBAAqB,EAC3B,OAAO,CAAC,EAAE,OAAO,GACf,OAAO,CAAC,uBAAuB,CAAC,CAsBlC"}
@@ -1 +1 @@
1
- {"version":3,"file":"sentenceSimilarity.d.ts","sourceRoot":"","sources":["../../../../src/tasks/nlp/sentenceSimilarity.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,uBAAuB,EAAE,wBAAwB,EAAE,MAAM,oBAAoB,CAAC;AAG5F,OAAO,KAAK,EAAE,QAAQ,EAAE,OAAO,EAAE,MAAM,aAAa,CAAC;AAIrD,MAAM,MAAM,sBAAsB,GAAG,QAAQ,GAAG,uBAAuB,CAAC;AAExE;;GAEG;AACH,wBAAsB,kBAAkB,CACvC,IAAI,EAAE,sBAAsB,EAC5B,OAAO,CAAC,EAAE,OAAO,GACf,OAAO,CAAC,wBAAwB,CAAC,CAanC"}
1
+ {"version":3,"file":"sentenceSimilarity.d.ts","sourceRoot":"","sources":["../../../../src/tasks/nlp/sentenceSimilarity.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,uBAAuB,EAAE,wBAAwB,EAAE,MAAM,oBAAoB,CAAC;AAE5F,OAAO,KAAK,EAAE,QAAQ,EAAE,OAAO,EAAE,MAAM,aAAa,CAAC;AAIrD,MAAM,MAAM,sBAAsB,GAAG,QAAQ,GAAG,uBAAuB,CAAC;AAExE;;GAEG;AACH,wBAAsB,kBAAkB,CACvC,IAAI,EAAE,sBAAsB,EAC5B,OAAO,CAAC,EAAE,OAAO,GACf,OAAO,CAAC,wBAAwB,CAAC,CAWnC"}
@@ -1 +1 @@
1
- {"version":3,"file":"textGeneration.d.ts","sourceRoot":"","sources":["../../../../src/tasks/nlp/textGeneration.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAEX,mBAAmB,EACnB,oBAAoB,EAEpB,MAAM,oBAAoB,CAAC;AAE5B,OAAO,KAAK,EAAE,QAAQ,EAAE,OAAO,EAAE,MAAM,aAAa,CAAC;AAIrD,YAAY,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,CAAC;AAY1D;;GAEG;AACH,wBAAsB,cAAc,CACnC,IAAI,EAAE,QAAQ,GAAG,mBAAmB,EACpC,OAAO,CAAC,EAAE,OAAO,GACf,OAAO,CAAC,oBAAoB,CAAC,CA+B/B"}
1
+ {"version":3,"file":"textGeneration.d.ts","sourceRoot":"","sources":["../../../../src/tasks/nlp/textGeneration.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAEX,mBAAmB,EACnB,oBAAoB,EAEpB,MAAM,oBAAoB,CAAC;AAE5B,OAAO,KAAK,EAAE,QAAQ,EAAE,OAAO,EAAE,MAAM,aAAa,CAAC;AAKrD,YAAY,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,CAAC;AAkB1D;;GAEG;AACH,wBAAsB,cAAc,CACnC,IAAI,EAAE,QAAQ,GAAG,mBAAmB,EACpC,OAAO,CAAC,EAAE,OAAO,GACf,OAAO,CAAC,oBAAoB,CAAC,CAuD/B"}
@@ -22,7 +22,7 @@ export interface Options {
22
22
  includeCredentials?: string | boolean;
23
23
  }
24
24
  export type InferenceTask = Exclude<PipelineType, "other">;
25
- export declare const INFERENCE_PROVIDERS: readonly ["fal-ai", "fireworks-ai", "nebius", "hf-inference", "replicate", "sambanova", "together"];
25
+ export declare const INFERENCE_PROVIDERS: readonly ["black-forest-labs", "fal-ai", "fireworks-ai", "hf-inference", "hyperbolic", "nebius", "novita", "replicate", "sambanova", "together"];
26
26
  export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
27
27
  export interface BaseArgs {
28
28
  /**
@@ -1 +1 @@
1
- {"version":3,"file":"types.d.ts","sourceRoot":"","sources":["../../src/types.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,mBAAmB,EAAE,YAAY,EAAE,MAAM,oBAAoB,CAAC;AAE5E;;GAEG;AACH,MAAM,MAAM,OAAO,GAAG,MAAM,CAAC;AAE7B,MAAM,WAAW,OAAO;IACvB;;OAEG;IACH,cAAc,CAAC,EAAE,OAAO,CAAC;IAEzB;;OAEG;IACH,KAAK,CAAC,EAAE,OAAO,KAAK,CAAC;IACrB;;OAEG;IACH,MAAM,CAAC,EAAE,WAAW,CAAC;IAErB;;OAEG;IACH,kBAAkB,CAAC,EAAE,MAAM,GAAG,OAAO,CAAC;CACtC;AAED,MAAM,MAAM,aAAa,GAAG,OAAO,CAAC,YAAY,EAAE,OAAO,CAAC,CAAC;AAE3D,eAAO,MAAM,mBAAmB,qGAQtB,CAAC;AACX,MAAM,MAAM,iBAAiB,GAAG,CAAC,OAAO,mBAAmB,CAAC,CAAC,MAAM,CAAC,CAAC;AAErE,MAAM,WAAW,QAAQ;IACxB;;;;;;OAMG;IACH,WAAW,CAAC,EAAE,MAAM,CAAC;IAErB;;;;;;;OAOG;IACH,KAAK,CAAC,EAAE,OAAO,CAAC;IAEhB;;;;OAIG;IACH,WAAW,CAAC,EAAE,MAAM,CAAC;IAErB;;;;OAIG;IACH,QAAQ,CAAC,EAAE,iBAAiB,CAAC;CAC7B;AAED,MAAM,MAAM,WAAW,GAAG,QAAQ,GACjC,CACG;IAAE,IAAI,EAAE,IAAI,GAAG,WAAW,CAAA;CAAE,GAC5B;IAAE,MAAM,EAAE,OAAO,CAAA;CAAE,GACnB;IAAE,MAAM,EAAE,MAAM,CAAA;CAAE,GAClB;IAAE,IAAI,EAAE,MAAM,CAAA;CAAE,GAChB;IAAE,SAAS,EAAE,MAAM,CAAA;CAAE,GACrB,mBAAmB,CACrB,GAAG;IACH,UAAU,CAAC,EAAE,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;IACrC,WAAW,CAAC,EAAE,MAAM,CAAC;CACrB,CAAC"}
1
+ {"version":3,"file":"types.d.ts","sourceRoot":"","sources":["../../src/types.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,mBAAmB,EAAE,YAAY,EAAE,MAAM,oBAAoB,CAAC;AAE5E;;GAEG;AACH,MAAM,MAAM,OAAO,GAAG,MAAM,CAAC;AAE7B,MAAM,WAAW,OAAO;IACvB;;OAEG;IACH,cAAc,CAAC,EAAE,OAAO,CAAC;IAEzB;;OAEG;IACH,KAAK,CAAC,EAAE,OAAO,KAAK,CAAC;IACrB;;OAEG;IACH,MAAM,CAAC,EAAE,WAAW,CAAC;IAErB;;OAEG;IACH,kBAAkB,CAAC,EAAE,MAAM,GAAG,OAAO,CAAC;CACtC;AAED,MAAM,MAAM,aAAa,GAAG,OAAO,CAAC,YAAY,EAAE,OAAO,CAAC,CAAC;AAE3D,eAAO,MAAM,mBAAmB,kJAWtB,CAAC;AAEX,MAAM,MAAM,iBAAiB,GAAG,CAAC,OAAO,mBAAmB,CAAC,CAAC,MAAM,CAAC,CAAC;AAErE,MAAM,WAAW,QAAQ;IACxB;;;;;;OAMG;IACH,WAAW,CAAC,EAAE,MAAM,CAAC;IAErB;;;;;;;OAOG;IACH,KAAK,CAAC,EAAE,OAAO,CAAC;IAEhB;;;;OAIG;IACH,WAAW,CAAC,EAAE,MAAM,CAAC;IAErB;;;;OAIG;IACH,QAAQ,CAAC,EAAE,iBAAiB,CAAC;CAC7B;AAED,MAAM,MAAM,WAAW,GAAG,QAAQ,GACjC,CACG;IAAE,IAAI,EAAE,IAAI,GAAG,WAAW,CAAA;CAAE,GAC5B;IAAE,MAAM,EAAE,OAAO,CAAA;CAAE,GACnB;IAAE,MAAM,EAAE,MAAM,CAAA;CAAE,GAClB;IAAE,IAAI,EAAE,MAAM,CAAA;CAAE,GAChB;IAAE,SAAS,EAAE,MAAM,CAAA;CAAE,GACrB,mBAAmB,CACrB,GAAG;IACH,UAAU,CAAC,EAAE,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;IACrC,WAAW,CAAC,EAAE,MAAM,CAAC;CACrB,CAAC"}
@@ -0,0 +1,2 @@
1
+ export declare function delay(ms: number): Promise<void>;
2
+ //# sourceMappingURL=delay.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"delay.d.ts","sourceRoot":"","sources":["../../../src/utils/delay.ts"],"names":[],"mappings":"AAAA,wBAAgB,KAAK,CAAC,EAAE,EAAE,MAAM,GAAG,OAAO,CAAC,IAAI,CAAC,CAI/C"}
@@ -1 +1 @@
1
- {"version":3,"file":"HfInference.spec.d.ts","sourceRoot":"","sources":["../../test/HfInference.spec.ts"],"names":[],"mappings":"AAOA,OAAO,OAAO,CAAC"}
1
+ {"version":3,"file":"HfInference.spec.d.ts","sourceRoot":"","sources":["../../test/HfInference.spec.ts"],"names":[],"mappings":"AAQA,OAAO,OAAO,CAAC"}
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@huggingface/inference",
3
- "version": "3.3.3",
3
+ "version": "3.3.4",
4
4
  "packageManager": "pnpm@8.10.5",
5
5
  "license": "MIT",
6
6
  "author": "Tim Mikeladze <tim.mikeladze@gmail.com>",
@@ -4,7 +4,10 @@ import { NEBIUS_API_BASE_URL } from "../providers/nebius";
4
4
  import { REPLICATE_API_BASE_URL } from "../providers/replicate";
5
5
  import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
6
6
  import { TOGETHER_API_BASE_URL } from "../providers/together";
7
+ import { NOVITA_API_BASE_URL } from "../providers/novita";
7
8
  import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
9
+ import { HYPERBOLIC_API_BASE_URL } from "../providers/hyperbolic";
10
+ import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/black-forest-labs";
8
11
  import type { InferenceProvider } from "../types";
9
12
  import type { InferenceTask, Options, RequestArgs } from "../types";
10
13
  import { isUrl } from "./isUrl";
@@ -28,8 +31,6 @@ export async function makeRequestOptions(
28
31
  stream?: boolean;
29
32
  },
30
33
  options?: Options & {
31
- /** When a model can be used for multiple tasks, and we want to run a non-default task */
32
- forceTask?: string | InferenceTask;
33
34
  /** To load default model if needed */
34
35
  taskHint?: InferenceTask;
35
36
  chatCompletion?: boolean;
@@ -39,14 +40,11 @@ export async function makeRequestOptions(
39
40
  let otherArgs = remainingArgs;
40
41
  const provider = maybeProvider ?? "hf-inference";
41
42
 
42
- const { forceTask, includeCredentials, taskHint, chatCompletion } = options ?? {};
43
+ const { includeCredentials, taskHint, chatCompletion } = options ?? {};
43
44
 
44
45
  if (endpointUrl && provider !== "hf-inference") {
45
46
  throw new Error(`Cannot use endpointUrl with a third-party provider.`);
46
47
  }
47
- if (forceTask && provider !== "hf-inference") {
48
- throw new Error(`Cannot use forceTask with a third-party provider.`);
49
- }
50
48
  if (maybeModel && isUrl(maybeModel)) {
51
49
  throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
52
50
  }
@@ -77,7 +75,6 @@ export async function makeRequestOptions(
77
75
  : makeUrl({
78
76
  authMethod,
79
77
  chatCompletion: chatCompletion ?? false,
80
- forceTask,
81
78
  model,
82
79
  provider: provider ?? "hf-inference",
83
80
  taskHint,
@@ -85,8 +82,13 @@ export async function makeRequestOptions(
85
82
 
86
83
  const headers: Record<string, string> = {};
87
84
  if (accessToken) {
88
- headers["Authorization"] =
89
- provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
85
+ if (provider === "fal-ai" && authMethod === "provider-key") {
86
+ headers["Authorization"] = `Key ${accessToken}`;
87
+ } else if (provider === "black-forest-labs" && authMethod === "provider-key") {
88
+ headers["X-Key"] = accessToken;
89
+ } else {
90
+ headers["Authorization"] = `Bearer ${accessToken}`;
91
+ }
90
92
  }
91
93
 
92
94
  // e.g. @huggingface/inference/3.1.3
@@ -131,7 +133,11 @@ export async function makeRequestOptions(
131
133
  ? args.data
132
134
  : JSON.stringify({
133
135
  ...otherArgs,
134
- ...(chatCompletion || provider === "together" || provider === "nebius" ? { model } : undefined),
136
+ ...(taskHint === "text-to-image" && provider === "hyperbolic"
137
+ ? { model_name: model }
138
+ : chatCompletion || provider === "together" || provider === "nebius" || provider === "hyperbolic"
139
+ ? { model }
140
+ : undefined),
135
141
  }),
136
142
  ...(credentials ? { credentials } : undefined),
137
143
  signal: options?.signal,
@@ -146,7 +152,6 @@ function makeUrl(params: {
146
152
  model: string;
147
153
  provider: InferenceProvider;
148
154
  taskHint: InferenceTask | undefined;
149
- forceTask?: string | InferenceTask;
150
155
  }): string {
151
156
  if (params.authMethod === "none" && params.provider !== "hf-inference") {
152
157
  throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
@@ -154,6 +159,12 @@ function makeUrl(params: {
154
159
 
155
160
  const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
156
161
  switch (params.provider) {
162
+ case "black-forest-labs": {
163
+ const baseUrl = shouldProxy
164
+ ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
165
+ : BLACKFORESTLABS_AI_API_BASE_URL;
166
+ return `${baseUrl}/${params.model}`;
167
+ }
157
168
  case "fal-ai": {
158
169
  const baseUrl = shouldProxy
159
170
  ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
@@ -213,6 +224,7 @@ function makeUrl(params: {
213
224
  }
214
225
  return baseUrl;
215
226
  }
227
+
216
228
  case "fireworks-ai": {
217
229
  const baseUrl = shouldProxy
218
230
  ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
@@ -222,15 +234,38 @@ function makeUrl(params: {
222
234
  }
223
235
  return baseUrl;
224
236
  }
237
+ case "hyperbolic": {
238
+ const baseUrl = shouldProxy
239
+ ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
240
+ : HYPERBOLIC_API_BASE_URL;
241
+
242
+ if (params.taskHint === "text-to-image") {
243
+ return `${baseUrl}/v1/images/generations`;
244
+ }
245
+ return `${baseUrl}/v1/chat/completions`;
246
+ }
247
+ case "novita": {
248
+ const baseUrl = shouldProxy
249
+ ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
250
+ : NOVITA_API_BASE_URL;
251
+ if (params.taskHint === "text-generation") {
252
+ if (params.chatCompletion) {
253
+ return `${baseUrl}/chat/completions`;
254
+ }
255
+ return `${baseUrl}/completions`;
256
+ }
257
+ return baseUrl;
258
+ }
225
259
  default: {
226
260
  const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
227
- const url = params.forceTask
228
- ? `${baseUrl}/pipeline/${params.forceTask}/${params.model}`
229
- : `${baseUrl}/models/${params.model}`;
261
+ if (params.taskHint && ["feature-extraction", "sentence-similarity"].includes(params.taskHint)) {
262
+ /// when deployed on hf-inference, those two tasks are automatically compatible with one another.
263
+ return `${baseUrl}/pipeline/${params.taskHint}/${params.model}`;
264
+ }
230
265
  if (params.taskHint === "text-generation" && params.chatCompletion) {
231
- return url + `/v1/chat/completions`;
266
+ return `${baseUrl}/models/${params.model}/v1/chat/completions`;
232
267
  }
233
- return url;
268
+ return `${baseUrl}/models/${params.model}`;
234
269
  }
235
270
  }
236
271
  }
@@ -0,0 +1,18 @@
1
+ export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
2
+
3
+ /**
4
+ * See the registered mapping of HF model ID => Black Forest Labs model ID here:
5
+ *
6
+ * https://huggingface.co/api/partners/blackforestlabs/models
7
+ *
8
+ * This is a publicly available mapping.
9
+ *
10
+ * If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11
+ * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12
+ *
13
+ * - If you work at Black Forest Labs and want to update this mapping, please use the model mapping API we provide on huggingface.co
14
+ * - If you're a community member and want to add a new supported HF model to Black Forest Labs, please open an issue on the present repo
15
+ * and we will tag Black Forest Labs team members.
16
+ *
17
+ * Thanks!
18
+ */
@@ -16,11 +16,14 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
16
16
  * Example:
17
17
  * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
18
18
  */
19
+ "black-forest-labs": {},
19
20
  "fal-ai": {},
20
21
  "fireworks-ai": {},
21
22
  "hf-inference": {},
23
+ hyperbolic: {},
22
24
  nebius: {},
23
25
  replicate: {},
24
26
  sambanova: {},
25
27
  together: {},
28
+ novita: {},
26
29
  };
@@ -0,0 +1,18 @@
1
+ export const HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
2
+
3
+ /**
4
+ * See the registered mapping of HF model ID => Hyperbolic model ID here:
5
+ *
6
+ * https://huggingface.co/api/partners/hyperbolic/models
7
+ *
8
+ * This is a publicly available mapping.
9
+ *
10
+ * If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11
+ * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12
+ *
13
+ * - If you work at Hyperbolic and want to update this mapping, please use the model mapping API we provide on huggingface.co
14
+ * - If you're a community member and want to add a new supported HF model to Hyperbolic, please open an issue on the present repo
15
+ * and we will tag Hyperbolic team members.
16
+ *
17
+ * Thanks!
18
+ */
@@ -0,0 +1,18 @@
1
+ export const NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
2
+
3
+ /**
4
+ * See the registered mapping of HF model ID => Novita model ID here:
5
+ *
6
+ * https://huggingface.co/api/partners/novita/models
7
+ *
8
+ * This is a publicly available mapping.
9
+ *
10
+ * If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11
+ * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12
+ *
13
+ * - If you work at Novita and want to update this mapping, please use the model mapping API we provide on huggingface.co
14
+ * - If you're a community member and want to add a new supported HF model to Novita, please open an issue on the present repo
15
+ * and we will tag Novita team members.
16
+ *
17
+ * Thanks!
18
+ */
@@ -3,6 +3,7 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";
3
3
  import type { BaseArgs, InferenceProvider, Options } from "../../types";
4
4
  import { omit } from "../../utils/omit";
5
5
  import { request } from "../custom/request";
6
+ import { delay } from "../../utils/delay";
6
7
 
7
8
  export type TextToImageArgs = BaseArgs & TextToImageInput;
8
9
 
@@ -14,6 +15,14 @@ interface Base64ImageGeneration {
14
15
  interface OutputUrlImageGeneration {
15
16
  output: string[];
16
17
  }
18
+ interface HyperbolicTextToImageOutput {
19
+ images: Array<{ image: string }>;
20
+ }
21
+
22
+ interface BlackForestLabsResponse {
23
+ id: string;
24
+ polling_url: string;
25
+ }
17
26
 
18
27
  function getResponseFormatArg(provider: InferenceProvider) {
19
28
  switch (provider) {
@@ -44,16 +53,36 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
44
53
  ...getResponseFormatArg(args.provider),
45
54
  prompt: args.inputs,
46
55
  };
47
- const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(payload, {
56
+ const res = await request<
57
+ | TextToImageOutput
58
+ | Base64ImageGeneration
59
+ | OutputUrlImageGeneration
60
+ | BlackForestLabsResponse
61
+ | HyperbolicTextToImageOutput
62
+ >(payload, {
48
63
  ...options,
49
64
  taskHint: "text-to-image",
50
65
  });
51
66
 
52
67
  if (res && typeof res === "object") {
68
+ if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
69
+ return await pollBflResponse(res.polling_url);
70
+ }
53
71
  if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
54
72
  const image = await fetch(res.images[0].url);
55
73
  return await image.blob();
56
74
  }
75
+ if (
76
+ args.provider === "hyperbolic" &&
77
+ "images" in res &&
78
+ Array.isArray(res.images) &&
79
+ res.images[0] &&
80
+ typeof res.images[0].image === "string"
81
+ ) {
82
+ const base64Response = await fetch(`data:image/jpeg;base64,${res.images[0].image}`);
83
+ const blob = await base64Response.blob();
84
+ return blob;
85
+ }
57
86
  if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
58
87
  const base64Data = res.data[0].b64_json;
59
88
  const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
@@ -72,3 +101,33 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
72
101
  }
73
102
  return res;
74
103
  }
104
+
105
+ async function pollBflResponse(url: string): Promise<Blob> {
106
+ const urlObj = new URL(url);
107
+ for (let step = 0; step < 5; step++) {
108
+ await delay(1000);
109
+ console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
110
+ urlObj.searchParams.set("attempt", step.toString(10));
111
+ const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
112
+ if (!resp.ok) {
113
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
114
+ }
115
+ const payload = await resp.json();
116
+ if (
117
+ typeof payload === "object" &&
118
+ payload &&
119
+ "status" in payload &&
120
+ typeof payload.status === "string" &&
121
+ payload.status === "Ready" &&
122
+ "result" in payload &&
123
+ typeof payload.result === "object" &&
124
+ payload.result &&
125
+ "sample" in payload.result &&
126
+ typeof payload.result.sample === "string"
127
+ ) {
128
+ const image = await fetch(payload.result.sample);
129
+ return await image.blob();
130
+ }
131
+ }
132
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
133
+ }
@@ -1,5 +1,4 @@
1
1
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
- import { getDefaultTask } from "../../lib/getDefaultTask";
3
2
  import type { BaseArgs, Options } from "../../types";
4
3
  import { request } from "../custom/request";
5
4
 
@@ -25,12 +24,9 @@ export async function featureExtraction(
25
24
  args: FeatureExtractionArgs,
26
25
  options?: Options
27
26
  ): Promise<FeatureExtractionOutput> {
28
- const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : undefined;
29
-
30
27
  const res = await request<FeatureExtractionOutput>(args, {
31
28
  ...options,
32
29
  taskHint: "feature-extraction",
33
- ...(defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" }),
34
30
  });
35
31
  let isValidOutput = true;
36
32
 
@@ -1,6 +1,5 @@
1
1
  import type { SentenceSimilarityInput, SentenceSimilarityOutput } from "@huggingface/tasks";
2
2
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
3
- import { getDefaultTask } from "../../lib/getDefaultTask";
4
3
  import type { BaseArgs, Options } from "../../types";
5
4
  import { request } from "../custom/request";
6
5
  import { omit } from "../../utils/omit";
@@ -14,11 +13,9 @@ export async function sentenceSimilarity(
14
13
  args: SentenceSimilarityArgs,
15
14
  options?: Options
16
15
  ): Promise<SentenceSimilarityOutput> {
17
- const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : undefined;
18
16
  const res = await request<SentenceSimilarityOutput>(prepareInput(args), {
19
17
  ...options,
20
18
  taskHint: "sentence-similarity",
21
- ...(defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }),
22
19
  });
23
20
 
24
21
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
@@ -8,6 +8,7 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";
8
8
  import type { BaseArgs, Options } from "../../types";
9
9
  import { toArray } from "../../utils/toArray";
10
10
  import { request } from "../custom/request";
11
+ import { omit } from "../../utils/omit";
11
12
 
12
13
  export type { TextGenerationInput, TextGenerationOutput };
13
14
 
@@ -21,6 +22,12 @@ interface TogeteherTextCompletionOutput extends Omit<ChatCompletionOutput, "choi
21
22
  }>;
22
23
  }
23
24
 
25
+ interface HyperbolicTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
26
+ choices: Array<{
27
+ message: { content: string };
28
+ }>;
29
+ }
30
+
24
31
  /**
25
32
  * Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
26
33
  */
@@ -43,6 +50,30 @@ export async function textGeneration(
43
50
  return {
44
51
  generated_text: completion.text,
45
52
  };
53
+ } else if (args.provider === "hyperbolic") {
54
+ const payload = {
55
+ messages: [{ content: args.inputs, role: "user" }],
56
+ ...(args.parameters
57
+ ? {
58
+ max_tokens: args.parameters.max_new_tokens,
59
+ ...omit(args.parameters, "max_new_tokens"),
60
+ }
61
+ : undefined),
62
+ ...omit(args, ["inputs", "parameters"]),
63
+ };
64
+ const raw = await request<HyperbolicTextCompletionOutput>(payload, {
65
+ ...options,
66
+ taskHint: "text-generation",
67
+ });
68
+ const isValidOutput =
69
+ typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
70
+ if (!isValidOutput) {
71
+ throw new InferenceOutputError("Expected ChatCompletionOutput");
72
+ }
73
+ const completion = raw.choices[0];
74
+ return {
75
+ generated_text: completion.message.content,
76
+ };
46
77
  } else {
47
78
  const res = toArray(
48
79
  await request<TextGenerationOutput | TextGenerationOutput[]>(args, {
package/src/types.ts CHANGED
@@ -29,14 +29,18 @@ export interface Options {
29
29
  export type InferenceTask = Exclude<PipelineType, "other">;
30
30
 
31
31
  export const INFERENCE_PROVIDERS = [
32
+ "black-forest-labs",
32
33
  "fal-ai",
33
34
  "fireworks-ai",
34
- "nebius",
35
35
  "hf-inference",
36
+ "hyperbolic",
37
+ "nebius",
38
+ "novita",
36
39
  "replicate",
37
40
  "sambanova",
38
41
  "together",
39
42
  ] as const;
43
+
40
44
  export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
41
45
 
42
46
  export interface BaseArgs {
@@ -0,0 +1,5 @@
1
+ export function delay(ms: number): Promise<void> {
2
+ return new Promise((resolve) => {
3
+ setTimeout(() => resolve(), ms);
4
+ });
5
+ }