@huggingface/inference 3.3.6 → 3.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. package/README.md +2 -0
  2. package/dist/index.cjs +339 -174
  3. package/dist/index.js +339 -174
  4. package/dist/src/lib/getProviderModelId.d.ts +1 -1
  5. package/dist/src/lib/getProviderModelId.d.ts.map +1 -1
  6. package/dist/src/lib/makeRequestOptions.d.ts +2 -2
  7. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  8. package/dist/src/providers/black-forest-labs.d.ts +2 -1
  9. package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
  10. package/dist/src/providers/cohere.d.ts +19 -0
  11. package/dist/src/providers/cohere.d.ts.map +1 -0
  12. package/dist/src/providers/consts.d.ts.map +1 -1
  13. package/dist/src/providers/fal-ai.d.ts +2 -1
  14. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  15. package/dist/src/providers/fireworks-ai.d.ts +2 -1
  16. package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
  17. package/dist/src/providers/hf-inference.d.ts +3 -0
  18. package/dist/src/providers/hf-inference.d.ts.map +1 -0
  19. package/dist/src/providers/hyperbolic.d.ts +2 -1
  20. package/dist/src/providers/hyperbolic.d.ts.map +1 -1
  21. package/dist/src/providers/nebius.d.ts +2 -1
  22. package/dist/src/providers/nebius.d.ts.map +1 -1
  23. package/dist/src/providers/novita.d.ts +2 -1
  24. package/dist/src/providers/novita.d.ts.map +1 -1
  25. package/dist/src/providers/replicate.d.ts +3 -1
  26. package/dist/src/providers/replicate.d.ts.map +1 -1
  27. package/dist/src/providers/sambanova.d.ts +2 -1
  28. package/dist/src/providers/sambanova.d.ts.map +1 -1
  29. package/dist/src/providers/together.d.ts +2 -1
  30. package/dist/src/providers/together.d.ts.map +1 -1
  31. package/dist/src/tasks/custom/request.d.ts +2 -4
  32. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  33. package/dist/src/tasks/custom/streamingRequest.d.ts +2 -4
  34. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  35. package/dist/src/tasks/nlp/featureExtraction.d.ts +2 -9
  36. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  37. package/dist/src/types.d.ts +25 -4
  38. package/dist/src/types.d.ts.map +1 -1
  39. package/package.json +2 -2
  40. package/src/lib/getProviderModelId.ts +4 -4
  41. package/src/lib/makeRequestOptions.ts +74 -186
  42. package/src/providers/black-forest-labs.ts +26 -2
  43. package/src/providers/cohere.ts +42 -0
  44. package/src/providers/consts.ts +2 -1
  45. package/src/providers/fal-ai.ts +24 -2
  46. package/src/providers/fireworks-ai.ts +28 -2
  47. package/src/providers/hf-inference.ts +43 -0
  48. package/src/providers/hyperbolic.ts +28 -2
  49. package/src/providers/nebius.ts +34 -2
  50. package/src/providers/novita.ts +31 -2
  51. package/src/providers/replicate.ts +30 -2
  52. package/src/providers/sambanova.ts +28 -2
  53. package/src/providers/together.ts +34 -2
  54. package/src/tasks/audio/audioClassification.ts +1 -1
  55. package/src/tasks/audio/audioToAudio.ts +1 -1
  56. package/src/tasks/audio/automaticSpeechRecognition.ts +1 -1
  57. package/src/tasks/audio/textToSpeech.ts +1 -1
  58. package/src/tasks/custom/request.ts +2 -4
  59. package/src/tasks/custom/streamingRequest.ts +2 -4
  60. package/src/tasks/cv/imageClassification.ts +1 -1
  61. package/src/tasks/cv/imageSegmentation.ts +1 -1
  62. package/src/tasks/cv/imageToImage.ts +1 -1
  63. package/src/tasks/cv/imageToText.ts +1 -1
  64. package/src/tasks/cv/objectDetection.ts +1 -1
  65. package/src/tasks/cv/textToImage.ts +1 -1
  66. package/src/tasks/cv/textToVideo.ts +1 -1
  67. package/src/tasks/cv/zeroShotImageClassification.ts +1 -1
  68. package/src/tasks/multimodal/documentQuestionAnswering.ts +1 -1
  69. package/src/tasks/multimodal/visualQuestionAnswering.ts +1 -1
  70. package/src/tasks/nlp/chatCompletion.ts +1 -1
  71. package/src/tasks/nlp/chatCompletionStream.ts +1 -1
  72. package/src/tasks/nlp/featureExtraction.ts +3 -10
  73. package/src/tasks/nlp/fillMask.ts +1 -1
  74. package/src/tasks/nlp/questionAnswering.ts +1 -1
  75. package/src/tasks/nlp/sentenceSimilarity.ts +1 -1
  76. package/src/tasks/nlp/summarization.ts +1 -1
  77. package/src/tasks/nlp/tableQuestionAnswering.ts +1 -1
  78. package/src/tasks/nlp/textClassification.ts +1 -1
  79. package/src/tasks/nlp/textGeneration.ts +3 -3
  80. package/src/tasks/nlp/textGenerationStream.ts +1 -1
  81. package/src/tasks/nlp/tokenClassification.ts +1 -1
  82. package/src/tasks/nlp/translation.ts +1 -1
  83. package/src/tasks/nlp/zeroShotClassification.ts +1 -1
  84. package/src/tasks/tabular/tabularClassification.ts +1 -1
  85. package/src/tasks/tabular/tabularRegression.ts +1 -1
  86. package/src/types.ts +29 -2
