@huggingface/inference 3.3.6 → 3.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. package/dist/index.cjs +315 -174
  2. package/dist/index.js +315 -174
  3. package/dist/src/lib/getProviderModelId.d.ts +1 -1
  4. package/dist/src/lib/getProviderModelId.d.ts.map +1 -1
  5. package/dist/src/lib/makeRequestOptions.d.ts +2 -2
  6. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  7. package/dist/src/providers/black-forest-labs.d.ts +2 -1
  8. package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
  9. package/dist/src/providers/fal-ai.d.ts +2 -1
  10. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  11. package/dist/src/providers/fireworks-ai.d.ts +2 -1
  12. package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
  13. package/dist/src/providers/hf-inference.d.ts +3 -0
  14. package/dist/src/providers/hf-inference.d.ts.map +1 -0
  15. package/dist/src/providers/hyperbolic.d.ts +2 -1
  16. package/dist/src/providers/hyperbolic.d.ts.map +1 -1
  17. package/dist/src/providers/nebius.d.ts +2 -1
  18. package/dist/src/providers/nebius.d.ts.map +1 -1
  19. package/dist/src/providers/novita.d.ts +2 -1
  20. package/dist/src/providers/novita.d.ts.map +1 -1
  21. package/dist/src/providers/replicate.d.ts +3 -1
  22. package/dist/src/providers/replicate.d.ts.map +1 -1
  23. package/dist/src/providers/sambanova.d.ts +2 -1
  24. package/dist/src/providers/sambanova.d.ts.map +1 -1
  25. package/dist/src/providers/together.d.ts +2 -1
  26. package/dist/src/providers/together.d.ts.map +1 -1
  27. package/dist/src/tasks/custom/request.d.ts +2 -4
  28. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  29. package/dist/src/tasks/custom/streamingRequest.d.ts +2 -4
  30. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  31. package/dist/src/tasks/nlp/featureExtraction.d.ts +2 -9
  32. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  33. package/dist/src/types.d.ts +24 -3
  34. package/dist/src/types.d.ts.map +1 -1
  35. package/package.json +2 -2
  36. package/src/lib/getProviderModelId.ts +4 -4
  37. package/src/lib/makeRequestOptions.ts +72 -186
  38. package/src/providers/black-forest-labs.ts +26 -2
  39. package/src/providers/consts.ts +1 -1
  40. package/src/providers/fal-ai.ts +24 -2
  41. package/src/providers/fireworks-ai.ts +28 -2
  42. package/src/providers/hf-inference.ts +43 -0
  43. package/src/providers/hyperbolic.ts +28 -2
  44. package/src/providers/nebius.ts +34 -2
  45. package/src/providers/novita.ts +31 -2
  46. package/src/providers/replicate.ts +30 -2
  47. package/src/providers/sambanova.ts +28 -2
  48. package/src/providers/together.ts +34 -2
  49. package/src/tasks/audio/audioClassification.ts +1 -1
  50. package/src/tasks/audio/audioToAudio.ts +1 -1
  51. package/src/tasks/audio/automaticSpeechRecognition.ts +1 -1
  52. package/src/tasks/audio/textToSpeech.ts +1 -1
  53. package/src/tasks/custom/request.ts +2 -4
  54. package/src/tasks/custom/streamingRequest.ts +2 -4
  55. package/src/tasks/cv/imageClassification.ts +1 -1
  56. package/src/tasks/cv/imageSegmentation.ts +1 -1
  57. package/src/tasks/cv/imageToImage.ts +1 -1
  58. package/src/tasks/cv/imageToText.ts +1 -1
  59. package/src/tasks/cv/objectDetection.ts +1 -1
  60. package/src/tasks/cv/textToImage.ts +1 -1
  61. package/src/tasks/cv/textToVideo.ts +1 -1
  62. package/src/tasks/cv/zeroShotImageClassification.ts +1 -1
  63. package/src/tasks/multimodal/documentQuestionAnswering.ts +1 -1
  64. package/src/tasks/multimodal/visualQuestionAnswering.ts +1 -1
  65. package/src/tasks/nlp/chatCompletion.ts +1 -1
  66. package/src/tasks/nlp/chatCompletionStream.ts +1 -1
  67. package/src/tasks/nlp/featureExtraction.ts +3 -10
  68. package/src/tasks/nlp/fillMask.ts +1 -1
  69. package/src/tasks/nlp/questionAnswering.ts +1 -1
  70. package/src/tasks/nlp/sentenceSimilarity.ts +1 -1
  71. package/src/tasks/nlp/summarization.ts +1 -1
  72. package/src/tasks/nlp/tableQuestionAnswering.ts +1 -1
  73. package/src/tasks/nlp/textClassification.ts +1 -1
  74. package/src/tasks/nlp/textGeneration.ts +3 -3
  75. package/src/tasks/nlp/textGenerationStream.ts +1 -1
  76. package/src/tasks/nlp/tokenClassification.ts +1 -1
  77. package/src/tasks/nlp/translation.ts +1 -1
  78. package/src/tasks/nlp/zeroShotClassification.ts +1 -1
  79. package/src/tasks/tabular/tabularClassification.ts +1 -1
  80. package/src/tasks/tabular/tabularRegression.ts +1 -1
  81. package/src/types.ts +28 -2
