@huggingface/inference 3.7.0 → 3.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.cjs +1369 -941
- package/dist/index.js +1371 -943
- package/dist/src/lib/getInferenceProviderMapping.d.ts +21 -0
- package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -0
- package/dist/src/lib/getProviderHelper.d.ts +37 -0
- package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
- package/dist/src/lib/makeRequestOptions.d.ts +5 -5
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/black-forest-labs.d.ts +14 -18
- package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
- package/dist/src/providers/cerebras.d.ts +4 -2
- package/dist/src/providers/cerebras.d.ts.map +1 -1
- package/dist/src/providers/cohere.d.ts +5 -2
- package/dist/src/providers/cohere.d.ts.map +1 -1
- package/dist/src/providers/consts.d.ts +2 -3
- package/dist/src/providers/consts.d.ts.map +1 -1
- package/dist/src/providers/fal-ai.d.ts +50 -3
- package/dist/src/providers/fal-ai.d.ts.map +1 -1
- package/dist/src/providers/fireworks-ai.d.ts +5 -2
- package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
- package/dist/src/providers/hf-inference.d.ts +126 -2
- package/dist/src/providers/hf-inference.d.ts.map +1 -1
- package/dist/src/providers/hyperbolic.d.ts +31 -2
- package/dist/src/providers/hyperbolic.d.ts.map +1 -1
- package/dist/src/providers/nebius.d.ts +20 -18
- package/dist/src/providers/nebius.d.ts.map +1 -1
- package/dist/src/providers/novita.d.ts +21 -18
- package/dist/src/providers/novita.d.ts.map +1 -1
- package/dist/src/providers/openai.d.ts +4 -2
- package/dist/src/providers/openai.d.ts.map +1 -1
- package/dist/src/providers/providerHelper.d.ts +182 -0
- package/dist/src/providers/providerHelper.d.ts.map +1 -0
- package/dist/src/providers/replicate.d.ts +23 -19
- package/dist/src/providers/replicate.d.ts.map +1 -1
- package/dist/src/providers/sambanova.d.ts +4 -2
- package/dist/src/providers/sambanova.d.ts.map +1 -1
- package/dist/src/providers/together.d.ts +32 -2
- package/dist/src/providers/together.d.ts.map +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts +2 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/audio/utils.d.ts +2 -1
- package/dist/src/tasks/audio/utils.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts +0 -2
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts +0 -2
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/index.d.ts +6 -6
- package/dist/src/tasks/index.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
- package/dist/src/types.d.ts +5 -13
- package/dist/src/types.d.ts.map +1 -1
- package/dist/src/utils/request.d.ts +3 -2
- package/dist/src/utils/request.d.ts.map +1 -1
- package/package.json +3 -3
- package/src/lib/getInferenceProviderMapping.ts +96 -0
- package/src/lib/getProviderHelper.ts +270 -0
- package/src/lib/makeRequestOptions.ts +78 -97
- package/src/providers/black-forest-labs.ts +73 -22
- package/src/providers/cerebras.ts +6 -27
- package/src/providers/cohere.ts +9 -28
- package/src/providers/consts.ts +5 -2
- package/src/providers/fal-ai.ts +224 -77
- package/src/providers/fireworks-ai.ts +8 -29
- package/src/providers/hf-inference.ts +557 -34
- package/src/providers/hyperbolic.ts +107 -29
- package/src/providers/nebius.ts +65 -29
- package/src/providers/novita.ts +68 -32
- package/src/providers/openai.ts +6 -32
- package/src/providers/providerHelper.ts +354 -0
- package/src/providers/replicate.ts +124 -34
- package/src/providers/sambanova.ts +5 -30
- package/src/providers/together.ts +92 -28
- package/src/snippets/getInferenceSnippets.ts +39 -14
- package/src/snippets/templates.exported.ts +25 -25
- package/src/tasks/audio/audioClassification.ts +5 -8
- package/src/tasks/audio/audioToAudio.ts +4 -27
- package/src/tasks/audio/automaticSpeechRecognition.ts +5 -4
- package/src/tasks/audio/textToSpeech.ts +5 -29
- package/src/tasks/audio/utils.ts +2 -1
- package/src/tasks/custom/request.ts +3 -3
- package/src/tasks/custom/streamingRequest.ts +4 -3
- package/src/tasks/cv/imageClassification.ts +4 -8
- package/src/tasks/cv/imageSegmentation.ts +4 -9
- package/src/tasks/cv/imageToImage.ts +4 -7
- package/src/tasks/cv/imageToText.ts +4 -7
- package/src/tasks/cv/objectDetection.ts +4 -19
- package/src/tasks/cv/textToImage.ts +9 -137
- package/src/tasks/cv/textToVideo.ts +17 -64
- package/src/tasks/cv/zeroShotImageClassification.ts +4 -8
- package/src/tasks/index.ts +6 -6
- package/src/tasks/multimodal/documentQuestionAnswering.ts +4 -19
- package/src/tasks/multimodal/visualQuestionAnswering.ts +4 -12
- package/src/tasks/nlp/chatCompletion.ts +5 -20
- package/src/tasks/nlp/chatCompletionStream.ts +4 -3
- package/src/tasks/nlp/featureExtraction.ts +4 -19
- package/src/tasks/nlp/fillMask.ts +4 -17
- package/src/tasks/nlp/questionAnswering.ts +11 -26
- package/src/tasks/nlp/sentenceSimilarity.ts +4 -8
- package/src/tasks/nlp/summarization.ts +4 -7
- package/src/tasks/nlp/tableQuestionAnswering.ts +10 -30
- package/src/tasks/nlp/textClassification.ts +4 -9
- package/src/tasks/nlp/textGeneration.ts +11 -79
- package/src/tasks/nlp/textGenerationStream.ts +3 -1
- package/src/tasks/nlp/tokenClassification.ts +11 -23
- package/src/tasks/nlp/translation.ts +4 -7
- package/src/tasks/nlp/zeroShotClassification.ts +11 -21
- package/src/tasks/tabular/tabularClassification.ts +4 -7
- package/src/tasks/tabular/tabularRegression.ts +4 -7
- package/src/types.ts +5 -14
- package/src/utils/request.ts +7 -4
- package/dist/src/lib/getProviderModelId.d.ts +0 -10
- package/dist/src/lib/getProviderModelId.d.ts.map +0 -1
- package/src/lib/getProviderModelId.ts +0 -74
|
@@ -1,23 +1,10 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import {
|
|
3
|
-
import {
|
|
4
|
-
import {
|
|
5
|
-
import {
|
|
6
|
-
import {
|
|
7
|
-
import { HF_INFERENCE_CONFIG } from "../providers/hf-inference";
|
|
8
|
-
import { HYPERBOLIC_CONFIG } from "../providers/hyperbolic";
|
|
9
|
-
import { NEBIUS_CONFIG } from "../providers/nebius";
|
|
10
|
-
import { NOVITA_CONFIG } from "../providers/novita";
|
|
11
|
-
import { REPLICATE_CONFIG } from "../providers/replicate";
|
|
12
|
-
import { SAMBANOVA_CONFIG } from "../providers/sambanova";
|
|
13
|
-
import { TOGETHER_CONFIG } from "../providers/together";
|
|
14
|
-
import { OPENAI_CONFIG } from "../providers/openai";
|
|
15
|
-
import type { InferenceProvider, InferenceTask, Options, ProviderConfig, RequestArgs } from "../types";
|
|
1
|
+
import { name as packageName, version as packageVersion } from "../../package.json";
|
|
2
|
+
import { HF_HEADER_X_BILL_TO, HF_HUB_URL } from "../config";
|
|
3
|
+
import type { InferenceTask, Options, RequestArgs } from "../types";
|
|
4
|
+
import type { InferenceProviderModelMapping } from "./getInferenceProviderMapping";
|
|
5
|
+
import { getInferenceProviderMapping } from "./getInferenceProviderMapping";
|
|
6
|
+
import type { getProviderHelper } from "./getProviderHelper";
|
|
16
7
|
import { isUrl } from "./isUrl";
|
|
17
|
-
import { version as packageVersion, name as packageName } from "../../package.json";
|
|
18
|
-
import { getProviderModelId } from "./getProviderModelId";
|
|
19
|
-
|
|
20
|
-
const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
|
|
21
8
|
|
|
22
9
|
/**
|
|
23
10
|
* Lazy-loaded from huggingface.co/api/tasks when needed
|
|
@@ -25,25 +12,6 @@ const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
|
|
|
25
12
|
*/
|
|
26
13
|
let tasks: Record<string, { models: { id: string }[] }> | null = null;
|
|
27
14
|
|
|
28
|
-
/**
|
|
29
|
-
* Config to define how to serialize requests for each provider
|
|
30
|
-
*/
|
|
31
|
-
const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
|
|
32
|
-
"black-forest-labs": BLACK_FOREST_LABS_CONFIG,
|
|
33
|
-
cerebras: CEREBRAS_CONFIG,
|
|
34
|
-
cohere: COHERE_CONFIG,
|
|
35
|
-
"fal-ai": FAL_AI_CONFIG,
|
|
36
|
-
"fireworks-ai": FIREWORKS_AI_CONFIG,
|
|
37
|
-
"hf-inference": HF_INFERENCE_CONFIG,
|
|
38
|
-
hyperbolic: HYPERBOLIC_CONFIG,
|
|
39
|
-
openai: OPENAI_CONFIG,
|
|
40
|
-
nebius: NEBIUS_CONFIG,
|
|
41
|
-
novita: NOVITA_CONFIG,
|
|
42
|
-
replicate: REPLICATE_CONFIG,
|
|
43
|
-
sambanova: SAMBANOVA_CONFIG,
|
|
44
|
-
together: TOGETHER_CONFIG,
|
|
45
|
-
};
|
|
46
|
-
|
|
47
15
|
/**
|
|
48
16
|
* Helper that prepares request arguments.
|
|
49
17
|
* This async version handle the model ID resolution step.
|
|
@@ -53,16 +21,15 @@ export async function makeRequestOptions(
|
|
|
53
21
|
data?: Blob | ArrayBuffer;
|
|
54
22
|
stream?: boolean;
|
|
55
23
|
},
|
|
24
|
+
providerHelper: ReturnType<typeof getProviderHelper>,
|
|
56
25
|
options?: Options & {
|
|
57
26
|
/** In most cases (unless we pass a endpointUrl) we know the task */
|
|
58
27
|
task?: InferenceTask;
|
|
59
|
-
chatCompletion?: boolean;
|
|
60
28
|
}
|
|
61
29
|
): Promise<{ url: string; info: RequestInit }> {
|
|
62
30
|
const { provider: maybeProvider, model: maybeModel } = args;
|
|
63
31
|
const provider = maybeProvider ?? "hf-inference";
|
|
64
|
-
const
|
|
65
|
-
const { task, chatCompletion } = options ?? {};
|
|
32
|
+
const { task } = options ?? {};
|
|
66
33
|
|
|
67
34
|
// Validate inputs
|
|
68
35
|
if (args.endpointUrl && provider !== "hf-inference") {
|
|
@@ -71,29 +38,61 @@ export async function makeRequestOptions(
|
|
|
71
38
|
if (maybeModel && isUrl(maybeModel)) {
|
|
72
39
|
throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
|
|
73
40
|
}
|
|
41
|
+
|
|
42
|
+
if (args.endpointUrl) {
|
|
43
|
+
// No need to have maybeModel, or to load default model for a task
|
|
44
|
+
return makeRequestOptionsFromResolvedModel(
|
|
45
|
+
maybeModel ?? args.endpointUrl,
|
|
46
|
+
providerHelper,
|
|
47
|
+
args,
|
|
48
|
+
undefined,
|
|
49
|
+
options
|
|
50
|
+
);
|
|
51
|
+
}
|
|
52
|
+
|
|
74
53
|
if (!maybeModel && !task) {
|
|
75
54
|
throw new Error("No model provided, and no task has been specified.");
|
|
76
55
|
}
|
|
77
|
-
if (!providerConfig) {
|
|
78
|
-
throw new Error(`No provider config found for provider ${provider}`);
|
|
79
|
-
}
|
|
80
|
-
if (providerConfig.clientSideRoutingOnly && !maybeModel) {
|
|
81
|
-
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
|
|
82
|
-
}
|
|
83
56
|
|
|
84
57
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
85
58
|
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
59
|
+
|
|
60
|
+
if (providerHelper.clientSideRoutingOnly && !maybeModel) {
|
|
61
|
+
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
const inferenceProviderMapping = providerHelper.clientSideRoutingOnly
|
|
65
|
+
? ({
|
|
66
|
+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
67
|
+
providerId: removeProviderPrefix(maybeModel!, provider),
|
|
68
|
+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
69
|
+
hfModelId: maybeModel!,
|
|
70
|
+
status: "live",
|
|
71
|
+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
72
|
+
task: task!,
|
|
73
|
+
} satisfies InferenceProviderModelMapping)
|
|
74
|
+
: await getInferenceProviderMapping(
|
|
75
|
+
{
|
|
76
|
+
modelId: hfModel,
|
|
77
|
+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
78
|
+
task: task!,
|
|
79
|
+
provider,
|
|
80
|
+
accessToken: args.accessToken,
|
|
81
|
+
},
|
|
82
|
+
{ fetch: options?.fetch }
|
|
83
|
+
);
|
|
84
|
+
if (!inferenceProviderMapping) {
|
|
85
|
+
throw new Error(`We have not been able to find inference provider information for model ${hfModel}.`);
|
|
86
|
+
}
|
|
94
87
|
|
|
95
88
|
// Use the sync version with the resolved model
|
|
96
|
-
return makeRequestOptionsFromResolvedModel(
|
|
89
|
+
return makeRequestOptionsFromResolvedModel(
|
|
90
|
+
inferenceProviderMapping.providerId,
|
|
91
|
+
providerHelper,
|
|
92
|
+
args,
|
|
93
|
+
inferenceProviderMapping,
|
|
94
|
+
options
|
|
95
|
+
);
|
|
97
96
|
}
|
|
98
97
|
|
|
99
98
|
/**
|
|
@@ -102,25 +101,24 @@ export async function makeRequestOptions(
|
|
|
102
101
|
*/
|
|
103
102
|
export function makeRequestOptionsFromResolvedModel(
|
|
104
103
|
resolvedModel: string,
|
|
104
|
+
providerHelper: ReturnType<typeof getProviderHelper>,
|
|
105
105
|
args: RequestArgs & {
|
|
106
106
|
data?: Blob | ArrayBuffer;
|
|
107
107
|
stream?: boolean;
|
|
108
108
|
},
|
|
109
|
+
mapping: InferenceProviderModelMapping | undefined,
|
|
109
110
|
options?: Options & {
|
|
110
111
|
task?: InferenceTask;
|
|
111
|
-
chatCompletion?: boolean;
|
|
112
112
|
}
|
|
113
113
|
): { url: string; info: RequestInit } {
|
|
114
114
|
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
|
|
115
115
|
void model;
|
|
116
116
|
|
|
117
117
|
const provider = maybeProvider ?? "hf-inference";
|
|
118
|
-
const providerConfig = providerConfigs[provider];
|
|
119
|
-
|
|
120
|
-
const { includeCredentials, task, chatCompletion, signal, billTo } = options ?? {};
|
|
121
118
|
|
|
119
|
+
const { includeCredentials, task, signal, billTo } = options ?? {};
|
|
122
120
|
const authMethod = (() => {
|
|
123
|
-
if (
|
|
121
|
+
if (providerHelper.clientSideRoutingOnly) {
|
|
124
122
|
// Closed-source providers require an accessToken (cannot be routed).
|
|
125
123
|
if (accessToken && accessToken.startsWith("hf_")) {
|
|
126
124
|
throw new Error(`Provider ${provider} is closed-source and does not support HF tokens.`);
|
|
@@ -138,36 +136,25 @@ export function makeRequestOptionsFromResolvedModel(
|
|
|
138
136
|
})();
|
|
139
137
|
|
|
140
138
|
// Make URL
|
|
141
|
-
const url = endpointUrl
|
|
142
|
-
? chatCompletion
|
|
143
|
-
? endpointUrl + `/v1/chat/completions`
|
|
144
|
-
: endpointUrl
|
|
145
|
-
: providerConfig.makeUrl({
|
|
146
|
-
authMethod,
|
|
147
|
-
baseUrl:
|
|
148
|
-
authMethod !== "provider-key"
|
|
149
|
-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider)
|
|
150
|
-
: providerConfig.makeBaseUrl(task),
|
|
151
|
-
model: resolvedModel,
|
|
152
|
-
chatCompletion,
|
|
153
|
-
task,
|
|
154
|
-
});
|
|
155
139
|
|
|
156
|
-
|
|
157
|
-
const
|
|
158
|
-
const headers = providerConfig.makeHeaders({
|
|
159
|
-
accessToken,
|
|
140
|
+
const modelId = endpointUrl ?? resolvedModel;
|
|
141
|
+
const url = providerHelper.makeUrl({
|
|
160
142
|
authMethod,
|
|
143
|
+
model: modelId,
|
|
144
|
+
task,
|
|
161
145
|
});
|
|
146
|
+
// Make headers
|
|
147
|
+
const headers = providerHelper.prepareHeaders(
|
|
148
|
+
{
|
|
149
|
+
accessToken,
|
|
150
|
+
authMethod,
|
|
151
|
+
},
|
|
152
|
+
"data" in args && !!args.data
|
|
153
|
+
);
|
|
162
154
|
if (billTo) {
|
|
163
155
|
headers[HF_HEADER_X_BILL_TO] = billTo;
|
|
164
156
|
}
|
|
165
157
|
|
|
166
|
-
// Add content-type to headers
|
|
167
|
-
if (!binary) {
|
|
168
|
-
headers["Content-Type"] = "application/json";
|
|
169
|
-
}
|
|
170
|
-
|
|
171
158
|
// Add user-agent to headers
|
|
172
159
|
// e.g. @huggingface/inference/3.1.3
|
|
173
160
|
const ownUserAgent = `${packageName}/${packageVersion}`;
|
|
@@ -177,17 +164,12 @@ export function makeRequestOptionsFromResolvedModel(
|
|
|
177
164
|
headers["User-Agent"] = userAgent;
|
|
178
165
|
|
|
179
166
|
// Make body
|
|
180
|
-
const body =
|
|
181
|
-
|
|
182
|
-
:
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
task,
|
|
187
|
-
chatCompletion,
|
|
188
|
-
})
|
|
189
|
-
);
|
|
190
|
-
|
|
167
|
+
const body = providerHelper.makeBody({
|
|
168
|
+
args: remainingArgs as Record<string, unknown>,
|
|
169
|
+
model: resolvedModel,
|
|
170
|
+
task,
|
|
171
|
+
mapping,
|
|
172
|
+
});
|
|
191
173
|
/**
|
|
192
174
|
* For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
|
|
193
175
|
*/
|
|
@@ -201,11 +183,10 @@ export function makeRequestOptionsFromResolvedModel(
|
|
|
201
183
|
const info: RequestInit = {
|
|
202
184
|
headers,
|
|
203
185
|
method: "POST",
|
|
204
|
-
body,
|
|
186
|
+
body: body,
|
|
205
187
|
...(credentials ? { credentials } : undefined),
|
|
206
188
|
signal,
|
|
207
189
|
};
|
|
208
|
-
|
|
209
190
|
return { url, info };
|
|
210
191
|
}
|
|
211
192
|
|
|
@@ -14,33 +14,84 @@
|
|
|
14
14
|
*
|
|
15
15
|
* Thanks!
|
|
16
16
|
*/
|
|
17
|
-
import
|
|
17
|
+
import { InferenceOutputError } from "../lib/InferenceOutputError";
|
|
18
|
+
import type { BodyParams, HeaderParams, UrlParams } from "../types";
|
|
19
|
+
import { delay } from "../utils/delay";
|
|
20
|
+
import { omit } from "../utils/omit";
|
|
21
|
+
import { TaskProviderHelper, type TextToImageTaskHelper } from "./providerHelper";
|
|
18
22
|
|
|
19
23
|
const BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
|
|
24
|
+
interface BlackForestLabsResponse {
|
|
25
|
+
id: string;
|
|
26
|
+
polling_url: string;
|
|
27
|
+
}
|
|
20
28
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
29
|
+
export class BlackForestLabsTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper {
|
|
30
|
+
constructor() {
|
|
31
|
+
super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
|
|
32
|
+
}
|
|
24
33
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
34
|
+
preparePayload(params: BodyParams): Record<string, unknown> {
|
|
35
|
+
return {
|
|
36
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
37
|
+
...(params.args.parameters as Record<string, unknown>),
|
|
38
|
+
prompt: params.args.inputs,
|
|
39
|
+
};
|
|
40
|
+
}
|
|
28
41
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
42
|
+
override prepareHeaders(params: HeaderParams, binary: boolean): Record<string, string> {
|
|
43
|
+
const headers: Record<string, string> = {
|
|
44
|
+
Authorization:
|
|
45
|
+
params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`,
|
|
46
|
+
};
|
|
47
|
+
if (!binary) {
|
|
48
|
+
headers["Content-Type"] = "application/json";
|
|
49
|
+
}
|
|
50
|
+
return headers;
|
|
34
51
|
}
|
|
35
|
-
};
|
|
36
52
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
53
|
+
makeRoute(params: UrlParams): string {
|
|
54
|
+
if (!params) {
|
|
55
|
+
throw new Error("Params are required");
|
|
56
|
+
}
|
|
57
|
+
return `/v1/${params.model}`;
|
|
58
|
+
}
|
|
40
59
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
60
|
+
async getResponse(
|
|
61
|
+
response: BlackForestLabsResponse,
|
|
62
|
+
url?: string,
|
|
63
|
+
headers?: HeadersInit,
|
|
64
|
+
outputType?: "url" | "blob"
|
|
65
|
+
): Promise<string | Blob> {
|
|
66
|
+
const urlObj = new URL(response.polling_url);
|
|
67
|
+
for (let step = 0; step < 5; step++) {
|
|
68
|
+
await delay(1000);
|
|
69
|
+
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
|
|
70
|
+
urlObj.searchParams.set("attempt", step.toString(10));
|
|
71
|
+
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
|
|
72
|
+
if (!resp.ok) {
|
|
73
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
74
|
+
}
|
|
75
|
+
const payload = await resp.json();
|
|
76
|
+
if (
|
|
77
|
+
typeof payload === "object" &&
|
|
78
|
+
payload &&
|
|
79
|
+
"status" in payload &&
|
|
80
|
+
typeof payload.status === "string" &&
|
|
81
|
+
payload.status === "Ready" &&
|
|
82
|
+
"result" in payload &&
|
|
83
|
+
typeof payload.result === "object" &&
|
|
84
|
+
payload.result &&
|
|
85
|
+
"sample" in payload.result &&
|
|
86
|
+
typeof payload.result.sample === "string"
|
|
87
|
+
) {
|
|
88
|
+
if (outputType === "url") {
|
|
89
|
+
return payload.result.sample;
|
|
90
|
+
}
|
|
91
|
+
const image = await fetch(payload.result.sample);
|
|
92
|
+
return await image.blob();
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
96
|
+
}
|
|
97
|
+
}
|
|
@@ -14,32 +14,11 @@
|
|
|
14
14
|
*
|
|
15
15
|
* Thanks!
|
|
16
16
|
*/
|
|
17
|
-
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
|
|
18
17
|
|
|
19
|
-
|
|
18
|
+
import { BaseConversationalTask } from "./providerHelper";
|
|
20
19
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
return {
|
|
27
|
-
...params.args,
|
|
28
|
-
model: params.model,
|
|
29
|
-
};
|
|
30
|
-
};
|
|
31
|
-
|
|
32
|
-
const makeHeaders = (params: HeaderParams): Record<string, string> => {
|
|
33
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
34
|
-
};
|
|
35
|
-
|
|
36
|
-
const makeUrl = (params: UrlParams): string => {
|
|
37
|
-
return `${params.baseUrl}/v1/chat/completions`;
|
|
38
|
-
};
|
|
39
|
-
|
|
40
|
-
export const CEREBRAS_CONFIG: ProviderConfig = {
|
|
41
|
-
makeBaseUrl,
|
|
42
|
-
makeBody,
|
|
43
|
-
makeHeaders,
|
|
44
|
-
makeUrl,
|
|
45
|
-
};
|
|
20
|
+
export class CerebrasConversationalTask extends BaseConversationalTask {
|
|
21
|
+
constructor() {
|
|
22
|
+
super("cerebras", "https://api.cerebras.ai");
|
|
23
|
+
}
|
|
24
|
+
}
|
package/src/providers/cohere.ts
CHANGED
|
@@ -14,32 +14,13 @@
|
|
|
14
14
|
*
|
|
15
15
|
* Thanks!
|
|
16
16
|
*/
|
|
17
|
-
import
|
|
17
|
+
import { BaseConversationalTask } from "./providerHelper";
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
...params.args,
|
|
28
|
-
model: params.model,
|
|
29
|
-
};
|
|
30
|
-
};
|
|
31
|
-
|
|
32
|
-
const makeHeaders = (params: HeaderParams): Record<string, string> => {
|
|
33
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
34
|
-
};
|
|
35
|
-
|
|
36
|
-
const makeUrl = (params: UrlParams): string => {
|
|
37
|
-
return `${params.baseUrl}/compatibility/v1/chat/completions`;
|
|
38
|
-
};
|
|
39
|
-
|
|
40
|
-
export const COHERE_CONFIG: ProviderConfig = {
|
|
41
|
-
makeBaseUrl,
|
|
42
|
-
makeBody,
|
|
43
|
-
makeHeaders,
|
|
44
|
-
makeUrl,
|
|
45
|
-
};
|
|
19
|
+
export class CohereConversationalTask extends BaseConversationalTask {
|
|
20
|
+
constructor() {
|
|
21
|
+
super("cohere", "https://api.cohere.com");
|
|
22
|
+
}
|
|
23
|
+
override makeRoute(): string {
|
|
24
|
+
return "/compatibility/v1/chat/completions";
|
|
25
|
+
}
|
|
26
|
+
}
|
package/src/providers/consts.ts
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
|
+
import type { InferenceProviderModelMapping } from "../lib/getInferenceProviderMapping";
|
|
1
2
|
import type { InferenceProvider } from "../types";
|
|
2
3
|
import { type ModelId } from "../types";
|
|
3
4
|
|
|
4
|
-
type ProviderId = string;
|
|
5
5
|
/**
|
|
6
6
|
* If you want to try to run inference for a new model locally before it's registered on huggingface.co
|
|
7
7
|
* for a given Inference Provider,
|
|
@@ -9,7 +9,10 @@ type ProviderId = string;
|
|
|
9
9
|
*
|
|
10
10
|
* We also inject into this dictionary from tests.
|
|
11
11
|
*/
|
|
12
|
-
export const
|
|
12
|
+
export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
|
|
13
|
+
InferenceProvider,
|
|
14
|
+
Record<ModelId, InferenceProviderModelMapping>
|
|
15
|
+
> = {
|
|
13
16
|
/**
|
|
14
17
|
* "HF model ID" => "Model ID on Inference Provider's side"
|
|
15
18
|
*
|