@@ -1,15 +1,16 @@
1
1
  import { HF_HUB_URL, HF_ROUTER_URL } from "../config";
2
- import { FAL_AI_API_BASE_URL } from "../providers/fal-ai";
3
- import { NEBIUS_API_BASE_URL } from "../providers/nebius";
4
- import { REPLICATE_API_BASE_URL } from "../providers/replicate";
5
- import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
6
- import { TOGETHER_API_BASE_URL } from "../providers/together";
7
- import { NOVITA_API_BASE_URL } from "../providers/novita";
8
- import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
9
- import { HYPERBOLIC_API_BASE_URL } from "../providers/hyperbolic";
10
- import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/black-forest-labs";
11
- import type { InferenceProvider } from "../types";
12
- import type { InferenceTask, Options, RequestArgs } from "../types";
2
+ import { BLACK_FOREST_LABS_CONFIG } from "../providers/black-forest-labs";
3
+ import { COHERE_CONFIG } from "../providers/cohere";
4
+ import { FAL_AI_CONFIG } from "../providers/fal-ai";
5
+ import { FIREWORKS_AI_CONFIG } from "../providers/fireworks-ai";
6
+ import { HF_INFERENCE_CONFIG } from "../providers/hf-inference";
7
+ import { HYPERBOLIC_CONFIG } from "../providers/hyperbolic";
8
+ import { NEBIUS_CONFIG } from "../providers/nebius";
9
+ import { NOVITA_CONFIG } from "../providers/novita";
10
+ import { REPLICATE_CONFIG } from "../providers/replicate";
11
+ import { SAMBANOVA_CONFIG } from "../providers/sambanova";
12
+ import { TOGETHER_CONFIG } from "../providers/together";
13
+ import type { InferenceProvider, InferenceTask, Options, ProviderConfig, RequestArgs } from "../types";
13
14
  import { isUrl } from "./isUrl";
14
15
  import { version as packageVersion, name as packageName } from "../../package.json";
15
16
  import { getProviderModelId } from "./getProviderModelId";
@@ -22,6 +23,23 @@ const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
22
23
  */
23
24
  let tasks: Record<string, { models: { id: string }[] }> | null = null;
24
25
 