@@ -1,15 +1,15 @@
1
1
  import { HF_HUB_URL, HF_ROUTER_URL } from "../config";
2
- import { FAL_AI_API_BASE_URL } from "../providers/fal-ai";
3
- import { NEBIUS_API_BASE_URL } from "../providers/nebius";
4
- import { REPLICATE_API_BASE_URL } from "../providers/replicate";
5
- import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
6
- import { TOGETHER_API_BASE_URL } from "../providers/together";
7
- import { NOVITA_API_BASE_URL } from "../providers/novita";
8
- import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
9
- import { HYPERBOLIC_API_BASE_URL } from "../providers/hyperbolic";
10
- import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/black-forest-labs";
11
- import type { InferenceProvider } from "../types";
12
- import type { InferenceTask, Options, RequestArgs } from "../types";
2
+ import { BLACK_FOREST_LABS_CONFIG } from "../providers/black-forest-labs";
3
+ import { FAL_AI_CONFIG } from "../providers/fal-ai";
4
+ import { FIREWORKS_AI_CONFIG } from "../providers/fireworks-ai";
5
+ import { HF_INFERENCE_CONFIG } from "../providers/hf-inference";
6
+ import { HYPERBOLIC_CONFIG } from "../providers/hyperbolic";
7
+ import { NEBIUS_CONFIG } from "../providers/nebius";
8
+ import { NOVITA_CONFIG } from "../providers/novita";
9
+ import { REPLICATE_CONFIG } from "../providers/replicate";
10
+ import { SAMBANOVA_CONFIG } from "../providers/sambanova";
11
+ import { TOGETHER_CONFIG } from "../providers/together";
12
+ import type { InferenceProvider, InferenceTask, Options, ProviderConfig, RequestArgs } from "../types";
13
13
  import { isUrl } from "./isUrl";
14
14
  import { version as packageVersion, name as packageName } from "../../package.json";
15
15
  import { getProviderModelId } from "./getProviderModelId";
@@ -22,6 +22,22 @@ const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
22
22
  */
23
23
  let tasks: Record<string, { models: { id: string }[] }> | null = null;
24
24
 
25
+ /**
26
+ * Config to define how to serialize requests for each provider
27
+ */
28
+ const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
29
+ "black-forest-labs": BLACK_FOREST_LABS_CONFIG,
30
+ "fal-ai": FAL_AI_CONFIG,
31
+ "fireworks-ai": FIREWORKS_AI_CONFIG,
32
+ "hf-inference": HF_INFERENCE_CONFIG,
33
+ hyperbolic: HYPERBOLIC_CONFIG,
34
+ nebius: NEBIUS_CONFIG,
35
+ novita: NOVITA_CONFIG,
36
+ replicate: REPLICATE_CONFIG,
37
+ sambanova: SAMBANOVA_CONFIG,
38
+ together: TOGETHER_CONFIG,
39
+ };
40
+
25
41
  /**
26
42
  * Helper that prepares request arguments
27
43
  */
