@huggingface/inference 3.3.6 → 3.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +2 -0
- package/dist/index.cjs +339 -174
- package/dist/index.js +339 -174
- package/dist/src/lib/getProviderModelId.d.ts +1 -1
- package/dist/src/lib/getProviderModelId.d.ts.map +1 -1
- package/dist/src/lib/makeRequestOptions.d.ts +2 -2
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/black-forest-labs.d.ts +2 -1
- package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
- package/dist/src/providers/cohere.d.ts +19 -0
- package/dist/src/providers/cohere.d.ts.map +1 -0
- package/dist/src/providers/consts.d.ts.map +1 -1
- package/dist/src/providers/fal-ai.d.ts +2 -1
- package/dist/src/providers/fal-ai.d.ts.map +1 -1
- package/dist/src/providers/fireworks-ai.d.ts +2 -1
- package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
- package/dist/src/providers/hf-inference.d.ts +3 -0
- package/dist/src/providers/hf-inference.d.ts.map +1 -0
- package/dist/src/providers/hyperbolic.d.ts +2 -1
- package/dist/src/providers/hyperbolic.d.ts.map +1 -1
- package/dist/src/providers/nebius.d.ts +2 -1
- package/dist/src/providers/nebius.d.ts.map +1 -1
- package/dist/src/providers/novita.d.ts +2 -1
- package/dist/src/providers/novita.d.ts.map +1 -1
- package/dist/src/providers/replicate.d.ts +3 -1
- package/dist/src/providers/replicate.d.ts.map +1 -1
- package/dist/src/providers/sambanova.d.ts +2 -1
- package/dist/src/providers/sambanova.d.ts.map +1 -1
- package/dist/src/providers/together.d.ts +2 -1
- package/dist/src/providers/together.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts +2 -4
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts +2 -4
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts +2 -9
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/types.d.ts +25 -4
- package/dist/src/types.d.ts.map +1 -1
- package/package.json +2 -2
- package/src/lib/getProviderModelId.ts +4 -4
- package/src/lib/makeRequestOptions.ts +74 -186
- package/src/providers/black-forest-labs.ts +26 -2
- package/src/providers/cohere.ts +42 -0
- package/src/providers/consts.ts +2 -1
- package/src/providers/fal-ai.ts +24 -2
- package/src/providers/fireworks-ai.ts +28 -2
- package/src/providers/hf-inference.ts +43 -0
- package/src/providers/hyperbolic.ts +28 -2
- package/src/providers/nebius.ts +34 -2
- package/src/providers/novita.ts +31 -2
- package/src/providers/replicate.ts +30 -2
- package/src/providers/sambanova.ts +28 -2
- package/src/providers/together.ts +34 -2
- package/src/tasks/audio/audioClassification.ts +1 -1
- package/src/tasks/audio/audioToAudio.ts +1 -1
- package/src/tasks/audio/automaticSpeechRecognition.ts +1 -1
- package/src/tasks/audio/textToSpeech.ts +1 -1
- package/src/tasks/custom/request.ts +2 -4
- package/src/tasks/custom/streamingRequest.ts +2 -4
- package/src/tasks/cv/imageClassification.ts +1 -1
- package/src/tasks/cv/imageSegmentation.ts +1 -1
- package/src/tasks/cv/imageToImage.ts +1 -1
- package/src/tasks/cv/imageToText.ts +1 -1
- package/src/tasks/cv/objectDetection.ts +1 -1
- package/src/tasks/cv/textToImage.ts +1 -1
- package/src/tasks/cv/textToVideo.ts +1 -1
- package/src/tasks/cv/zeroShotImageClassification.ts +1 -1
- package/src/tasks/multimodal/documentQuestionAnswering.ts +1 -1
- package/src/tasks/multimodal/visualQuestionAnswering.ts +1 -1
- package/src/tasks/nlp/chatCompletion.ts +1 -1
- package/src/tasks/nlp/chatCompletionStream.ts +1 -1
- package/src/tasks/nlp/featureExtraction.ts +3 -10
- package/src/tasks/nlp/fillMask.ts +1 -1
- package/src/tasks/nlp/questionAnswering.ts +1 -1
- package/src/tasks/nlp/sentenceSimilarity.ts +1 -1
- package/src/tasks/nlp/summarization.ts +1 -1
- package/src/tasks/nlp/tableQuestionAnswering.ts +1 -1
- package/src/tasks/nlp/textClassification.ts +1 -1
- package/src/tasks/nlp/textGeneration.ts +3 -3
- package/src/tasks/nlp/textGenerationStream.ts +1 -1
- package/src/tasks/nlp/tokenClassification.ts +1 -1
- package/src/tasks/nlp/translation.ts +1 -1
- package/src/tasks/nlp/zeroShotClassification.ts +1 -1
- package/src/tasks/tabular/tabularClassification.ts +1 -1
- package/src/tasks/tabular/tabularRegression.ts +1 -1
- package/src/types.ts +29 -2
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
import { HF_HUB_URL, HF_ROUTER_URL } from "../config";
|
|
2
|
-
import {
|
|
3
|
-
import {
|
|
4
|
-
import {
|
|
5
|
-
import {
|
|
6
|
-
import {
|
|
7
|
-
import {
|
|
8
|
-
import {
|
|
9
|
-
import {
|
|
10
|
-
import {
|
|
11
|
-
import
|
|
12
|
-
import
|
|
2
|
+
import { BLACK_FOREST_LABS_CONFIG } from "../providers/black-forest-labs";
|
|
3
|
+
import { COHERE_CONFIG } from "../providers/cohere";
|
|
4
|
+
import { FAL_AI_CONFIG } from "../providers/fal-ai";
|
|
5
|
+
import { FIREWORKS_AI_CONFIG } from "../providers/fireworks-ai";
|
|
6
|
+
import { HF_INFERENCE_CONFIG } from "../providers/hf-inference";
|
|
7
|
+
import { HYPERBOLIC_CONFIG } from "../providers/hyperbolic";
|
|
8
|
+
import { NEBIUS_CONFIG } from "../providers/nebius";
|
|
9
|
+
import { NOVITA_CONFIG } from "../providers/novita";
|
|
10
|
+
import { REPLICATE_CONFIG } from "../providers/replicate";
|
|
11
|
+
import { SAMBANOVA_CONFIG } from "../providers/sambanova";
|
|
12
|
+
import { TOGETHER_CONFIG } from "../providers/together";
|
|
13
|
+
import type { InferenceProvider, InferenceTask, Options, ProviderConfig, RequestArgs } from "../types";
|
|
13
14
|
import { isUrl } from "./isUrl";
|
|
14
15
|
import { version as packageVersion, name as packageName } from "../../package.json";
|
|
15
16
|
import { getProviderModelId } from "./getProviderModelId";
|
|
@@ -22,6 +23,23 @@ const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
|
|
|
22
23
|
*/
|
|
23
24
|
let tasks: Record<string, { models: { id: string }[] }> | null = null;
|
|
24
25
|
|
|
26
|
+
/**
|
|
27
|
+
* Config to define how to serialize requests for each provider
|
|
28
|
+
*/
|
|
29
|
+
const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
|
|
30
|
+
"black-forest-labs": BLACK_FOREST_LABS_CONFIG,
|
|
31
|
+
cohere: COHERE_CONFIG,
|
|
32
|
+
"fal-ai": FAL_AI_CONFIG,
|
|
33
|
+
"fireworks-ai": FIREWORKS_AI_CONFIG,
|
|
34
|
+
"hf-inference": HF_INFERENCE_CONFIG,
|
|
35
|
+
hyperbolic: HYPERBOLIC_CONFIG,
|
|
36
|
+
nebius: NEBIUS_CONFIG,
|
|
37
|
+
novita: NOVITA_CONFIG,
|
|
38
|
+
replicate: REPLICATE_CONFIG,
|
|
39
|
+
sambanova: SAMBANOVA_CONFIG,
|
|
40
|
+
together: TOGETHER_CONFIG,
|
|
41
|
+
};
|
|
42
|
+
|
|
25
43
|
/**
|
|
26
44
|
* Helper that prepares request arguments
|
|
27
45
|
*/
|
|
@@ -31,16 +49,16 @@ export async function makeRequestOptions(
|
|
|
31
49
|
stream?: boolean;
|
|
32
50
|
},
|
|
33
51
|
options?: Options & {
|
|
34
|
-
/**
|
|
35
|
-
|
|
52
|
+
/** In most cases (unless we pass a endpointUrl) we know the task */
|
|
53
|
+
task?: InferenceTask;
|
|
36
54
|
chatCompletion?: boolean;
|
|
37
55
|
}
|
|
38
56
|
): Promise<{ url: string; info: RequestInit }> {
|
|
39
57
|
const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
|
|
40
|
-
let otherArgs = remainingArgs;
|
|
41
58
|
const provider = maybeProvider ?? "hf-inference";
|
|
59
|
+
const providerConfig = providerConfigs[provider];
|
|
42
60
|
|
|
43
|
-
const { includeCredentials,
|
|
61
|
+
const { includeCredentials, task, chatCompletion, signal } = options ?? {};
|
|
44
62
|
|
|
45
63
|
if (endpointUrl && provider !== "hf-inference") {
|
|
46
64
|
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
|
|
@@ -48,13 +66,16 @@ export async function makeRequestOptions(
|
|
|
48
66
|
if (maybeModel && isUrl(maybeModel)) {
|
|
49
67
|
throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
|
|
50
68
|
}
|
|
51
|
-
if (!maybeModel && !
|
|
69
|
+
if (!maybeModel && !task) {
|
|
52
70
|
throw new Error("No model provided, and no task has been specified.");
|
|
53
71
|
}
|
|
72
|
+
if (!providerConfig) {
|
|
73
|
+
throw new Error(`No provider config found for provider ${provider}`);
|
|
74
|
+
}
|
|
54
75
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
55
|
-
const hfModel = maybeModel ?? (await loadDefaultModel(
|
|
76
|
+
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
|
|
56
77
|
const model = await getProviderModelId({ model: hfModel, provider }, args, {
|
|
57
|
-
|
|
78
|
+
task,
|
|
58
79
|
chatCompletion,
|
|
59
80
|
fetch: options?.fetch,
|
|
60
81
|
});
|
|
@@ -68,44 +89,52 @@ export async function makeRequestOptions(
|
|
|
68
89
|
? "credentials-include"
|
|
69
90
|
: "none";
|
|
70
91
|
|
|
92
|
+
// Make URL
|
|
71
93
|
const url = endpointUrl
|
|
72
94
|
? chatCompletion
|
|
73
95
|
? endpointUrl + `/v1/chat/completions`
|
|
74
96
|
: endpointUrl
|
|
75
|
-
: makeUrl({
|
|
76
|
-
|
|
77
|
-
|
|
97
|
+
: providerConfig.makeUrl({
|
|
98
|
+
baseUrl:
|
|
99
|
+
authMethod !== "provider-key"
|
|
100
|
+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider)
|
|
101
|
+
: providerConfig.baseUrl,
|
|
78
102
|
model,
|
|
79
|
-
|
|
80
|
-
|
|
103
|
+
chatCompletion,
|
|
104
|
+
task,
|
|
81
105
|
});
|
|
82
106
|
|
|
83
|
-
|
|
84
|
-
if (accessToken) {
|
|
85
|
-
if (provider === "fal-ai" && authMethod === "provider-key") {
|
|
86
|
-
headers["Authorization"] = `Key ${accessToken}`;
|
|
87
|
-
} else if (provider === "black-forest-labs" && authMethod === "provider-key") {
|
|
88
|
-
headers["X-Key"] = accessToken;
|
|
89
|
-
} else {
|
|
90
|
-
headers["Authorization"] = `Bearer ${accessToken}`;
|
|
91
|
-
}
|
|
92
|
-
}
|
|
93
|
-
|
|
94
|
-
// e.g. @huggingface/inference/3.1.3
|
|
95
|
-
const ownUserAgent = `${packageName}/${packageVersion}`;
|
|
96
|
-
headers["User-Agent"] = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : undefined]
|
|
97
|
-
.filter((x) => x !== undefined)
|
|
98
|
-
.join(" ");
|
|
99
|
-
|
|
107
|
+
// Make headers
|
|
100
108
|
const binary = "data" in args && !!args.data;
|
|
109
|
+
const headers = providerConfig.makeHeaders({
|
|
110
|
+
accessToken,
|
|
111
|
+
authMethod,
|
|
112
|
+
});
|
|
101
113
|
|
|
114
|
+
// Add content-type to headers
|
|
102
115
|
if (!binary) {
|
|
103
116
|
headers["Content-Type"] = "application/json";
|
|
104
117
|
}
|
|
105
118
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
}
|
|
119
|
+
// Add user-agent to headers
|
|
120
|
+
// e.g. @huggingface/inference/3.1.3
|
|
121
|
+
const ownUserAgent = `${packageName}/${packageVersion}`;
|
|
122
|
+
const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : undefined]
|
|
123
|
+
.filter((x) => x !== undefined)
|
|
124
|
+
.join(" ");
|
|
125
|
+
headers["User-Agent"] = userAgent;
|
|
126
|
+
|
|
127
|
+
// Make body
|
|
128
|
+
const body = binary
|
|
129
|
+
? args.data
|
|
130
|
+
: JSON.stringify(
|
|
131
|
+
providerConfig.makeBody({
|
|
132
|
+
args: remainingArgs as Record<string, unknown>,
|
|
133
|
+
model,
|
|
134
|
+
task,
|
|
135
|
+
chatCompletion,
|
|
136
|
+
})
|
|
137
|
+
);
|
|
109
138
|
|
|
110
139
|
/**
|
|
111
140
|
* For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
|
|
@@ -117,158 +146,17 @@ export async function makeRequestOptions(
|
|
|
117
146
|
credentials = "include";
|
|
118
147
|
}
|
|
119
148
|
|
|
120
|
-
/**
|
|
121
|
-
* Replicate models wrap all inputs inside { input: ... }
|
|
122
|
-
* Versioned Replicate models in the format `owner/model:version` expect the version in the body
|
|
123
|
-
*/
|
|
124
|
-
if (provider === "replicate") {
|
|
125
|
-
const version = model.includes(":") ? model.split(":")[1] : undefined;
|
|
126
|
-
(otherArgs as unknown) = { input: otherArgs, version };
|
|
127
|
-
}
|
|
128
|
-
|
|
129
149
|
const info: RequestInit = {
|
|
130
150
|
headers,
|
|
131
151
|
method: "POST",
|
|
132
|
-
body
|
|
133
|
-
? args.data
|
|
134
|
-
: JSON.stringify({
|
|
135
|
-
...otherArgs,
|
|
136
|
-
...(taskHint === "text-to-image" && provider === "hyperbolic"
|
|
137
|
-
? { model_name: model }
|
|
138
|
-
: chatCompletion || provider === "together" || provider === "nebius" || provider === "hyperbolic"
|
|
139
|
-
? { model }
|
|
140
|
-
: undefined),
|
|
141
|
-
}),
|
|
152
|
+
body,
|
|
142
153
|
...(credentials ? { credentials } : undefined),
|
|
143
|
-
signal
|
|
154
|
+
signal,
|
|
144
155
|
};
|
|
145
156
|
|
|
146
157
|
return { url, info };
|
|
147
158
|
}
|
|
148
159
|
|
|
149
|
-
function makeUrl(params: {
|
|
150
|
-
authMethod: "none" | "hf-token" | "credentials-include" | "provider-key";
|
|
151
|
-
chatCompletion: boolean;
|
|
152
|
-
model: string;
|
|
153
|
-
provider: InferenceProvider;
|
|
154
|
-
taskHint: InferenceTask | undefined;
|
|
155
|
-
}): string {
|
|
156
|
-
if (params.authMethod === "none" && params.provider !== "hf-inference") {
|
|
157
|
-
throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
|
|
158
|
-
}
|
|
159
|
-
|
|
160
|
-
const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
|
|
161
|
-
switch (params.provider) {
|
|
162
|
-
case "black-forest-labs": {
|
|
163
|
-
const baseUrl = shouldProxy
|
|
164
|
-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
|
|
165
|
-
: BLACKFORESTLABS_AI_API_BASE_URL;
|
|
166
|
-
return `${baseUrl}/${params.model}`;
|
|
167
|
-
}
|
|
168
|
-
case "fal-ai": {
|
|
169
|
-
const baseUrl = shouldProxy
|
|
170
|
-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
|
|
171
|
-
: FAL_AI_API_BASE_URL;
|
|
172
|
-
return `${baseUrl}/${params.model}`;
|
|
173
|
-
}
|
|
174
|
-
case "nebius": {
|
|
175
|
-
const baseUrl = shouldProxy
|
|
176
|
-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
|
|
177
|
-
: NEBIUS_API_BASE_URL;
|
|
178
|
-
|
|
179
|
-
if (params.taskHint === "text-to-image") {
|
|
180
|
-
return `${baseUrl}/v1/images/generations`;
|
|
181
|
-
}
|
|
182
|
-
if (params.taskHint === "text-generation") {
|
|
183
|
-
if (params.chatCompletion) {
|
|
184
|
-
return `${baseUrl}/v1/chat/completions`;
|
|
185
|
-
}
|
|
186
|
-
return `${baseUrl}/v1/completions`;
|
|
187
|
-
}
|
|
188
|
-
return baseUrl;
|
|
189
|
-
}
|
|
190
|
-
case "replicate": {
|
|
191
|
-
const baseUrl = shouldProxy
|
|
192
|
-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
|
|
193
|
-
: REPLICATE_API_BASE_URL;
|
|
194
|
-
if (params.model.includes(":")) {
|
|
195
|
-
/// Versioned model
|
|
196
|
-
return `${baseUrl}/v1/predictions`;
|
|
197
|
-
}
|
|
198
|
-
/// Evergreen / Canonical model
|
|
199
|
-
return `${baseUrl}/v1/models/${params.model}/predictions`;
|
|
200
|
-
}
|
|
201
|
-
case "sambanova": {
|
|
202
|
-
const baseUrl = shouldProxy
|
|
203
|
-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
|
|
204
|
-
: SAMBANOVA_API_BASE_URL;
|
|
205
|
-
/// Sambanova API matches OpenAI-like APIs: model is defined in the request body
|
|
206
|
-
if (params.taskHint === "text-generation" && params.chatCompletion) {
|
|
207
|
-
return `${baseUrl}/v1/chat/completions`;
|
|
208
|
-
}
|
|
209
|
-
return baseUrl;
|
|
210
|
-
}
|
|
211
|
-
case "together": {
|
|
212
|
-
const baseUrl = shouldProxy
|
|
213
|
-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
|
|
214
|
-
: TOGETHER_API_BASE_URL;
|
|
215
|
-
/// Together API matches OpenAI-like APIs: model is defined in the request body
|
|
216
|
-
if (params.taskHint === "text-to-image") {
|
|
217
|
-
return `${baseUrl}/v1/images/generations`;
|
|
218
|
-
}
|
|
219
|
-
if (params.taskHint === "text-generation") {
|
|
220
|
-
if (params.chatCompletion) {
|
|
221
|
-
return `${baseUrl}/v1/chat/completions`;
|
|
222
|
-
}
|
|
223
|
-
return `${baseUrl}/v1/completions`;
|
|
224
|
-
}
|
|
225
|
-
return baseUrl;
|
|
226
|
-
}
|
|
227
|
-
|
|
228
|
-
case "fireworks-ai": {
|
|
229
|
-
const baseUrl = shouldProxy
|
|
230
|
-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
|
|
231
|
-
: FIREWORKS_AI_API_BASE_URL;
|
|
232
|
-
if (params.taskHint === "text-generation" && params.chatCompletion) {
|
|
233
|
-
return `${baseUrl}/v1/chat/completions`;
|
|
234
|
-
}
|
|
235
|
-
return baseUrl;
|
|
236
|
-
}
|
|
237
|
-
case "hyperbolic": {
|
|
238
|
-
const baseUrl = shouldProxy
|
|
239
|
-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
|
|
240
|
-
: HYPERBOLIC_API_BASE_URL;
|
|
241
|
-
|
|
242
|
-
if (params.taskHint === "text-to-image") {
|
|
243
|
-
return `${baseUrl}/v1/images/generations`;
|
|
244
|
-
}
|
|
245
|
-
return `${baseUrl}/v1/chat/completions`;
|
|
246
|
-
}
|
|
247
|
-
case "novita": {
|
|
248
|
-
const baseUrl = shouldProxy
|
|
249
|
-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
|
|
250
|
-
: NOVITA_API_BASE_URL;
|
|
251
|
-
if (params.taskHint === "text-generation") {
|
|
252
|
-
if (params.chatCompletion) {
|
|
253
|
-
return `${baseUrl}/chat/completions`;
|
|
254
|
-
}
|
|
255
|
-
return `${baseUrl}/completions`;
|
|
256
|
-
}
|
|
257
|
-
return baseUrl;
|
|
258
|
-
}
|
|
259
|
-
default: {
|
|
260
|
-
const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
|
|
261
|
-
if (params.taskHint && ["feature-extraction", "sentence-similarity"].includes(params.taskHint)) {
|
|
262
|
-
/// when deployed on hf-inference, those two tasks are automatically compatible with one another.
|
|
263
|
-
return `${baseUrl}/pipeline/${params.taskHint}/${params.model}`;
|
|
264
|
-
}
|
|
265
|
-
if (params.taskHint === "text-generation" && params.chatCompletion) {
|
|
266
|
-
return `${baseUrl}/models/${params.model}/v1/chat/completions`;
|
|
267
|
-
}
|
|
268
|
-
return `${baseUrl}/models/${params.model}`;
|
|
269
|
-
}
|
|
270
|
-
}
|
|
271
|
-
}
|
|
272
160
|
async function loadDefaultModel(task: InferenceTask): Promise<string> {
|
|
273
161
|
if (!tasks) {
|
|
274
162
|
tasks = await loadTaskInfo();
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
|
|
2
|
-
|
|
3
1
|
/**
|
|
4
2
|
* See the registered mapping of HF model ID => Black Forest Labs model ID here:
|
|
5
3
|
*
|
|
@@ -16,3 +14,29 @@ export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
|
|
|
16
14
|
*
|
|
17
15
|
* Thanks!
|
|
18
16
|
*/
|
|
17
|
+
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
|
|
18
|
+
|
|
19
|
+
const BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
|
|
20
|
+
|
|
21
|
+
const makeBody = (params: BodyParams): Record<string, unknown> => {
|
|
22
|
+
return params.args;
|
|
23
|
+
};
|
|
24
|
+
|
|
25
|
+
const makeHeaders = (params: HeaderParams): Record<string, string> => {
|
|
26
|
+
if (params.authMethod === "provider-key") {
|
|
27
|
+
return { "X-Key": `${params.accessToken}` };
|
|
28
|
+
} else {
|
|
29
|
+
return { Authorization: `Bearer ${params.accessToken}` };
|
|
30
|
+
}
|
|
31
|
+
};
|
|
32
|
+
|
|
33
|
+
const makeUrl = (params: UrlParams): string => {
|
|
34
|
+
return `${params.baseUrl}/${params.model}`;
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
export const BLACK_FOREST_LABS_CONFIG: ProviderConfig = {
|
|
38
|
+
baseUrl: BLACK_FOREST_LABS_AI_API_BASE_URL,
|
|
39
|
+
makeBody,
|
|
40
|
+
makeHeaders,
|
|
41
|
+
makeUrl,
|
|
42
|
+
};
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* See the registered mapping of HF model ID => Cohere model ID here:
|
|
3
|
+
*
|
|
4
|
+
* https://huggingface.co/api/partners/cohere/models
|
|
5
|
+
*
|
|
6
|
+
* This is a publicly available mapping.
|
|
7
|
+
*
|
|
8
|
+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
|
|
9
|
+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
|
|
10
|
+
*
|
|
11
|
+
* - If you work at Cohere and want to update this mapping, please use the model mapping API we provide on huggingface.co
|
|
12
|
+
* - If you're a community member and want to add a new supported HF model to Cohere, please open an issue on the present repo
|
|
13
|
+
* and we will tag Cohere team members.
|
|
14
|
+
*
|
|
15
|
+
* Thanks!
|
|
16
|
+
*/
|
|
17
|
+
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
|
|
18
|
+
|
|
19
|
+
const COHERE_API_BASE_URL = "https://api.cohere.com";
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
const makeBody = (params: BodyParams): Record<string, unknown> => {
|
|
23
|
+
return {
|
|
24
|
+
...params.args,
|
|
25
|
+
model: params.model,
|
|
26
|
+
};
|
|
27
|
+
};
|
|
28
|
+
|
|
29
|
+
const makeHeaders = (params: HeaderParams): Record<string, string> => {
|
|
30
|
+
return { Authorization: `Bearer ${params.accessToken}` };
|
|
31
|
+
};
|
|
32
|
+
|
|
33
|
+
const makeUrl = (params: UrlParams): string => {
|
|
34
|
+
return `${params.baseUrl}/compatibility/v1/chat/completions`;
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
export const COHERE_CONFIG: ProviderConfig = {
|
|
38
|
+
baseUrl: COHERE_API_BASE_URL,
|
|
39
|
+
makeBody,
|
|
40
|
+
makeHeaders,
|
|
41
|
+
makeUrl,
|
|
42
|
+
};
|
package/src/providers/consts.ts
CHANGED
|
@@ -17,13 +17,14 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
|
|
|
17
17
|
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
|
|
18
18
|
*/
|
|
19
19
|
"black-forest-labs": {},
|
|
20
|
+
cohere: {},
|
|
20
21
|
"fal-ai": {},
|
|
21
22
|
"fireworks-ai": {},
|
|
22
23
|
"hf-inference": {},
|
|
23
24
|
hyperbolic: {},
|
|
24
25
|
nebius: {},
|
|
26
|
+
novita: {},
|
|
25
27
|
replicate: {},
|
|
26
28
|
sambanova: {},
|
|
27
29
|
together: {},
|
|
28
|
-
novita: {},
|
|
29
30
|
};
|
package/src/providers/fal-ai.ts
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
export const FAL_AI_API_BASE_URL = "https://fal.run";
|
|
2
|
-
|
|
3
1
|
/**
|
|
4
2
|
* See the registered mapping of HF model ID => Fal model ID here:
|
|
5
3
|
*
|
|
@@ -16,3 +14,27 @@ export const FAL_AI_API_BASE_URL = "https://fal.run";
|
|
|
16
14
|
*
|
|
17
15
|
* Thanks!
|
|
18
16
|
*/
|
|
17
|
+
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
|
|
18
|
+
|
|
19
|
+
const FAL_AI_API_BASE_URL = "https://fal.run";
|
|
20
|
+
|
|
21
|
+
const makeBody = (params: BodyParams): Record<string, unknown> => {
|
|
22
|
+
return params.args;
|
|
23
|
+
};
|
|
24
|
+
|
|
25
|
+
const makeHeaders = (params: HeaderParams): Record<string, string> => {
|
|
26
|
+
return {
|
|
27
|
+
Authorization: params.authMethod === "provider-key" ? `Key ${params.accessToken}` : `Bearer ${params.accessToken}`,
|
|
28
|
+
};
|
|
29
|
+
};
|
|
30
|
+
|
|
31
|
+
const makeUrl = (params: UrlParams): string => {
|
|
32
|
+
return `${params.baseUrl}/${params.model}`;
|
|
33
|
+
};
|
|
34
|
+
|
|
35
|
+
export const FAL_AI_CONFIG: ProviderConfig = {
|
|
36
|
+
baseUrl: FAL_AI_API_BASE_URL,
|
|
37
|
+
makeBody,
|
|
38
|
+
makeHeaders,
|
|
39
|
+
makeUrl,
|
|
40
|
+
};
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
export const FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
|
|
2
|
-
|
|
3
1
|
/**
|
|
4
2
|
* See the registered mapping of HF model ID => Fireworks model ID here:
|
|
5
3
|
*
|
|
@@ -16,3 +14,31 @@ export const FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
|
|
|
16
14
|
*
|
|
17
15
|
* Thanks!
|
|
18
16
|
*/
|
|
17
|
+
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
|
|
18
|
+
|
|
19
|
+
const FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
|
|
20
|
+
|
|
21
|
+
const makeBody = (params: BodyParams): Record<string, unknown> => {
|
|
22
|
+
return {
|
|
23
|
+
...params.args,
|
|
24
|
+
...(params.chatCompletion ? { model: params.model } : undefined),
|
|
25
|
+
};
|
|
26
|
+
};
|
|
27
|
+
|
|
28
|
+
const makeHeaders = (params: HeaderParams): Record<string, string> => {
|
|
29
|
+
return { Authorization: `Bearer ${params.accessToken}` };
|
|
30
|
+
};
|
|
31
|
+
|
|
32
|
+
const makeUrl = (params: UrlParams): string => {
|
|
33
|
+
if (params.task === "text-generation" && params.chatCompletion) {
|
|
34
|
+
return `${params.baseUrl}/v1/chat/completions`;
|
|
35
|
+
}
|
|
36
|
+
return params.baseUrl;
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
export const FIREWORKS_AI_CONFIG: ProviderConfig = {
|
|
40
|
+
baseUrl: FIREWORKS_AI_API_BASE_URL,
|
|
41
|
+
makeBody,
|
|
42
|
+
makeHeaders,
|
|
43
|
+
makeUrl,
|
|
44
|
+
};
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* HF-Inference do not have a mapping since all models use IDs from the Hub.
|
|
3
|
+
*
|
|
4
|
+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
|
|
5
|
+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
|
|
6
|
+
*
|
|
7
|
+
* - If you work at HF and want to update this mapping, please use the model mapping API we provide on huggingface.co
|
|
8
|
+
* - If you're a community member and want to add a new supported HF model to HF, please open an issue on the present repo
|
|
9
|
+
* and we will tag HF team members.
|
|
10
|
+
*
|
|
11
|
+
* Thanks!
|
|
12
|
+
*/
|
|
13
|
+
import { HF_ROUTER_URL } from "../config";
|
|
14
|
+
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
|
|
15
|
+
|
|
16
|
+
const makeBody = (params: BodyParams): Record<string, unknown> => {
|
|
17
|
+
return {
|
|
18
|
+
...params.args,
|
|
19
|
+
...(params.chatCompletion ? { model: params.model } : undefined),
|
|
20
|
+
};
|
|
21
|
+
};
|
|
22
|
+
|
|
23
|
+
const makeHeaders = (params: HeaderParams): Record<string, string> => {
|
|
24
|
+
return { Authorization: `Bearer ${params.accessToken}` };
|
|
25
|
+
};
|
|
26
|
+
|
|
27
|
+
const makeUrl = (params: UrlParams): string => {
|
|
28
|
+
if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
|
|
29
|
+
/// when deployed on hf-inference, those two tasks are automatically compatible with one another.
|
|
30
|
+
return `${params.baseUrl}/pipeline/${params.task}/${params.model}`;
|
|
31
|
+
}
|
|
32
|
+
if (params.task === "text-generation" && params.chatCompletion) {
|
|
33
|
+
return `${params.baseUrl}/models/${params.model}/v1/chat/completions`;
|
|
34
|
+
}
|
|
35
|
+
return `${params.baseUrl}/models/${params.model}`;
|
|
36
|
+
};
|
|
37
|
+
|
|
38
|
+
export const HF_INFERENCE_CONFIG: ProviderConfig = {
|
|
39
|
+
baseUrl: `${HF_ROUTER_URL}/hf-inference`,
|
|
40
|
+
makeBody,
|
|
41
|
+
makeHeaders,
|
|
42
|
+
makeUrl,
|
|
43
|
+
};
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
export const HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
|
|
2
|
-
|
|
3
1
|
/**
|
|
4
2
|
* See the registered mapping of HF model ID => Hyperbolic model ID here:
|
|
5
3
|
*
|
|
@@ -16,3 +14,31 @@ export const HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
|
|
|
16
14
|
*
|
|
17
15
|
* Thanks!
|
|
18
16
|
*/
|
|
17
|
+
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
|
|
18
|
+
|
|
19
|
+
const HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
|
|
20
|
+
|
|
21
|
+
const makeBody = (params: BodyParams): Record<string, unknown> => {
|
|
22
|
+
return {
|
|
23
|
+
...params.args,
|
|
24
|
+
...(params.task === "text-to-image" ? { model_name: params.model } : { model: params.model }),
|
|
25
|
+
};
|
|
26
|
+
};
|
|
27
|
+
|
|
28
|
+
const makeHeaders = (params: HeaderParams): Record<string, string> => {
|
|
29
|
+
return { Authorization: `Bearer ${params.accessToken}` };
|
|
30
|
+
};
|
|
31
|
+
|
|
32
|
+
const makeUrl = (params: UrlParams): string => {
|
|
33
|
+
if (params.task === "text-to-image") {
|
|
34
|
+
return `${params.baseUrl}/v1/images/generations`;
|
|
35
|
+
}
|
|
36
|
+
return `${params.baseUrl}/v1/chat/completions`;
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
export const HYPERBOLIC_CONFIG: ProviderConfig = {
|
|
40
|
+
baseUrl: HYPERBOLIC_API_BASE_URL,
|
|
41
|
+
makeBody,
|
|
42
|
+
makeHeaders,
|
|
43
|
+
makeUrl,
|
|
44
|
+
};
|
package/src/providers/nebius.ts
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
export const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
|
|
2
|
-
|
|
3
1
|
/**
|
|
4
2
|
* See the registered mapping of HF model ID => Nebius model ID here:
|
|
5
3
|
*
|
|
@@ -16,3 +14,37 @@ export const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
|
|
|
16
14
|
*
|
|
17
15
|
* Thanks!
|
|
18
16
|
*/
|
|
17
|
+
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
|
|
18
|
+
|
|
19
|
+
const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
|
|
20
|
+
|
|
21
|
+
const makeBody = (params: BodyParams): Record<string, unknown> => {
|
|
22
|
+
return {
|
|
23
|
+
...params.args,
|
|
24
|
+
model: params.model,
|
|
25
|
+
};
|
|
26
|
+
};
|
|
27
|
+
|
|
28
|
+
const makeHeaders = (params: HeaderParams): Record<string, string> => {
|
|
29
|
+
return { Authorization: `Bearer ${params.accessToken}` };
|
|
30
|
+
};
|
|
31
|
+
|
|
32
|
+
const makeUrl = (params: UrlParams): string => {
|
|
33
|
+
if (params.task === "text-to-image") {
|
|
34
|
+
return `${params.baseUrl}/v1/images/generations`;
|
|
35
|
+
}
|
|
36
|
+
if (params.task === "text-generation") {
|
|
37
|
+
if (params.chatCompletion) {
|
|
38
|
+
return `${params.baseUrl}/v1/chat/completions`;
|
|
39
|
+
}
|
|
40
|
+
return `${params.baseUrl}/v1/completions`;
|
|
41
|
+
}
|
|
42
|
+
return params.baseUrl;
|
|
43
|
+
};
|
|
44
|
+
|
|
45
|
+
export const NEBIUS_CONFIG: ProviderConfig = {
|
|
46
|
+
baseUrl: NEBIUS_API_BASE_URL,
|
|
47
|
+
makeBody,
|
|
48
|
+
makeHeaders,
|
|
49
|
+
makeUrl,
|
|
50
|
+
};
|
package/src/providers/novita.ts
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
export const NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
|
|
2
|
-
|
|
3
1
|
/**
|
|
4
2
|
* See the registered mapping of HF model ID => Novita model ID here:
|
|
5
3
|
*
|
|
@@ -16,3 +14,34 @@ export const NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
|
|
|
16
14
|
*
|
|
17
15
|
* Thanks!
|
|
18
16
|
*/
|
|
17
|
+
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
|
|
18
|
+
|
|
19
|
+
const NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
|
|
20
|
+
|
|
21
|
+
const makeBody = (params: BodyParams): Record<string, unknown> => {
|
|
22
|
+
return {
|
|
23
|
+
...params.args,
|
|
24
|
+
...(params.chatCompletion ? { model: params.model } : undefined),
|
|
25
|
+
};
|
|
26
|
+
};
|
|
27
|
+
|
|
28
|
+
const makeHeaders = (params: HeaderParams): Record<string, string> => {
|
|
29
|
+
return { Authorization: `Bearer ${params.accessToken}` };
|
|
30
|
+
};
|
|
31
|
+
|
|
32
|
+
const makeUrl = (params: UrlParams): string => {
|
|
33
|
+
if (params.task === "text-generation") {
|
|
34
|
+
if (params.chatCompletion) {
|
|
35
|
+
return `${params.baseUrl}/chat/completions`;
|
|
36
|
+
}
|
|
37
|
+
return `${params.baseUrl}/completions`;
|
|
38
|
+
}
|
|
39
|
+
return params.baseUrl;
|
|
40
|
+
};
|
|
41
|
+
|
|
42
|
+
export const NOVITA_CONFIG: ProviderConfig = {
|
|
43
|
+
baseUrl: NOVITA_API_BASE_URL,
|
|
44
|
+
makeBody,
|
|
45
|
+
makeHeaders,
|
|
46
|
+
makeUrl,
|
|
47
|
+
};
|