26
+ /**
27
+ * Config to define how to serialize requests for each provider
28
+ */
29
+ const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
30
+ "black-forest-labs": BLACK_FOREST_LABS_CONFIG,
31
+ cohere: COHERE_CONFIG,
32
+ "fal-ai": FAL_AI_CONFIG,
33
+ "fireworks-ai": FIREWORKS_AI_CONFIG,
34
+ "hf-inference": HF_INFERENCE_CONFIG,
35
+ hyperbolic: HYPERBOLIC_CONFIG,
36
+ nebius: NEBIUS_CONFIG,
37
+ novita: NOVITA_CONFIG,
38
+ replicate: REPLICATE_CONFIG,
39
+ sambanova: SAMBANOVA_CONFIG,
40
+ together: TOGETHER_CONFIG,
41
+ };
42
+
25
43
  /**
26
44
  * Helper that prepares request arguments
27
45
  */
@@ -31,16 +49,16 @@ export async function makeRequestOptions(
31
49
  stream?: boolean;
32
50
  },
33
51
  options?: Options & {
34
- /** To load default model if needed */
35
- taskHint?: InferenceTask;
52
+ /** In most cases (unless we pass a endpointUrl) we know the task */
53
+ task?: InferenceTask;
36
54
  chatCompletion?: boolean;
37
55
  }