@@ -31,16 +47,16 @@ export async function makeRequestOptions(
31
47
  stream?: boolean;
32
48
  },
33
49
  options?: Options & {
34
- /** To load default model if needed */
35
- taskHint?: InferenceTask;
50
+ /** In most cases (unless we pass a endpointUrl) we know the task */
51
+ task?: InferenceTask;
36
52
  chatCompletion?: boolean;
37
53
  }
38
54
  ): Promise<{ url: string; info: RequestInit }> {
39
55
  const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
40
- let otherArgs = remainingArgs;
41
56
  const provider = maybeProvider ?? "hf-inference";
57
+ const providerConfig = providerConfigs[provider];
42
58
 
43
- const { includeCredentials, taskHint, chatCompletion } = options ?? {};
59
+ const { includeCredentials, task, chatCompletion, signal } = options ?? {};
44
60
 
45
61
  if (endpointUrl && provider !== "hf-inference") {
46
62
  throw new Error(`Cannot use endpointUrl with a third-party provider.`);
@@ -48,13 +64,16 @@ export async function makeRequestOptions(
48
64
  if (maybeModel && isUrl(maybeModel)) {
49
65
  throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
50
66
  }
51
- if (!maybeModel && !taskHint) {
67
+ if (!maybeModel && !task) {
52
68
  throw new Error("No model provided, and no task has been specified.");
53
69
  }
70
+ if (!providerConfig) {
71
+ throw new Error(`No provider config found for provider ${provider}`);
72
+ }
54
73
  // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
55
- const hfModel = maybeModel ?? (await loadDefaultModel(taskHint!));
74
+ const hfModel = maybeModel ?? (await loadDefaultModel(task!));
56
75
  const model = await getProviderModelId({ model: hfModel, provider }, args, {
57
- taskHint,
76
+ task,
58
77
  chatCompletion,
59
78
  fetch: options?.fetch,
60
79
  });
@@ -68,44 +87,52 @@ export async function makeRequestOptions(
68
87
  ? "credentials-include"
69
88
  : "none";
70
89
 
90
+ // Make URL
71
91
  const url = endpointUrl
72
92
  ? chatCompletion
73
93
  ? endpointUrl + `/v1/chat/completions`
74
94
  : endpointUrl
75
- : makeUrl({
76
- authMethod,
77
- chatCompletion: chatCompletion ?? false,
95
+ : providerConfig.makeUrl({
96
+ baseUrl:
97
+ authMethod !== "provider-key"
98
+ ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider)
99
+ : providerConfig.baseUrl,
78
100
  model,
79
- provider: provider ?? "hf-inference",
80
- taskHint,
101
+ chatCompletion,
102
+ task,
81
103
  });
82
104
 
83
- const headers: Record<string, string> = {};
84
- if (accessToken) {
85
- if (provider === "fal-ai" && authMethod === "provider-key") {
86
- headers["Authorization"] = `Key ${accessToken}`;
87
- } else if (provider === "black-forest-labs" && authMethod === "provider-key") {
88
- headers["X-Key"] = accessToken;
89
- } else {
90
- headers["Authorization"] = `Bearer ${accessToken}`;
91
- }
92
- }
93
-
94
- // e.g. @huggingface/inference/3.1.3
95
- const ownUserAgent = `${packageName}/${packageVersion}`;
96
- headers["User-Agent"] = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : undefined]
97
- .filter((x) => x !== undefined)
98
- .join(" ");
99
-
105
+ // Make headers
100
106
  const binary = "data" in args && !!args.data;
107
+ const headers = providerConfig.makeHeaders({
108
+ accessToken,
109
+ authMethod,
110
+ });
101
111
 
112
+ // Add content-type to headers
102
113
  if (!binary) {
103
114
  headers["Content-Type"] = "application/json";
104
115
  }
105
116
 
106
- if (provider === "replicate") {
107
- headers["Prefer"] = "wait";
108
- }
117
+ // Add user-agent to headers
118
+ // e.g. @huggingface/inference/3.1.3
119
+ const ownUserAgent = `${packageName}/${packageVersion}`;
120
+ const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : undefined]
121
+ .filter((x) => x !== undefined)
122
+ .join(" ");
123
+ headers["User-Agent"] = userAgent;
124
+
125
+ // Make body
126
+ const body = binary
127
+ ? args.data
128
+ : JSON.stringify(
129
+ providerConfig.makeBody({
130
+ args: remainingArgs as Record<string, unknown>,
131
+ model,
132
+ task,
133
+ chatCompletion,
134
+ })
135
+ );
109
136
 
110
137
  /**
111
138
  * For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
@@ -117,158 +144,17 @@ export async function makeRequestOptions(
117
144
  credentials = "include";
118
145
  }
119
146
 
120
- /**
121
- * Replicate models wrap all inputs inside { input: ... }
122
- * Versioned Replicate models in the format `owner/model:version` expect the version in the body
123
- */
124
- if (provider === "replicate") {
125
- const version = model.includes(":") ? model.split(":")[1] : undefined;
126
- (otherArgs as unknown) = { input: otherArgs, version };
127
- }
128
-
129
147
  const info: RequestInit = {
130
148
  headers,
131
149
  method: "POST",
132
- body: binary
133
- ? args.data
134
- : JSON.stringify({
135
- ...otherArgs,
136
- ...(taskHint === "text-to-image" && provider === "hyperbolic"
137
- ? { model_name: model }
138
- : chatCompletion || provider === "together" || provider === "nebius" || provider === "hyperbolic"
139
- ? { model }
140
- : undefined),
141
- }),
150
+ body,
142
151
  ...(credentials ? { credentials } : undefined),
143
- signal: options?.signal,
152
+ signal,
144
153
  };
145
154
 
146
155
  return { url, info };
147
156
  }
148
157
 
149
- function makeUrl(params: {
150
- authMethod: "none" | "hf-token" | "credentials-include" | "provider-key";
151
- chatCompletion: boolean;
152
- model: string;
153
- provider: InferenceProvider;
154
- taskHint: InferenceTask | undefined;
155
- }): string {
156
- if (params.authMethod === "none" && params.provider !== "hf-inference") {
157
- throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
158
- }
159
-
160
- const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
161
- switch (params.provider) {
162
- case "black-forest-labs": {
163
- const baseUrl = shouldProxy
164
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
165
- : BLACKFORESTLABS_AI_API_BASE_URL;
166
- return `${baseUrl}/${params.model}`;
167
- }
168
- case "fal-ai": {
169
- const baseUrl = shouldProxy
170
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
171
- : FAL_AI_API_BASE_URL;
172
- return `${baseUrl}/${params.model}`;
173
- }
174
- case "nebius": {
175
- const baseUrl = shouldProxy
176
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
177
- : NEBIUS_API_BASE_URL;
178
-
179
- if (params.taskHint === "text-to-image") {
180
- return `${baseUrl}/v1/images/generations`;
181
- }
182
- if (params.taskHint === "text-generation") {
183
- if (params.chatCompletion) {
184
- return `${baseUrl}/v1/chat/completions`;
185
- }
186
- return `${baseUrl}/v1/completions`;
187
- }
188
- return baseUrl;
189
- }
190
- case "replicate": {
191
- const baseUrl = shouldProxy
192
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
193
- : REPLICATE_API_BASE_URL;
194
- if (params.model.includes(":")) {
195
- /// Versioned model
196
- return `${baseUrl}/v1/predictions`;
197
- }
198
- /// Evergreen / Canonical model
199
- return `${baseUrl}/v1/models/${params.model}/predictions`;
200
- }
201
- case "sambanova": {
202
- const baseUrl = shouldProxy
203
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
204
- : SAMBANOVA_API_BASE_URL;
205
- /// Sambanova API matches OpenAI-like APIs: model is defined in the request body
206
- if (params.taskHint === "text-generation" && params.chatCompletion) {
207
- return `${baseUrl}/v1/chat/completions`;
208
- }
209
- return baseUrl;
210
- }
211
- case "together": {
212
- const baseUrl = shouldProxy
213
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
214
- : TOGETHER_API_BASE_URL;
215
- /// Together API matches OpenAI-like APIs: model is defined in the request body
216
- if (params.taskHint === "text-to-image") {
217
- return `${baseUrl}/v1/images/generations`;
218
- }
219
- if (params.taskHint === "text-generation") {
220
- if (params.chatCompletion) {
221
- return `${baseUrl}/v1/chat/completions`;
222
- }
223
- return `${baseUrl}/v1/completions`;
224
- }
225
- return baseUrl;
226
- }
227
-
228
- case "fireworks-ai": {
229
- const baseUrl = shouldProxy
230
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
231
- : FIREWORKS_AI_API_BASE_URL;
232
- if (params.taskHint === "text-generation" && params.chatCompletion) {
233
- return `${baseUrl}/v1/chat/completions`;
234
- }
235
- return baseUrl;
236
- }
237
- case "hyperbolic": {
238
- const baseUrl = shouldProxy
239
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
240
- : HYPERBOLIC_API_BASE_URL;
241
-
242
- if (params.taskHint === "text-to-image") {
243
- return `${baseUrl}/v1/images/generations`;
244
- }
245
- return `${baseUrl}/v1/chat/completions`;
246
- }
247
- case "novita": {
248
- const baseUrl = shouldProxy
249
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
250
- : NOVITA_API_BASE_URL;
251
- if (params.taskHint === "text-generation") {
252
- if (params.chatCompletion) {
253
- return `${baseUrl}/chat/completions`;
254
- }
255
- return `${baseUrl}/completions`;
256
- }
257
- return baseUrl;
258
- }
259
- default: {
260
- const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
261
- if (params.taskHint && ["feature-extraction", "sentence-similarity"].includes(params.taskHint)) {
262
- /// when deployed on hf-inference, those two tasks are automatically compatible with one another.
263
- return `${baseUrl}/pipeline/${params.taskHint}/${params.model}`;
264
- }
265
- if (params.taskHint === "text-generation" && params.chatCompletion) {
266
- return `${baseUrl}/models/${params.model}/v1/chat/completions`;
267
- }
268
- return `${baseUrl}/models/${params.model}`;
269
- }
270
- }
271
- }
272
158
  async function loadDefaultModel(task: InferenceTask): Promise<string> {
273
159
  if (!tasks) {
274
160
  tasks = await loadTaskInfo();
@@ -1,5 +1,3 @@
1
- export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
2
-
3
1
  /**
4
2
  * See the registered mapping of HF model ID => Black Forest Labs model ID here:
5
3
  *
@@ -16,3 +14,29 @@ export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
16
14
  *
17
15
  * Thanks!
18
16
  */
17
+ import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18
+
19
+ const BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
20
+
21
+ const makeBody = (params: BodyParams): Record<string, unknown> => {
22
+ return params.args;
23
+ };
24
+
25
+ const makeHeaders = (params: HeaderParams): Record<string, string> => {
26
+ if (params.authMethod === "provider-key") {
27
+ return { "X-Key": `${params.accessToken}` };
28
+ } else {
29
+ return { Authorization: `Bearer ${params.accessToken}` };
30
+ }
31
+ };
32
+
33
+ const makeUrl = (params: UrlParams): string => {
34
+ return `${params.baseUrl}/${params.model}`;
35
+ };
36
+
37
+ export const BLACK_FOREST_LABS_CONFIG: ProviderConfig = {
38
+ baseUrl: BLACK_FOREST_LABS_AI_API_BASE_URL,
39
+ makeBody,
40
+ makeHeaders,
41
+ makeUrl,
42
+ };
@@ -22,8 +22,8 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
22
22
  "hf-inference": {},
23
23
  hyperbolic: {},
24
24
  nebius: {},
25
+ novita: {},
25
26
  replicate: {},
26
27
  sambanova: {},
27
28
  together: {},
28
- novita: {},
29
29
  };
@@ -1,5 +1,3 @@
1
- export const FAL_AI_API_BASE_URL = "https://fal.run";
2
-
3
1
  /**
4
2
  * See the registered mapping of HF model ID => Fal model ID here:
5
3
  *
@@ -16,3 +14,27 @@ export const FAL_AI_API_BASE_URL = "https://fal.run";
16
14
  *
17
15
  * Thanks!
18
16
  */
17
+ import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18
+
19
+ const FAL_AI_API_BASE_URL = "https://fal.run";
20
+
21
+ const makeBody = (params: BodyParams): Record<string, unknown> => {
22
+ return params.args;
23
+ };
24
+
25
+ const makeHeaders = (params: HeaderParams): Record<string, string> => {
26
+ return {
27
+ Authorization: params.authMethod === "provider-key" ? `Key ${params.accessToken}` : `Bearer ${params.accessToken}`,
28
+ };
29
+ };
30
+
31
+ const makeUrl = (params: UrlParams): string => {
32
+ return `${params.baseUrl}/${params.model}`;
33
+ };
34
+
35
+ export const FAL_AI_CONFIG: ProviderConfig = {
36
+ baseUrl: FAL_AI_API_BASE_URL,
37
+ makeBody,
38
+ makeHeaders,
39
+ makeUrl,
40
+ };
@@ -1,5 +1,3 @@
1
- export const FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
2
-
3
1
  /**
4
2
  * See the registered mapping of HF model ID => Fireworks model ID here:
5
3
  *
@@ -16,3 +14,31 @@ export const FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
16
14
  *
17
15
  * Thanks!
18
16
  */
17
+ import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18
+
19
+ const FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
20
+
21
+ const makeBody = (params: BodyParams): Record<string, unknown> => {
22
+ return {
23
+ ...params.args,
24
+ ...(params.chatCompletion ? { model: params.model } : undefined),
25
+ };
26
+ };
27
+
28
+ const makeHeaders = (params: HeaderParams): Record<string, string> => {
29
+ return { Authorization: `Bearer ${params.accessToken}` };
30
+ };
31
+
32
+ const makeUrl = (params: UrlParams): string => {
33
+ if (params.task === "text-generation" && params.chatCompletion) {
34
+ return `${params.baseUrl}/v1/chat/completions`;
35
+ }
36
+ return params.baseUrl;
37
+ };
38
+
39
+ export const FIREWORKS_AI_CONFIG: ProviderConfig = {
40
+ baseUrl: FIREWORKS_AI_API_BASE_URL,
41
+ makeBody,
42
+ makeHeaders,
43
+ makeUrl,
44
+ };
@@ -0,0 +1,43 @@
1
+ /**
2
+ * HF-Inference do not have a mapping since all models use IDs from the Hub.
3
+ *
4
+ * If you want to try to run inference for a new model locally before it's registered on huggingface.co,
5
+ * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
6
+ *
7
+ * - If you work at HF and want to update this mapping, please use the model mapping API we provide on huggingface.co
8
+ * - If you're a community member and want to add a new supported HF model to HF, please open an issue on the present repo
9
+ * and we will tag HF team members.
10
+ *
11
+ * Thanks!
12
+ */
13
+ import { HF_ROUTER_URL } from "../config";
14
+ import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
15
+
16
+ const makeBody = (params: BodyParams): Record<string, unknown> => {
17
+ return {
18
+ ...params.args,
19
+ ...(params.chatCompletion ? { model: params.model } : undefined),
20
+ };
21
+ };
22
+
23
+ const makeHeaders = (params: HeaderParams): Record<string, string> => {
24
+ return { Authorization: `Bearer ${params.accessToken}` };
25
+ };
26
+
27
+ const makeUrl = (params: UrlParams): string => {
28
+ if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
29
+ /// when deployed on hf-inference, those two tasks are automatically compatible with one another.
30
+ return `${params.baseUrl}/pipeline/${params.task}/${params.model}`;
31
+ }
32
+ if (params.task === "text-generation" && params.chatCompletion) {
33
+ return `${params.baseUrl}/models/${params.model}/v1/chat/completions`;
34
+ }
35
+ return `${params.baseUrl}/models/${params.model}`;
36
+ };
37
+
38
+ export const HF_INFERENCE_CONFIG: ProviderConfig = {
39
+ baseUrl: `${HF_ROUTER_URL}/hf-inference`,
40
+ makeBody,
41
+ makeHeaders,
42
+ makeUrl,
43
+ };
@@ -1,5 +1,3 @@
1
- export const HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
2
-
3
1
  /**
4
2
  * See the registered mapping of HF model ID => Hyperbolic model ID here:
5
3
  *
@@ -16,3 +14,31 @@ export const HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
16
14
  *
17
15
  * Thanks!
18
16
  */
17
+ import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18
+
19
+ const HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
20
+
21
+ const makeBody = (params: BodyParams): Record<string, unknown> => {
22
+ return {
23
+ ...params.args,
24
+ ...(params.task === "text-to-image" ? { model_name: params.model } : { model: params.model }),
25
+ };
26
+ };
27
+
28
+ const makeHeaders = (params: HeaderParams): Record<string, string> => {
29
+ return { Authorization: `Bearer ${params.accessToken}` };
30
+ };
31
+
32
+ const makeUrl = (params: UrlParams): string => {
33
+ if (params.task === "text-to-image") {
34
+ return `${params.baseUrl}/v1/images/generations`;
35
+ }
36
+ return `${params.baseUrl}/v1/chat/completions`;
37
+ };
38
+
39
+ export const HYPERBOLIC_CONFIG: ProviderConfig = {
40
+ baseUrl: HYPERBOLIC_API_BASE_URL,
41
+ makeBody,
42
+ makeHeaders,
43
+ makeUrl,
44
+ };
@@ -1,5 +1,3 @@
1
- export const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
2
-
3
1
  /**
4
2
  * See the registered mapping of HF model ID => Nebius model ID here:
5
3
  *
@@ -16,3 +14,37 @@ export const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
16
14
  *
17
15
  * Thanks!
18
16
  */
17
+ import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18
+
19
+ const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
20
+
21
+ const makeBody = (params: BodyParams): Record<string, unknown> => {
22
+ return {
23
+ ...params.args,
24
+ model: params.model,
25
+ };
26
+ };
27
+
28
+ const makeHeaders = (params: HeaderParams): Record<string, string> => {
29
+ return { Authorization: `Bearer ${params.accessToken}` };
30
+ };
31
+
32
+ const makeUrl = (params: UrlParams): string => {
33
+ if (params.task === "text-to-image") {
34
+ return `${params.baseUrl}/v1/images/generations`;
35
+ }
36
+ if (params.task === "text-generation") {
37
+ if (params.chatCompletion) {
38
+ return `${params.baseUrl}/v1/chat/completions`;
39
+ }
40
+ return `${params.baseUrl}/v1/completions`;
41
+ }
42
+ return params.baseUrl;
43
+ };
44
+
45
+ export const NEBIUS_CONFIG: ProviderConfig = {
46
+ baseUrl: NEBIUS_API_BASE_URL,
47
+ makeBody,
48
+ makeHeaders,
49
+ makeUrl,
50
+ };
@@ -1,5 +1,3 @@
1
- export const NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
2
-
3
1
  /**
4
2
  * See the registered mapping of HF model ID => Novita model ID here:
5
3
  *
@@ -16,3 +14,34 @@ export const NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
16
14
  *
17
15
  * Thanks!
18
16
  */
17
+ import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18
+
19
+ const NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
20
+
21
+ const makeBody = (params: BodyParams): Record<string, unknown> => {
22
+ return {
23
+ ...params.args,
24
+ ...(params.chatCompletion ? { model: params.model } : undefined),
25
+ };
26
+ };
27
+
28
+ const makeHeaders = (params: HeaderParams): Record<string, string> => {
29
+ return { Authorization: `Bearer ${params.accessToken}` };
30
+ };
31
+
32
+ const makeUrl = (params: UrlParams): string => {
33
+ if (params.task === "text-generation") {
34
+ if (params.chatCompletion) {
35
+ return `${params.baseUrl}/chat/completions`;
36
+ }
37
+ return `${params.baseUrl}/completions`;
38
+ }
39
+ return params.baseUrl;
40
+ };
41
+
42
+ export const NOVITA_CONFIG: ProviderConfig = {
43
+ baseUrl: NOVITA_API_BASE_URL,
44
+ makeBody,
45
+ makeHeaders,
46
+ makeUrl,
47
+ };
@@ -1,5 +1,3 @@
1
- export const REPLICATE_API_BASE_URL = "https://api.replicate.com";
2
-
3
1
  /**
4
2
  * See the registered mapping of HF model ID => Replicate model ID here:
5
3
  *
@@ -16,3 +14,33 @@ export const REPLICATE_API_BASE_URL = "https://api.replicate.com";
16
14
  *
17
15
  * Thanks!
18
16
  */
17
+ import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18
+
19
+ export const REPLICATE_API_BASE_URL = "https://api.replicate.com";
20
+
21
+ const makeBody = (params: BodyParams): Record<string, unknown> => {
22
+ return {
23
+ input: params.args,
24
+ version: params.model.includes(":") ? params.model.split(":")[1] : undefined,
25
+ };
26
+ };
27
+
28
+ const makeHeaders = (params: HeaderParams): Record<string, string> => {
29
+ return { Authorization: `Bearer ${params.accessToken}` };
30
+ };
31
+
32
+ const makeUrl = (params: UrlParams): string => {
33
+ if (params.model.includes(":")) {
34
+ /// Versioned model
35
+ return `${params.baseUrl}/v1/predictions`;
36
+ }
37
+ /// Evergreen / Canonical model
38
+ return `${params.baseUrl}/v1/models/${params.model}/predictions`;
39
+ };
40
+
41
+ export const REPLICATE_CONFIG: ProviderConfig = {
42
+ baseUrl: REPLICATE_API_BASE_URL,
43
+ makeBody,
44
+ makeHeaders,
45
+ makeUrl,
46
+ };
@@ -1,5 +1,3 @@
1
- export const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
2
-
3
1
  /**
4
2
  * See the registered mapping of HF model ID => Sambanova model ID here:
5
3
  *
@@ -16,3 +14,31 @@ export const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
16
14
  *
17
15
  * Thanks!
18
16
  */
17
+ import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18
+
19
+ const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
20
+
21
+ const makeBody = (params: BodyParams): Record<string, unknown> => {
22
+ return {
23
+ ...params.args,
24
+ ...(params.chatCompletion ? { model: params.model } : undefined),
25
+ };
26
+ };
27
+
28
+ const makeHeaders = (params: HeaderParams): Record<string, string> => {
29
+ return { Authorization: `Bearer ${params.accessToken}` };
30
+ };
31
+
32
+ const makeUrl = (params: UrlParams): string => {
33
+ if (params.task === "text-generation" && params.chatCompletion) {
34
+ return `${params.baseUrl}/v1/chat/completions`;
35
+ }
36
+ return params.baseUrl;
37
+ };
38
+
39
+ export const SAMBANOVA_CONFIG: ProviderConfig = {
40
+ baseUrl: SAMBANOVA_API_BASE_URL,
41
+ makeBody,
42
+ makeHeaders,
43
+ makeUrl,
44
+ };