38
56
  ): Promise<{ url: string; info: RequestInit }> {
39
57
  const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
40
- let otherArgs = remainingArgs;
41
58
  const provider = maybeProvider ?? "hf-inference";
59
+ const providerConfig = providerConfigs[provider];
42
60
 
43
- const { includeCredentials, taskHint, chatCompletion } = options ?? {};
61
+ const { includeCredentials, task, chatCompletion, signal } = options ?? {};
44
62
 
45
63
  if (endpointUrl && provider !== "hf-inference") {
46
64
  throw new Error(`Cannot use endpointUrl with a third-party provider.`);
@@ -48,13 +66,16 @@ export async function makeRequestOptions(
48
66
  if (maybeModel && isUrl(maybeModel)) {
49
67
  throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
50
68
  }
51
- if (!maybeModel && !taskHint) {
69
+ if (!maybeModel && !task) {
52
70
  throw new Error("No model provided, and no task has been specified.");
53
71
  }
72
+ if (!providerConfig) {
73
+ throw new Error(`No provider config found for provider ${provider}`);
74
+ }
54
75
  // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
55
- const hfModel = maybeModel ?? (await loadDefaultModel(taskHint!));
76
+ const hfModel = maybeModel ?? (await loadDefaultModel(task!));
56
77
  const model = await getProviderModelId({ model: hfModel, provider }, args, {
57
- taskHint,
78
+ task,
58
79
  chatCompletion,
59
80
  fetch: options?.fetch,
60
81
  });
@@ -68,44 +89,52 @@ export async function makeRequestOptions(
68
89
  ? "credentials-include"
69
90
  : "none";
70
91
 
92
+ // Make URL
71
93
  const url = endpointUrl
72
94
  ? chatCompletion
73
95
  ? endpointUrl + `/v1/chat/completions`
74
96
  : endpointUrl
75
- : makeUrl({
76
- authMethod,
77
- chatCompletion: chatCompletion ?? false,
97
+ : providerConfig.makeUrl({
98
+ baseUrl:
99
+ authMethod !== "provider-key"
100
+ ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider)
101
+ : providerConfig.baseUrl,
78
102
  model,
79
- provider: provider ?? "hf-inference",
80
- taskHint,
103
+ chatCompletion,
104
+ task,
81
105
  });
82
106
 
83
- const headers: Record<string, string> = {};
84
- if (accessToken) {
85
- if (provider === "fal-ai" && authMethod === "provider-key") {
86
- headers["Authorization"] = `Key ${accessToken}`;
87
- } else if (provider === "black-forest-labs" && authMethod === "provider-key") {
88
- headers["X-Key"] = accessToken;
89
- } else {
90
- headers["Authorization"] = `Bearer ${accessToken}`;
91
- }
92
- }
93
-
94
- // e.g. @huggingface/inference/3.1.3
95
- const ownUserAgent = `${packageName}/${packageVersion}`;
96
- headers["User-Agent"] = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : undefined]
97
- .filter((x) => x !== undefined)
98
- .join(" ");
99
-
107
+ // Make headers
100
108
  const binary = "data" in args && !!args.data;
109
+ const headers = providerConfig.makeHeaders({
110
+ accessToken,
111
+ authMethod,
112
+ });
101
113
 
114
+ // Add content-type to headers
102
115
  if (!binary) {
103
116
  headers["Content-Type"] = "application/json";
104
117
  }
105
118
 
106
- if (provider === "replicate") {
107
- headers["Prefer"] = "wait";
108
- }
119
+ // Add user-agent to headers
120
+ // e.g. @huggingface/inference/3.1.3
121
+ const ownUserAgent = `${packageName}/${packageVersion}`;
122
+ const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : undefined]
123
+ .filter((x) => x !== undefined)
124
+ .join(" ");
125
+ headers["User-Agent"] = userAgent;
126
+
127
+ // Make body
128
+ const body = binary
129
+ ? args.data
130
+ : JSON.stringify(
131
+ providerConfig.makeBody({
132
+ args: remainingArgs as Record<string, unknown>,
133
+ model,
134
+ task,
135
+ chatCompletion,
136
+ })
137
+ );
109
138
 
110
139
  /**
111
140
  * For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
@@ -117,158 +146,17 @@ export async function makeRequestOptions(
117
146
  credentials = "include";
118
147
  }
119
148
 
120
- /**
121
- * Replicate models wrap all inputs inside { input: ... }
122
- * Versioned Replicate models in the format `owner/model:version` expect the version in the body
123
- */
124
- if (provider === "replicate") {
125
- const version = model.includes(":") ? model.split(":")[1] : undefined;
126
- (otherArgs as unknown) = { input: otherArgs, version };
127
- }
128
-
129
149
  const info: RequestInit = {
130
150
  headers,
131
151
  method: "POST",
132
- body: binary
133
- ? args.data
134
- : JSON.stringify({
135
- ...otherArgs,
136
- ...(taskHint === "text-to-image" && provider === "hyperbolic"
137
- ? { model_name: model }
138
- : chatCompletion || provider === "together" || provider === "nebius" || provider === "hyperbolic"
139
- ? { model }
140
- : undefined),
141
- }),
152
+ body,
142
153
  ...(credentials ? { credentials } : undefined),
143
- signal: options?.signal,
154
+ signal,
144
155
  };
145
156
 
146
157
  return { url, info };
147
158
  }
148
159
 
149
- function makeUrl(params: {
150
- authMethod: "none" | "hf-token" | "credentials-include" | "provider-key";
151
- chatCompletion: boolean;
152
- model: string;
153
- provider: InferenceProvider;
154
- taskHint: InferenceTask | undefined;
155
- }): string {
156
- if (params.authMethod === "none" && params.provider !== "hf-inference") {
157
- throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
158
- }
159
-
160
- const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
161
- switch (params.provider) {
162
- case "black-forest-labs": {
163
- const baseUrl = shouldProxy
164
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
165
- : BLACKFORESTLABS_AI_API_BASE_URL;
166
- return `${baseUrl}/${params.model}`;
167
- }
168
- case "fal-ai": {
169
- const baseUrl = shouldProxy
170
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
171
- : FAL_AI_API_BASE_URL;
172
- return `${baseUrl}/${params.model}`;
173
- }
174
- case "nebius": {
175
- const baseUrl = shouldProxy
176
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
177
- : NEBIUS_API_BASE_URL;
178
-
179
- if (params.taskHint === "text-to-image") {
180
- return `${baseUrl}/v1/images/generations`;
181
- }
182
- if (params.taskHint === "text-generation") {
183
- if (params.chatCompletion) {
184
- return `${baseUrl}/v1/chat/completions`;
185
- }
186
- return `${baseUrl}/v1/completions`;
187
- }
188
- return baseUrl;
189
- }
190
- case "replicate": {
191
- const baseUrl = shouldProxy
192
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
193
- : REPLICATE_API_BASE_URL;
194
- if (params.model.includes(":")) {
195
- /// Versioned model
196
- return `${baseUrl}/v1/predictions`;
197
- }
198
- /// Evergreen / Canonical model
199
- return `${baseUrl}/v1/models/${params.model}/predictions`;
200
- }
201
- case "sambanova": {
202
- const baseUrl = shouldProxy
203
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
204
- : SAMBANOVA_API_BASE_URL;
205
- /// Sambanova API matches OpenAI-like APIs: model is defined in the request body
206
- if (params.taskHint === "text-generation" && params.chatCompletion) {
207
- return `${baseUrl}/v1/chat/completions`;
208
- }
209
- return baseUrl;
210
- }
211
- case "together": {
212
- const baseUrl = shouldProxy
213
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
214
- : TOGETHER_API_BASE_URL;
215
- /// Together API matches OpenAI-like APIs: model is defined in the request body
216
- if (params.taskHint === "text-to-image") {
217
- return `${baseUrl}/v1/images/generations`;
218
- }
219
- if (params.taskHint === "text-generation") {
220
- if (params.chatCompletion) {
221
- return `${baseUrl}/v1/chat/completions`;
222
- }
223
- return `${baseUrl}/v1/completions`;
224
- }
225
- return baseUrl;
226
- }
227
-
228
- case "fireworks-ai": {
229
- const baseUrl = shouldProxy
230
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
231
- : FIREWORKS_AI_API_BASE_URL;
232
- if (params.taskHint === "text-generation" && params.chatCompletion) {
233
- return `${baseUrl}/v1/chat/completions`;
234
- }
235
- return baseUrl;
236
- }
237
- case "hyperbolic": {
238
- const baseUrl = shouldProxy
239
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
240
- : HYPERBOLIC_API_BASE_URL;
241
-
242
- if (params.taskHint === "text-to-image") {
243
- return `${baseUrl}/v1/images/generations`;
244
- }
245
- return `${baseUrl}/v1/chat/completions`;
246
- }
247
- case "novita": {
248
- const baseUrl = shouldProxy
249
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
250
- : NOVITA_API_BASE_URL;
251
- if (params.taskHint === "text-generation") {
252
- if (params.chatCompletion) {
253
- return `${baseUrl}/chat/completions`;
254
- }
255
- return `${baseUrl}/completions`;
256
- }
257
- return baseUrl;
258
- }
259
- default: {
260
- const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
261
- if (params.taskHint && ["feature-extraction", "sentence-similarity"].includes(params.taskHint)) {
262
- /// when deployed on hf-inference, those two tasks are automatically compatible with one another.
263
- return `${baseUrl}/pipeline/${params.taskHint}/${params.model}`;
264
- }
265
- if (params.taskHint === "text-generation" && params.chatCompletion) {
266
- return `${baseUrl}/models/${params.model}/v1/chat/completions`;
267
- }
268
- return `${baseUrl}/models/${params.model}`;
269
- }
270
- }
271
- }
272
160
  async function loadDefaultModel(task: InferenceTask): Promise<string> {
273
161
  if (!tasks) {
274
162
  tasks = await loadTaskInfo();
@@ -1,5 +1,3 @@
1
- export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
2
-
3
1
  /**
4
2
  * See the registered mapping of HF model ID => Black Forest Labs model ID here:
5
3
  *
@@ -16,3 +14,29 @@ export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
16
14
  *
17
15
  * Thanks!
18
16
  */
17
+ import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18
+
19
+ const BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
20
+
21
+ const makeBody = (params: BodyParams): Record<string, unknown> => {
22
+ return params.args;
23
+ };
24
+
25
+ const makeHeaders = (params: HeaderParams): Record<string, string> => {
26
+ if (params.authMethod === "provider-key") {
27
+ return { "X-Key": `${params.accessToken}` };
28
+ } else {
29
+ return { Authorization: `Bearer ${params.accessToken}` };
30
+ }
31
+ };
32
+
33
+ const makeUrl = (params: UrlParams): string => {
34
+ return `${params.baseUrl}/${params.model}`;
35
+ };
36
+
37
+ export const BLACK_FOREST_LABS_CONFIG: ProviderConfig = {
38
+ baseUrl: BLACK_FOREST_LABS_AI_API_BASE_URL,
39
+ makeBody,
40
+ makeHeaders,
41
+ makeUrl,
42
+ };
@@ -0,0 +1,42 @@
1
+ /**
2
+ * See the registered mapping of HF model ID => Cohere model ID here:
3
+ *
4
+ * https://huggingface.co/api/partners/cohere/models
5
+ *
6
+ * This is a publicly available mapping.
7
+ *
8
+ * If you want to try to run inference for a new model locally before it's registered on huggingface.co,
9
+ * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
10
+ *
11
+ * - If you work at Cohere and want to update this mapping, please use the model mapping API we provide on huggingface.co
12
+ * - If you're a community member and want to add a new supported HF model to Cohere, please open an issue on the present repo
13
+ * and we will tag Cohere team members.
14
+ *
15
+ * Thanks!
16
+ */
17
+ import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18
+
19
+ const COHERE_API_BASE_URL = "https://api.cohere.com";
20
+
21
+
22
+ const makeBody = (params: BodyParams): Record<string, unknown> => {
23
+ return {
24
+ ...params.args,
25
+ model: params.model,
26
+ };
27
+ };
28
+
29
+ const makeHeaders = (params: HeaderParams): Record<string, string> => {
30
+ return { Authorization: `Bearer ${params.accessToken}` };
31
+ };
32
+
33
+ const makeUrl = (params: UrlParams): string => {
34
+ return `${params.baseUrl}/compatibility/v1/chat/completions`;
35
+ };
36
+
37
+ export const COHERE_CONFIG: ProviderConfig = {
38
+ baseUrl: COHERE_API_BASE_URL,
39
+ makeBody,
40
+ makeHeaders,
41
+ makeUrl,
42
+ };
@@ -17,13 +17,14 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
17
17
  * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
18
18
  */
19
19
  "black-forest-labs": {},
20
+ cohere: {},
20
21
  "fal-ai": {},
21
22
  "fireworks-ai": {},
22
23
  "hf-inference": {},
23
24
  hyperbolic: {},
24
25
  nebius: {},
26
+ novita: {},
25
27
  replicate: {},
26
28
  sambanova: {},
27
29
  together: {},
28
- novita: {},
29
30
  };
@@ -1,5 +1,3 @@
1
- export const FAL_AI_API_BASE_URL = "https://fal.run";
2
-
3
1
  /**
4
2
  * See the registered mapping of HF model ID => Fal model ID here:
5
3
  *
@@ -16,3 +14,27 @@ export const FAL_AI_API_BASE_URL = "https://fal.run";
16
14
  *
17
15
  * Thanks!
18
16
  */
17
+ import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18
+
19
+ const FAL_AI_API_BASE_URL = "https://fal.run";
20
+
21
+ const makeBody = (params: BodyParams): Record<string, unknown> => {
22
+ return params.args;
23
+ };
24
+
25
+ const makeHeaders = (params: HeaderParams): Record<string, string> => {
26
+ return {
27
+ Authorization: params.authMethod === "provider-key" ? `Key ${params.accessToken}` : `Bearer ${params.accessToken}`,
28
+ };
29
+ };
30
+
31
+ const makeUrl = (params: UrlParams): string => {
32
+ return `${params.baseUrl}/${params.model}`;
33
+ };
34
+
35
+ export const FAL_AI_CONFIG: ProviderConfig = {
36
+ baseUrl: FAL_AI_API_BASE_URL,
37
+ makeBody,
38
+ makeHeaders,
39
+ makeUrl,
40
+ };
@@ -1,5 +1,3 @@
1
- export const FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
2
-
3
1
  /**
4
2
  * See the registered mapping of HF model ID => Fireworks model ID here:
5
3
  *
@@ -16,3 +14,31 @@ export const FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
16
14
  *
17
15
  * Thanks!
18
16
  */
17
+ import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18
+
19
+ const FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
20
+
21
+ const makeBody = (params: BodyParams): Record<string, unknown> => {
22
+ return {
23
+ ...params.args,
24
+ ...(params.chatCompletion ? { model: params.model } : undefined),
25
+ };
26
+ };
27
+
28
+ const makeHeaders = (params: HeaderParams): Record<string, string> => {
29
+ return { Authorization: `Bearer ${params.accessToken}` };
30
+ };
31
+
32
+ const makeUrl = (params: UrlParams): string => {
33
+ if (params.task === "text-generation" && params.chatCompletion) {
34
+ return `${params.baseUrl}/v1/chat/completions`;
35
+ }
36
+ return params.baseUrl;
37
+ };
38
+
39
+ export const FIREWORKS_AI_CONFIG: ProviderConfig = {
40
+ baseUrl: FIREWORKS_AI_API_BASE_URL,
41
+ makeBody,
42
+ makeHeaders,
43
+ makeUrl,
44
+ };
@@ -0,0 +1,43 @@
1
+ /**
2
+ * HF-Inference do not have a mapping since all models use IDs from the Hub.
3
+ *
4
+ * If you want to try to run inference for a new model locally before it's registered on huggingface.co,
5
+ * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
6
+ *
7
+ * - If you work at HF and want to update this mapping, please use the model mapping API we provide on huggingface.co
8
+ * - If you're a community member and want to add a new supported HF model to HF, please open an issue on the present repo
9
+ * and we will tag HF team members.
10
+ *
11
+ * Thanks!
12
+ */
13
+ import { HF_ROUTER_URL } from "../config";
14
+ import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
15
+
16
+ const makeBody = (params: BodyParams): Record<string, unknown> => {
17
+ return {
18
+ ...params.args,
19
+ ...(params.chatCompletion ? { model: params.model } : undefined),
20
+ };
21
+ };
22
+
23
+ const makeHeaders = (params: HeaderParams): Record<string, string> => {
24
+ return { Authorization: `Bearer ${params.accessToken}` };
25
+ };
26
+
27
+ const makeUrl = (params: UrlParams): string => {
28
+ if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
29
+ /// when deployed on hf-inference, those two tasks are automatically compatible with one another.
30
+ return `${params.baseUrl}/pipeline/${params.task}/${params.model}`;
31
+ }
32
+ if (params.task === "text-generation" && params.chatCompletion) {
33
+ return `${params.baseUrl}/models/${params.model}/v1/chat/completions`;
34
+ }
35
+ return `${params.baseUrl}/models/${params.model}`;
36
+ };
37
+
38
+ export const HF_INFERENCE_CONFIG: ProviderConfig = {
39
+ baseUrl: `${HF_ROUTER_URL}/hf-inference`,
40
+ makeBody,
41
+ makeHeaders,
42
+ makeUrl,
43
+ };
@@ -1,5 +1,3 @@
1
- export const HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
2
-
3
1
  /**
4
2
  * See the registered mapping of HF model ID => Hyperbolic model ID here:
5
3
  *
@@ -16,3 +14,31 @@ export const HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
16
14
  *
17
15
  * Thanks!
18
16
  */
17
+ import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18
+
19
+ const HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
20
+
21
+ const makeBody = (params: BodyParams): Record<string, unknown> => {
22
+ return {
23
+ ...params.args,
24
+ ...(params.task === "text-to-image" ? { model_name: params.model } : { model: params.model }),
25
+ };
26
+ };
27
+
28
+ const makeHeaders = (params: HeaderParams): Record<string, string> => {
29
+ return { Authorization: `Bearer ${params.accessToken}` };
30
+ };
31
+
32
+ const makeUrl = (params: UrlParams): string => {
33
+ if (params.task === "text-to-image") {
34
+ return `${params.baseUrl}/v1/images/generations`;
35
+ }
36
+ return `${params.baseUrl}/v1/chat/completions`;
37
+ };
38
+
39
+ export const HYPERBOLIC_CONFIG: ProviderConfig = {
40
+ baseUrl: HYPERBOLIC_API_BASE_URL,
41
+ makeBody,
42
+ makeHeaders,
43
+ makeUrl,
44
+ };
@@ -1,5 +1,3 @@
1
- export const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
2
-
3
1
  /**
4
2
  * See the registered mapping of HF model ID => Nebius model ID here:
5
3
  *
@@ -16,3 +14,37 @@ export const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
16
14
  *
17
15
  * Thanks!
18
16
  */
17
+ import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18
+
19
+ const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
20
+
21
+ const makeBody = (params: BodyParams): Record<string, unknown> => {
22
+ return {
23
+ ...params.args,
24
+ model: params.model,
25
+ };
26
+ };
27
+
28
+ const makeHeaders = (params: HeaderParams): Record<string, string> => {
29
+ return { Authorization: `Bearer ${params.accessToken}` };
30
+ };
31
+
32
+ const makeUrl = (params: UrlParams): string => {
33
+ if (params.task === "text-to-image") {
34
+ return `${params.baseUrl}/v1/images/generations`;
35
+ }
36
+ if (params.task === "text-generation") {
37
+ if (params.chatCompletion) {
38
+ return `${params.baseUrl}/v1/chat/completions`;
39
+ }
40
+ return `${params.baseUrl}/v1/completions`;
41
+ }
42
+ return params.baseUrl;
43
+ };
44
+
45
+ export const NEBIUS_CONFIG: ProviderConfig = {
46
+ baseUrl: NEBIUS_API_BASE_URL,
47
+ makeBody,
48
+ makeHeaders,
49
+ makeUrl,
50
+ };
@@ -1,5 +1,3 @@
1
- export const NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
2
-
3
1
  /**
4
2
  * See the registered mapping of HF model ID => Novita model ID here:
5
3
  *
@@ -16,3 +14,34 @@ export const NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
16
14
  *
17
15
  * Thanks!
18
16
  */
17
+ import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18
+
19
+ const NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
20
+
21
+ const makeBody = (params: BodyParams): Record<string, unknown> => {
22
+ return {
23
+ ...params.args,
24
+ ...(params.chatCompletion ? { model: params.model } : undefined),
25
+ };
26
+ };
27
+
28
+ const makeHeaders = (params: HeaderParams): Record<string, string> => {
29
+ return { Authorization: `Bearer ${params.accessToken}` };
30
+ };
31
+
32
+ const makeUrl = (params: UrlParams): string => {
33
+ if (params.task === "text-generation") {
34
+ if (params.chatCompletion) {
35
+ return `${params.baseUrl}/chat/completions`;
36
+ }
37
+ return `${params.baseUrl}/completions`;
38
+ }
39
+ return params.baseUrl;
40
+ };
41
+
42
+ export const NOVITA_CONFIG: ProviderConfig = {
43
+ baseUrl: NOVITA_API_BASE_URL,
44
+ makeBody,
45
+ makeHeaders,
46
+ makeUrl,
47
+ };