@huggingface/inference 3.7.0 → 3.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.cjs +1152 -839
- package/dist/index.js +1154 -841
- package/dist/src/lib/getProviderHelper.d.ts +37 -0
- package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
- package/dist/src/lib/makeRequestOptions.d.ts +0 -2
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/black-forest-labs.d.ts +14 -18
- package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
- package/dist/src/providers/cerebras.d.ts +4 -2
- package/dist/src/providers/cerebras.d.ts.map +1 -1
- package/dist/src/providers/cohere.d.ts +5 -2
- package/dist/src/providers/cohere.d.ts.map +1 -1
- package/dist/src/providers/fal-ai.d.ts +50 -3
- package/dist/src/providers/fal-ai.d.ts.map +1 -1
- package/dist/src/providers/fireworks-ai.d.ts +5 -2
- package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
- package/dist/src/providers/hf-inference.d.ts +125 -2
- package/dist/src/providers/hf-inference.d.ts.map +1 -1
- package/dist/src/providers/hyperbolic.d.ts +31 -2
- package/dist/src/providers/hyperbolic.d.ts.map +1 -1
- package/dist/src/providers/nebius.d.ts +20 -18
- package/dist/src/providers/nebius.d.ts.map +1 -1
- package/dist/src/providers/novita.d.ts +21 -18
- package/dist/src/providers/novita.d.ts.map +1 -1
- package/dist/src/providers/openai.d.ts +4 -2
- package/dist/src/providers/openai.d.ts.map +1 -1
- package/dist/src/providers/providerHelper.d.ts +182 -0
- package/dist/src/providers/providerHelper.d.ts.map +1 -0
- package/dist/src/providers/replicate.d.ts +23 -19
- package/dist/src/providers/replicate.d.ts.map +1 -1
- package/dist/src/providers/sambanova.d.ts +4 -2
- package/dist/src/providers/sambanova.d.ts.map +1 -1
- package/dist/src/providers/together.d.ts +32 -2
- package/dist/src/providers/together.d.ts.map +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/audio/utils.d.ts +2 -1
- package/dist/src/tasks/audio/utils.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts +0 -2
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts +0 -2
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/index.d.ts +6 -6
- package/dist/src/tasks/index.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
- package/dist/src/types.d.ts +3 -13
- package/dist/src/types.d.ts.map +1 -1
- package/package.json +3 -3
- package/src/lib/getProviderHelper.ts +270 -0
- package/src/lib/makeRequestOptions.ts +34 -91
- package/src/providers/black-forest-labs.ts +73 -22
- package/src/providers/cerebras.ts +6 -27
- package/src/providers/cohere.ts +9 -28
- package/src/providers/fal-ai.ts +195 -77
- package/src/providers/fireworks-ai.ts +8 -29
- package/src/providers/hf-inference.ts +555 -34
- package/src/providers/hyperbolic.ts +107 -29
- package/src/providers/nebius.ts +65 -29
- package/src/providers/novita.ts +68 -32
- package/src/providers/openai.ts +6 -32
- package/src/providers/providerHelper.ts +354 -0
- package/src/providers/replicate.ts +124 -34
- package/src/providers/sambanova.ts +5 -30
- package/src/providers/together.ts +92 -28
- package/src/snippets/getInferenceSnippets.ts +16 -9
- package/src/snippets/templates.exported.ts +1 -1
- package/src/tasks/audio/audioClassification.ts +4 -7
- package/src/tasks/audio/audioToAudio.ts +3 -26
- package/src/tasks/audio/automaticSpeechRecognition.ts +4 -3
- package/src/tasks/audio/textToSpeech.ts +5 -29
- package/src/tasks/audio/utils.ts +2 -1
- package/src/tasks/custom/request.ts +0 -2
- package/src/tasks/custom/streamingRequest.ts +0 -2
- package/src/tasks/cv/imageClassification.ts +3 -7
- package/src/tasks/cv/imageSegmentation.ts +3 -8
- package/src/tasks/cv/imageToImage.ts +3 -6
- package/src/tasks/cv/imageToText.ts +3 -6
- package/src/tasks/cv/objectDetection.ts +3 -18
- package/src/tasks/cv/textToImage.ts +9 -137
- package/src/tasks/cv/textToVideo.ts +11 -62
- package/src/tasks/cv/zeroShotImageClassification.ts +3 -7
- package/src/tasks/index.ts +6 -6
- package/src/tasks/multimodal/documentQuestionAnswering.ts +3 -19
- package/src/tasks/multimodal/visualQuestionAnswering.ts +3 -11
- package/src/tasks/nlp/chatCompletion.ts +5 -20
- package/src/tasks/nlp/chatCompletionStream.ts +1 -2
- package/src/tasks/nlp/featureExtraction.ts +3 -18
- package/src/tasks/nlp/fillMask.ts +3 -16
- package/src/tasks/nlp/questionAnswering.ts +3 -22
- package/src/tasks/nlp/sentenceSimilarity.ts +3 -7
- package/src/tasks/nlp/summarization.ts +3 -6
- package/src/tasks/nlp/tableQuestionAnswering.ts +3 -27
- package/src/tasks/nlp/textClassification.ts +3 -8
- package/src/tasks/nlp/textGeneration.ts +12 -79
- package/src/tasks/nlp/tokenClassification.ts +3 -18
- package/src/tasks/nlp/translation.ts +3 -6
- package/src/tasks/nlp/zeroShotClassification.ts +3 -16
- package/src/tasks/tabular/tabularClassification.ts +3 -6
- package/src/tasks/tabular/tabularRegression.ts +3 -6
- package/src/types.ts +3 -14
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"tabularRegression.d.ts","sourceRoot":"","sources":["../../../../src/tasks/tabular/tabularRegression.ts"],"names":[],"mappings":"AACA,OAAO,KAAK,EAAE,QAAQ,EAAE,OAAO,EAAE,MAAM,aAAa,CAAC;AAGrD,MAAM,MAAM,qBAAqB,GAAG,QAAQ,GAAG;IAC9C,MAAM,EAAE;QACP;;WAEG;QACH,IAAI,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,EAAE,CAAC,CAAC;KAC/B,CAAC;CACF,CAAC;AAEF;;GAEG;AACH,MAAM,MAAM,uBAAuB,GAAG,MAAM,EAAE,CAAC;AAE/C;;;;GAIG;AACH,wBAAsB,iBAAiB,CACtC,IAAI,EAAE,qBAAqB,EAC3B,OAAO,CAAC,EAAE,OAAO,GACf,OAAO,CAAC,uBAAuB,CAAC,
|
|
1
|
+
{"version":3,"file":"tabularRegression.d.ts","sourceRoot":"","sources":["../../../../src/tasks/tabular/tabularRegression.ts"],"names":[],"mappings":"AACA,OAAO,KAAK,EAAE,QAAQ,EAAE,OAAO,EAAE,MAAM,aAAa,CAAC;AAGrD,MAAM,MAAM,qBAAqB,GAAG,QAAQ,GAAG;IAC9C,MAAM,EAAE;QACP;;WAEG;QACH,IAAI,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,EAAE,CAAC,CAAC;KAC/B,CAAC;CACF,CAAC;AAEF;;GAEG;AACH,MAAM,MAAM,uBAAuB,GAAG,MAAM,EAAE,CAAC;AAE/C;;;;GAIG;AACH,wBAAsB,iBAAiB,CACtC,IAAI,EAAE,qBAAqB,EAC3B,OAAO,CAAC,EAAE,OAAO,GACf,OAAO,CAAC,uBAAuB,CAAC,CAOlC"}
|
package/dist/src/types.d.ts
CHANGED
|
@@ -28,7 +28,7 @@ export interface Options {
|
|
|
28
28
|
*/
|
|
29
29
|
billTo?: string;
|
|
30
30
|
}
|
|
31
|
-
export type InferenceTask = Exclude<PipelineType, "other"
|
|
31
|
+
export type InferenceTask = Exclude<PipelineType, "other"> | "conversational";
|
|
32
32
|
export declare const INFERENCE_PROVIDERS: readonly ["black-forest-labs", "cerebras", "cohere", "fal-ai", "fireworks-ai", "hf-inference", "hyperbolic", "nebius", "novita", "openai", "replicate", "sambanova", "together"];
|
|
33
33
|
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
|
|
34
34
|
export interface BaseArgs {
|
|
@@ -75,13 +75,6 @@ export type RequestArgs = BaseArgs & ({
|
|
|
75
75
|
} | ChatCompletionInput) & {
|
|
76
76
|
parameters?: Record<string, unknown>;
|
|
77
77
|
};
|
|
78
|
-
export interface ProviderConfig {
|
|
79
|
-
makeBaseUrl: ((task?: InferenceTask) => string) | (() => string);
|
|
80
|
-
makeBody: (params: BodyParams) => Record<string, unknown>;
|
|
81
|
-
makeHeaders: (params: HeaderParams) => Record<string, string>;
|
|
82
|
-
makeUrl: (params: UrlParams) => string;
|
|
83
|
-
clientSideRoutingOnly?: boolean;
|
|
84
|
-
}
|
|
85
78
|
export type AuthMethod = "none" | "hf-token" | "credentials-include" | "provider-key";
|
|
86
79
|
export interface HeaderParams {
|
|
87
80
|
accessToken?: string;
|
|
@@ -89,14 +82,11 @@ export interface HeaderParams {
|
|
|
89
82
|
}
|
|
90
83
|
export interface UrlParams {
|
|
91
84
|
authMethod: AuthMethod;
|
|
92
|
-
baseUrl: string;
|
|
93
85
|
model: string;
|
|
94
86
|
task?: InferenceTask;
|
|
95
|
-
chatCompletion?: boolean;
|
|
96
87
|
}
|
|
97
|
-
export interface BodyParams {
|
|
98
|
-
args:
|
|
99
|
-
chatCompletion?: boolean;
|
|
88
|
+
export interface BodyParams<T extends Record<string, unknown> = Record<string, unknown>> {
|
|
89
|
+
args: T;
|
|
100
90
|
model: string;
|
|
101
91
|
task?: InferenceTask;
|
|
102
92
|
}
|
package/dist/src/types.d.ts.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"types.d.ts","sourceRoot":"","sources":["../../src/types.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,mBAAmB,EAAE,YAAY,EAAE,MAAM,oBAAoB,CAAC;AAE5E;;GAEG;AACH,MAAM,MAAM,OAAO,GAAG,MAAM,CAAC;AAE7B,MAAM,WAAW,OAAO;IACvB;;OAEG;IACH,cAAc,CAAC,EAAE,OAAO,CAAC;IAEzB;;OAEG;IACH,KAAK,CAAC,EAAE,OAAO,KAAK,CAAC;IACrB;;OAEG;IACH,MAAM,CAAC,EAAE,WAAW,CAAC;IAErB;;OAEG;IACH,kBAAkB,CAAC,EAAE,MAAM,GAAG,OAAO,CAAC;IAEtC;;;;;OAKG;IACH,MAAM,CAAC,EAAE,MAAM,CAAC;CAChB;AAED,MAAM,MAAM,aAAa,GAAG,OAAO,CAAC,YAAY,EAAE,OAAO,CAAC,CAAC;
|
|
1
|
+
{"version":3,"file":"types.d.ts","sourceRoot":"","sources":["../../src/types.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,mBAAmB,EAAE,YAAY,EAAE,MAAM,oBAAoB,CAAC;AAE5E;;GAEG;AACH,MAAM,MAAM,OAAO,GAAG,MAAM,CAAC;AAE7B,MAAM,WAAW,OAAO;IACvB;;OAEG;IACH,cAAc,CAAC,EAAE,OAAO,CAAC;IAEzB;;OAEG;IACH,KAAK,CAAC,EAAE,OAAO,KAAK,CAAC;IACrB;;OAEG;IACH,MAAM,CAAC,EAAE,WAAW,CAAC;IAErB;;OAEG;IACH,kBAAkB,CAAC,EAAE,MAAM,GAAG,OAAO,CAAC;IAEtC;;;;;OAKG;IACH,MAAM,CAAC,EAAE,MAAM,CAAC;CAChB;AAED,MAAM,MAAM,aAAa,GAAG,OAAO,CAAC,YAAY,EAAE,OAAO,CAAC,GAAG,gBAAgB,CAAC;AAE9E,eAAO,MAAM,mBAAmB,kLActB,CAAC;AAEX,MAAM,MAAM,iBAAiB,GAAG,CAAC,OAAO,mBAAmB,CAAC,CAAC,MAAM,CAAC,CAAC;AAErE,MAAM,WAAW,QAAQ;IACxB;;;;;;OAMG;IACH,WAAW,CAAC,EAAE,MAAM,CAAC;IAErB;;;;;;;OAOG;IACH,KAAK,CAAC,EAAE,OAAO,CAAC;IAEhB;;;;OAIG;IACH,WAAW,CAAC,EAAE,MAAM,CAAC;IAErB;;;;OAIG;IACH,QAAQ,CAAC,EAAE,iBAAiB,CAAC;CAC7B;AAED,MAAM,MAAM,WAAW,GAAG,QAAQ,GACjC,CACG;IAAE,IAAI,EAAE,IAAI,GAAG,WAAW,CAAA;CAAE,GAC5B;IAAE,MAAM,EAAE,OAAO,CAAA;CAAE,GACnB;IAAE,MAAM,EAAE,MAAM,CAAA;CAAE,GAClB;IAAE,IAAI,EAAE,MAAM,CAAA;CAAE,GAChB;IAAE,SAAS,EAAE,MAAM,CAAA;CAAE,GACrB,mBAAmB,CACrB,GAAG;IACH,UAAU,CAAC,EAAE,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;CACrC,CAAC;AAEH,MAAM,MAAM,UAAU,GAAG,MAAM,GAAG,UAAU,GAAG,qBAAqB,GAAG,cAAc,CAAC;AAEtF,MAAM,WAAW,YAAY;IAC5B,WAAW,CAAC,EAAE,MAAM,CAAC;IACrB,UAAU,EAAE,UAAU,CAAC;CACvB;AAED,MAAM,WAAW,SAAS;IACzB,UAAU,EAAE,UAAU,CAAC;IACvB,KAAK,EAAE,MAAM,CAAC;IACd,IAAI,CAAC,EAAE,aAAa,CAAC;CACrB;AAED,MAAM,WAAW,UAAU,CAAC,CAAC,SAAS,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC;IACtF,IAAI,EAAE,CAAC,CAAC;IACR,KAAK,EAAE,MAAM,CAAC;IACd,IAAI,CAAC,EAAE,aAAa,CAAC;CACrB"}
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@huggingface/inference",
|
|
3
|
-
"version": "3.7.
|
|
3
|
+
"version": "3.7.1",
|
|
4
4
|
"packageManager": "pnpm@8.10.5",
|
|
5
5
|
"license": "MIT",
|
|
6
6
|
"author": "Hugging Face and Tim Mikeladze <tim.mikeladze@gmail.com>",
|
|
@@ -40,8 +40,8 @@
|
|
|
40
40
|
},
|
|
41
41
|
"type": "module",
|
|
42
42
|
"dependencies": {
|
|
43
|
-
"@huggingface/
|
|
44
|
-
"@huggingface/
|
|
43
|
+
"@huggingface/jinja": "^0.3.3",
|
|
44
|
+
"@huggingface/tasks": "^0.18.7"
|
|
45
45
|
},
|
|
46
46
|
"devDependencies": {
|
|
47
47
|
"@types/node": "18.13.0"
|
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
import * as BlackForestLabs from "../providers/black-forest-labs";
|
|
2
|
+
import * as Cerebras from "../providers/cerebras";
|
|
3
|
+
import * as Cohere from "../providers/cohere";
|
|
4
|
+
import * as FalAI from "../providers/fal-ai";
|
|
5
|
+
import * as Fireworks from "../providers/fireworks-ai";
|
|
6
|
+
import * as HFInference from "../providers/hf-inference";
|
|
7
|
+
|
|
8
|
+
import * as Hyperbolic from "../providers/hyperbolic";
|
|
9
|
+
import * as Nebius from "../providers/nebius";
|
|
10
|
+
import * as Novita from "../providers/novita";
|
|
11
|
+
import * as OpenAI from "../providers/openai";
|
|
12
|
+
import type {
|
|
13
|
+
AudioClassificationTaskHelper,
|
|
14
|
+
AudioToAudioTaskHelper,
|
|
15
|
+
AutomaticSpeechRecognitionTaskHelper,
|
|
16
|
+
ConversationalTaskHelper,
|
|
17
|
+
DocumentQuestionAnsweringTaskHelper,
|
|
18
|
+
FeatureExtractionTaskHelper,
|
|
19
|
+
FillMaskTaskHelper,
|
|
20
|
+
ImageClassificationTaskHelper,
|
|
21
|
+
ImageSegmentationTaskHelper,
|
|
22
|
+
ImageToImageTaskHelper,
|
|
23
|
+
ImageToTextTaskHelper,
|
|
24
|
+
ObjectDetectionTaskHelper,
|
|
25
|
+
QuestionAnsweringTaskHelper,
|
|
26
|
+
SentenceSimilarityTaskHelper,
|
|
27
|
+
SummarizationTaskHelper,
|
|
28
|
+
TableQuestionAnsweringTaskHelper,
|
|
29
|
+
TabularClassificationTaskHelper,
|
|
30
|
+
TabularRegressionTaskHelper,
|
|
31
|
+
TaskProviderHelper,
|
|
32
|
+
TextClassificationTaskHelper,
|
|
33
|
+
TextGenerationTaskHelper,
|
|
34
|
+
TextToAudioTaskHelper,
|
|
35
|
+
TextToImageTaskHelper,
|
|
36
|
+
TextToSpeechTaskHelper,
|
|
37
|
+
TextToVideoTaskHelper,
|
|
38
|
+
TokenClassificationTaskHelper,
|
|
39
|
+
TranslationTaskHelper,
|
|
40
|
+
VisualQuestionAnsweringTaskHelper,
|
|
41
|
+
ZeroShotClassificationTaskHelper,
|
|
42
|
+
ZeroShotImageClassificationTaskHelper,
|
|
43
|
+
} from "../providers/providerHelper";
|
|
44
|
+
import * as Replicate from "../providers/replicate";
|
|
45
|
+
import * as Sambanova from "../providers/sambanova";
|
|
46
|
+
import * as Together from "../providers/together";
|
|
47
|
+
import type { InferenceProvider, InferenceTask } from "../types";
|
|
48
|
+
|
|
49
|
+
export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask, TaskProviderHelper>>> = {
|
|
50
|
+
"black-forest-labs": {
|
|
51
|
+
"text-to-image": new BlackForestLabs.BlackForestLabsTextToImageTask(),
|
|
52
|
+
},
|
|
53
|
+
cerebras: {
|
|
54
|
+
conversational: new Cerebras.CerebrasConversationalTask(),
|
|
55
|
+
},
|
|
56
|
+
cohere: {
|
|
57
|
+
conversational: new Cohere.CohereConversationalTask(),
|
|
58
|
+
},
|
|
59
|
+
"fal-ai": {
|
|
60
|
+
"text-to-image": new FalAI.FalAITextToImageTask(),
|
|
61
|
+
"text-to-speech": new FalAI.FalAITextToSpeechTask(),
|
|
62
|
+
"text-to-video": new FalAI.FalAITextToVideoTask(),
|
|
63
|
+
"automatic-speech-recognition": new FalAI.FalAIAutomaticSpeechRecognitionTask(),
|
|
64
|
+
},
|
|
65
|
+
"hf-inference": {
|
|
66
|
+
"text-to-image": new HFInference.HFInferenceTextToImageTask(),
|
|
67
|
+
conversational: new HFInference.HFInferenceConversationalTask(),
|
|
68
|
+
"text-generation": new HFInference.HFInferenceTextGenerationTask(),
|
|
69
|
+
"text-classification": new HFInference.HFInferenceTextClassificationTask(),
|
|
70
|
+
"question-answering": new HFInference.HFInferenceQuestionAnsweringTask(),
|
|
71
|
+
"audio-classification": new HFInference.HFInferenceAudioClassificationTask(),
|
|
72
|
+
"automatic-speech-recognition": new HFInference.HFInferenceAutomaticSpeechRecognitionTask(),
|
|
73
|
+
"fill-mask": new HFInference.HFInferenceFillMaskTask(),
|
|
74
|
+
"feature-extraction": new HFInference.HFInferenceFeatureExtractionTask(),
|
|
75
|
+
"image-classification": new HFInference.HFInferenceImageClassificationTask(),
|
|
76
|
+
"image-segmentation": new HFInference.HFInferenceImageSegmentationTask(),
|
|
77
|
+
"document-question-answering": new HFInference.HFInferenceDocumentQuestionAnsweringTask(),
|
|
78
|
+
"image-to-text": new HFInference.HFInferenceImageToTextTask(),
|
|
79
|
+
"object-detection": new HFInference.HFInferenceObjectDetectionTask(),
|
|
80
|
+
"audio-to-audio": new HFInference.HFInferenceAudioToAudioTask(),
|
|
81
|
+
"zero-shot-image-classification": new HFInference.HFInferenceZeroShotImageClassificationTask(),
|
|
82
|
+
"zero-shot-classification": new HFInference.HFInferenceZeroShotClassificationTask(),
|
|
83
|
+
"image-to-image": new HFInference.HFInferenceImageToImageTask(),
|
|
84
|
+
"sentence-similarity": new HFInference.HFInferenceSentenceSimilarityTask(),
|
|
85
|
+
"table-question-answering": new HFInference.HFInferenceTableQuestionAnsweringTask(),
|
|
86
|
+
"tabular-classification": new HFInference.HFInferenceTabularClassificationTask(),
|
|
87
|
+
"text-to-speech": new HFInference.HFInferenceTextToSpeechTask(),
|
|
88
|
+
"token-classification": new HFInference.HFInferenceTokenClassificationTask(),
|
|
89
|
+
translation: new HFInference.HFInferenceTranslationTask(),
|
|
90
|
+
summarization: new HFInference.HFInferenceSummarizationTask(),
|
|
91
|
+
"visual-question-answering": new HFInference.HFInferenceVisualQuestionAnsweringTask(),
|
|
92
|
+
"tabular-regression": new HFInference.HFInferenceTabularRegressionTask(),
|
|
93
|
+
"text-to-audio": new HFInference.HFInferenceTextToAudioTask(),
|
|
94
|
+
},
|
|
95
|
+
"fireworks-ai": {
|
|
96
|
+
conversational: new Fireworks.FireworksConversationalTask(),
|
|
97
|
+
},
|
|
98
|
+
hyperbolic: {
|
|
99
|
+
"text-to-image": new Hyperbolic.HyperbolicTextToImageTask(),
|
|
100
|
+
conversational: new Hyperbolic.HyperbolicConversationalTask(),
|
|
101
|
+
"text-generation": new Hyperbolic.HyperbolicTextGenerationTask(),
|
|
102
|
+
},
|
|
103
|
+
nebius: {
|
|
104
|
+
"text-to-image": new Nebius.NebiusTextToImageTask(),
|
|
105
|
+
conversational: new Nebius.NebiusConversationalTask(),
|
|
106
|
+
"text-generation": new Nebius.NebiusTextGenerationTask(),
|
|
107
|
+
},
|
|
108
|
+
novita: {
|
|
109
|
+
conversational: new Novita.NovitaConversationalTask(),
|
|
110
|
+
"text-generation": new Novita.NovitaTextGenerationTask(),
|
|
111
|
+
},
|
|
112
|
+
openai: {
|
|
113
|
+
conversational: new OpenAI.OpenAIConversationalTask(),
|
|
114
|
+
},
|
|
115
|
+
replicate: {
|
|
116
|
+
"text-to-image": new Replicate.ReplicateTextToImageTask(),
|
|
117
|
+
"text-to-speech": new Replicate.ReplicateTextToSpeechTask(),
|
|
118
|
+
"text-to-video": new Replicate.ReplicateTextToVideoTask(),
|
|
119
|
+
},
|
|
120
|
+
sambanova: {
|
|
121
|
+
conversational: new Sambanova.SambanovaConversationalTask(),
|
|
122
|
+
},
|
|
123
|
+
together: {
|
|
124
|
+
"text-to-image": new Together.TogetherTextToImageTask(),
|
|
125
|
+
conversational: new Together.TogetherConversationalTask(),
|
|
126
|
+
"text-generation": new Together.TogetherTextGenerationTask(),
|
|
127
|
+
},
|
|
128
|
+
};
|
|
129
|
+
|
|
130
|
+
/**
|
|
131
|
+
* Get provider helper instance by name and task
|
|
132
|
+
*/
|
|
133
|
+
export function getProviderHelper(
|
|
134
|
+
provider: InferenceProvider,
|
|
135
|
+
task: "text-to-image"
|
|
136
|
+
): TextToImageTaskHelper & TaskProviderHelper;
|
|
137
|
+
export function getProviderHelper(
|
|
138
|
+
provider: InferenceProvider,
|
|
139
|
+
task: "conversational"
|
|
140
|
+
): ConversationalTaskHelper & TaskProviderHelper;
|
|
141
|
+
export function getProviderHelper(
|
|
142
|
+
provider: InferenceProvider,
|
|
143
|
+
task: "text-generation"
|
|
144
|
+
): TextGenerationTaskHelper & TaskProviderHelper;
|
|
145
|
+
export function getProviderHelper(
|
|
146
|
+
provider: InferenceProvider,
|
|
147
|
+
task: "text-to-speech"
|
|
148
|
+
): TextToSpeechTaskHelper & TaskProviderHelper;
|
|
149
|
+
export function getProviderHelper(
|
|
150
|
+
provider: InferenceProvider,
|
|
151
|
+
task: "text-to-audio"
|
|
152
|
+
): TextToAudioTaskHelper & TaskProviderHelper;
|
|
153
|
+
export function getProviderHelper(
|
|
154
|
+
provider: InferenceProvider,
|
|
155
|
+
task: "automatic-speech-recognition"
|
|
156
|
+
): AutomaticSpeechRecognitionTaskHelper & TaskProviderHelper;
|
|
157
|
+
export function getProviderHelper(
|
|
158
|
+
provider: InferenceProvider,
|
|
159
|
+
task: "text-to-video"
|
|
160
|
+
): TextToVideoTaskHelper & TaskProviderHelper;
|
|
161
|
+
export function getProviderHelper(
|
|
162
|
+
provider: InferenceProvider,
|
|
163
|
+
task: "text-classification"
|
|
164
|
+
): TextClassificationTaskHelper & TaskProviderHelper;
|
|
165
|
+
export function getProviderHelper(
|
|
166
|
+
provider: InferenceProvider,
|
|
167
|
+
task: "question-answering"
|
|
168
|
+
): QuestionAnsweringTaskHelper & TaskProviderHelper;
|
|
169
|
+
export function getProviderHelper(
|
|
170
|
+
provider: InferenceProvider,
|
|
171
|
+
task: "audio-classification"
|
|
172
|
+
): AudioClassificationTaskHelper & TaskProviderHelper;
|
|
173
|
+
export function getProviderHelper(
|
|
174
|
+
provider: InferenceProvider,
|
|
175
|
+
task: "audio-to-audio"
|
|
176
|
+
): AudioToAudioTaskHelper & TaskProviderHelper;
|
|
177
|
+
export function getProviderHelper(
|
|
178
|
+
provider: InferenceProvider,
|
|
179
|
+
task: "fill-mask"
|
|
180
|
+
): FillMaskTaskHelper & TaskProviderHelper;
|
|
181
|
+
export function getProviderHelper(
|
|
182
|
+
provider: InferenceProvider,
|
|
183
|
+
task: "feature-extraction"
|
|
184
|
+
): FeatureExtractionTaskHelper & TaskProviderHelper;
|
|
185
|
+
export function getProviderHelper(
|
|
186
|
+
provider: InferenceProvider,
|
|
187
|
+
task: "image-classification"
|
|
188
|
+
): ImageClassificationTaskHelper & TaskProviderHelper;
|
|
189
|
+
export function getProviderHelper(
|
|
190
|
+
provider: InferenceProvider,
|
|
191
|
+
task: "image-segmentation"
|
|
192
|
+
): ImageSegmentationTaskHelper & TaskProviderHelper;
|
|
193
|
+
export function getProviderHelper(
|
|
194
|
+
provider: InferenceProvider,
|
|
195
|
+
task: "document-question-answering"
|
|
196
|
+
): DocumentQuestionAnsweringTaskHelper & TaskProviderHelper;
|
|
197
|
+
export function getProviderHelper(
|
|
198
|
+
provider: InferenceProvider,
|
|
199
|
+
task: "image-to-text"
|
|
200
|
+
): ImageToTextTaskHelper & TaskProviderHelper;
|
|
201
|
+
export function getProviderHelper(
|
|
202
|
+
provider: InferenceProvider,
|
|
203
|
+
task: "object-detection"
|
|
204
|
+
): ObjectDetectionTaskHelper & TaskProviderHelper;
|
|
205
|
+
export function getProviderHelper(
|
|
206
|
+
provider: InferenceProvider,
|
|
207
|
+
task: "zero-shot-image-classification"
|
|
208
|
+
): ZeroShotImageClassificationTaskHelper & TaskProviderHelper;
|
|
209
|
+
export function getProviderHelper(
|
|
210
|
+
provider: InferenceProvider,
|
|
211
|
+
task: "zero-shot-classification"
|
|
212
|
+
): ZeroShotClassificationTaskHelper & TaskProviderHelper;
|
|
213
|
+
export function getProviderHelper(
|
|
214
|
+
provider: InferenceProvider,
|
|
215
|
+
task: "image-to-image"
|
|
216
|
+
): ImageToImageTaskHelper & TaskProviderHelper;
|
|
217
|
+
export function getProviderHelper(
|
|
218
|
+
provider: InferenceProvider,
|
|
219
|
+
task: "sentence-similarity"
|
|
220
|
+
): SentenceSimilarityTaskHelper & TaskProviderHelper;
|
|
221
|
+
export function getProviderHelper(
|
|
222
|
+
provider: InferenceProvider,
|
|
223
|
+
task: "table-question-answering"
|
|
224
|
+
): TableQuestionAnsweringTaskHelper & TaskProviderHelper;
|
|
225
|
+
export function getProviderHelper(
|
|
226
|
+
provider: InferenceProvider,
|
|
227
|
+
task: "tabular-classification"
|
|
228
|
+
): TabularClassificationTaskHelper & TaskProviderHelper;
|
|
229
|
+
export function getProviderHelper(
|
|
230
|
+
provider: InferenceProvider,
|
|
231
|
+
task: "tabular-regression"
|
|
232
|
+
): TabularRegressionTaskHelper & TaskProviderHelper;
|
|
233
|
+
export function getProviderHelper(
|
|
234
|
+
provider: InferenceProvider,
|
|
235
|
+
task: "token-classification"
|
|
236
|
+
): TokenClassificationTaskHelper & TaskProviderHelper;
|
|
237
|
+
export function getProviderHelper(
|
|
238
|
+
provider: InferenceProvider,
|
|
239
|
+
task: "translation"
|
|
240
|
+
): TranslationTaskHelper & TaskProviderHelper;
|
|
241
|
+
export function getProviderHelper(
|
|
242
|
+
provider: InferenceProvider,
|
|
243
|
+
task: "summarization"
|
|
244
|
+
): SummarizationTaskHelper & TaskProviderHelper;
|
|
245
|
+
export function getProviderHelper(
|
|
246
|
+
provider: InferenceProvider,
|
|
247
|
+
task: "visual-question-answering"
|
|
248
|
+
): VisualQuestionAnsweringTaskHelper & TaskProviderHelper;
|
|
249
|
+
export function getProviderHelper(provider: InferenceProvider, task: InferenceTask | undefined): TaskProviderHelper;
|
|
250
|
+
|
|
251
|
+
export function getProviderHelper(provider: InferenceProvider, task: InferenceTask | undefined): TaskProviderHelper {
|
|
252
|
+
if (provider === "hf-inference") {
|
|
253
|
+
if (!task) {
|
|
254
|
+
return new HFInference.HFInferenceTask();
|
|
255
|
+
}
|
|
256
|
+
}
|
|
257
|
+
if (!task) {
|
|
258
|
+
throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'");
|
|
259
|
+
}
|
|
260
|
+
if (!(provider in PROVIDERS)) {
|
|
261
|
+
throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
|
|
262
|
+
}
|
|
263
|
+
const providerTasks = PROVIDERS[provider];
|
|
264
|
+
if (!providerTasks || !(task in providerTasks)) {
|
|
265
|
+
throw new Error(
|
|
266
|
+
`Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}`
|
|
267
|
+
);
|
|
268
|
+
}
|
|
269
|
+
return providerTasks[task] as TaskProviderHelper;
|
|
270
|
+
}
|
|
@@ -1,23 +1,9 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import {
|
|
3
|
-
import {
|
|
4
|
-
import {
|
|
5
|
-
import { FAL_AI_CONFIG } from "../providers/fal-ai";
|
|
6
|
-
import { FIREWORKS_AI_CONFIG } from "../providers/fireworks-ai";
|
|
7
|
-
import { HF_INFERENCE_CONFIG } from "../providers/hf-inference";
|
|
8
|
-
import { HYPERBOLIC_CONFIG } from "../providers/hyperbolic";
|
|
9
|
-
import { NEBIUS_CONFIG } from "../providers/nebius";
|
|
10
|
-
import { NOVITA_CONFIG } from "../providers/novita";
|
|
11
|
-
import { REPLICATE_CONFIG } from "../providers/replicate";
|
|
12
|
-
import { SAMBANOVA_CONFIG } from "../providers/sambanova";
|
|
13
|
-
import { TOGETHER_CONFIG } from "../providers/together";
|
|
14
|
-
import { OPENAI_CONFIG } from "../providers/openai";
|
|
15
|
-
import type { InferenceProvider, InferenceTask, Options, ProviderConfig, RequestArgs } from "../types";
|
|
16
|
-
import { isUrl } from "./isUrl";
|
|
17
|
-
import { version as packageVersion, name as packageName } from "../../package.json";
|
|
1
|
+
import { name as packageName, version as packageVersion } from "../../package.json";
|
|
2
|
+
import { HF_HEADER_X_BILL_TO, HF_HUB_URL } from "../config";
|
|
3
|
+
import type { InferenceTask, Options, RequestArgs } from "../types";
|
|
4
|
+
import { getProviderHelper } from "./getProviderHelper";
|
|
18
5
|
import { getProviderModelId } from "./getProviderModelId";
|
|
19
|
-
|
|
20
|
-
const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
|
|
6
|
+
import { isUrl } from "./isUrl";
|
|
21
7
|
|
|
22
8
|
/**
|
|
23
9
|
* Lazy-loaded from huggingface.co/api/tasks when needed
|
|
@@ -25,25 +11,6 @@ const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
|
|
|
25
11
|
*/
|
|
26
12
|
let tasks: Record<string, { models: { id: string }[] }> | null = null;
|
|
27
13
|
|
|
28
|
-
/**
|
|
29
|
-
* Config to define how to serialize requests for each provider
|
|
30
|
-
*/
|
|
31
|
-
const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
|
|
32
|
-
"black-forest-labs": BLACK_FOREST_LABS_CONFIG,
|
|
33
|
-
cerebras: CEREBRAS_CONFIG,
|
|
34
|
-
cohere: COHERE_CONFIG,
|
|
35
|
-
"fal-ai": FAL_AI_CONFIG,
|
|
36
|
-
"fireworks-ai": FIREWORKS_AI_CONFIG,
|
|
37
|
-
"hf-inference": HF_INFERENCE_CONFIG,
|
|
38
|
-
hyperbolic: HYPERBOLIC_CONFIG,
|
|
39
|
-
openai: OPENAI_CONFIG,
|
|
40
|
-
nebius: NEBIUS_CONFIG,
|
|
41
|
-
novita: NOVITA_CONFIG,
|
|
42
|
-
replicate: REPLICATE_CONFIG,
|
|
43
|
-
sambanova: SAMBANOVA_CONFIG,
|
|
44
|
-
together: TOGETHER_CONFIG,
|
|
45
|
-
};
|
|
46
|
-
|
|
47
14
|
/**
|
|
48
15
|
* Helper that prepares request arguments.
|
|
49
16
|
* This async version handle the model ID resolution step.
|
|
@@ -56,14 +23,11 @@ export async function makeRequestOptions(
|
|
|
56
23
|
options?: Options & {
|
|
57
24
|
/** In most cases (unless we pass a endpointUrl) we know the task */
|
|
58
25
|
task?: InferenceTask;
|
|
59
|
-
chatCompletion?: boolean;
|
|
60
26
|
}
|
|
61
27
|
): Promise<{ url: string; info: RequestInit }> {
|
|
62
28
|
const { provider: maybeProvider, model: maybeModel } = args;
|
|
63
29
|
const provider = maybeProvider ?? "hf-inference";
|
|
64
|
-
const
|
|
65
|
-
const { task, chatCompletion } = options ?? {};
|
|
66
|
-
|
|
30
|
+
const { task } = options ?? {};
|
|
67
31
|
// Validate inputs
|
|
68
32
|
if (args.endpointUrl && provider !== "hf-inference") {
|
|
69
33
|
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
|
|
@@ -74,21 +38,20 @@ export async function makeRequestOptions(
|
|
|
74
38
|
if (!maybeModel && !task) {
|
|
75
39
|
throw new Error("No model provided, and no task has been specified.");
|
|
76
40
|
}
|
|
77
|
-
if (!providerConfig) {
|
|
78
|
-
throw new Error(`No provider config found for provider ${provider}`);
|
|
79
|
-
}
|
|
80
|
-
if (providerConfig.clientSideRoutingOnly && !maybeModel) {
|
|
81
|
-
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
|
|
82
|
-
}
|
|
83
41
|
|
|
84
42
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
85
43
|
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
|
|
86
|
-
const
|
|
44
|
+
const providerHelper = getProviderHelper(provider, task);
|
|
45
|
+
|
|
46
|
+
if (providerHelper.clientSideRoutingOnly && !maybeModel) {
|
|
47
|
+
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
const resolvedModel = providerHelper.clientSideRoutingOnly
|
|
87
51
|
? // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
88
52
|
removeProviderPrefix(maybeModel!, provider)
|
|
89
53
|
: await getProviderModelId({ model: hfModel, provider }, args, {
|
|
90
54
|
task,
|
|
91
|
-
chatCompletion,
|
|
92
55
|
fetch: options?.fetch,
|
|
93
56
|
});
|
|
94
57
|
|
|
@@ -108,19 +71,17 @@ export function makeRequestOptionsFromResolvedModel(
|
|
|
108
71
|
},
|
|
109
72
|
options?: Options & {
|
|
110
73
|
task?: InferenceTask;
|
|
111
|
-
chatCompletion?: boolean;
|
|
112
74
|
}
|
|
113
75
|
): { url: string; info: RequestInit } {
|
|
114
76
|
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
|
|
115
77
|
void model;
|
|
116
78
|
|
|
117
79
|
const provider = maybeProvider ?? "hf-inference";
|
|
118
|
-
const providerConfig = providerConfigs[provider];
|
|
119
|
-
|
|
120
|
-
const { includeCredentials, task, chatCompletion, signal, billTo } = options ?? {};
|
|
121
80
|
|
|
81
|
+
const { includeCredentials, task, signal, billTo } = options ?? {};
|
|
82
|
+
const providerHelper = getProviderHelper(provider, task);
|
|
122
83
|
const authMethod = (() => {
|
|
123
|
-
if (
|
|
84
|
+
if (providerHelper.clientSideRoutingOnly) {
|
|
124
85
|
// Closed-source providers require an accessToken (cannot be routed).
|
|
125
86
|
if (accessToken && accessToken.startsWith("hf_")) {
|
|
126
87
|
throw new Error(`Provider ${provider} is closed-source and does not support HF tokens.`);
|
|
@@ -138,36 +99,25 @@ export function makeRequestOptionsFromResolvedModel(
|
|
|
138
99
|
})();
|
|
139
100
|
|
|
140
101
|
// Make URL
|
|
141
|
-
const url = endpointUrl
|
|
142
|
-
? chatCompletion
|
|
143
|
-
? endpointUrl + `/v1/chat/completions`
|
|
144
|
-
: endpointUrl
|
|
145
|
-
: providerConfig.makeUrl({
|
|
146
|
-
authMethod,
|
|
147
|
-
baseUrl:
|
|
148
|
-
authMethod !== "provider-key"
|
|
149
|
-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider)
|
|
150
|
-
: providerConfig.makeBaseUrl(task),
|
|
151
|
-
model: resolvedModel,
|
|
152
|
-
chatCompletion,
|
|
153
|
-
task,
|
|
154
|
-
});
|
|
155
102
|
|
|
156
|
-
|
|
157
|
-
const
|
|
158
|
-
const headers = providerConfig.makeHeaders({
|
|
159
|
-
accessToken,
|
|
103
|
+
const modelId = endpointUrl ?? resolvedModel;
|
|
104
|
+
const url = providerHelper.makeUrl({
|
|
160
105
|
authMethod,
|
|
106
|
+
model: modelId,
|
|
107
|
+
task,
|
|
161
108
|
});
|
|
109
|
+
// Make headers
|
|
110
|
+
const headers = providerHelper.prepareHeaders(
|
|
111
|
+
{
|
|
112
|
+
accessToken,
|
|
113
|
+
authMethod,
|
|
114
|
+
},
|
|
115
|
+
"data" in args && !!args.data
|
|
116
|
+
);
|
|
162
117
|
if (billTo) {
|
|
163
118
|
headers[HF_HEADER_X_BILL_TO] = billTo;
|
|
164
119
|
}
|
|
165
120
|
|
|
166
|
-
// Add content-type to headers
|
|
167
|
-
if (!binary) {
|
|
168
|
-
headers["Content-Type"] = "application/json";
|
|
169
|
-
}
|
|
170
|
-
|
|
171
121
|
// Add user-agent to headers
|
|
172
122
|
// e.g. @huggingface/inference/3.1.3
|
|
173
123
|
const ownUserAgent = `${packageName}/${packageVersion}`;
|
|
@@ -177,17 +127,11 @@ export function makeRequestOptionsFromResolvedModel(
|
|
|
177
127
|
headers["User-Agent"] = userAgent;
|
|
178
128
|
|
|
179
129
|
// Make body
|
|
180
|
-
const body =
|
|
181
|
-
|
|
182
|
-
:
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
model: resolvedModel,
|
|
186
|
-
task,
|
|
187
|
-
chatCompletion,
|
|
188
|
-
})
|
|
189
|
-
);
|
|
190
|
-
|
|
130
|
+
const body = providerHelper.makeBody({
|
|
131
|
+
args: remainingArgs as Record<string, unknown>,
|
|
132
|
+
model: resolvedModel,
|
|
133
|
+
task,
|
|
134
|
+
});
|
|
191
135
|
/**
|
|
192
136
|
* For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
|
|
193
137
|
*/
|
|
@@ -201,11 +145,10 @@ export function makeRequestOptionsFromResolvedModel(
|
|
|
201
145
|
const info: RequestInit = {
|
|
202
146
|
headers,
|
|
203
147
|
method: "POST",
|
|
204
|
-
body,
|
|
148
|
+
body: body,
|
|
205
149
|
...(credentials ? { credentials } : undefined),
|
|
206
150
|
signal,
|
|
207
151
|
};
|
|
208
|
-
|
|
209
152
|
return { url, info };
|
|
210
153
|
}
|
|
211
154
|
|
|
@@ -14,33 +14,84 @@
|
|
|
14
14
|
*
|
|
15
15
|
* Thanks!
|
|
16
16
|
*/
|
|
17
|
-
import
|
|
17
|
+
import { InferenceOutputError } from "../lib/InferenceOutputError";
|
|
18
|
+
import type { BodyParams, HeaderParams, UrlParams } from "../types";
|
|
19
|
+
import { delay } from "../utils/delay";
|
|
20
|
+
import { omit } from "../utils/omit";
|
|
21
|
+
import { TaskProviderHelper, type TextToImageTaskHelper } from "./providerHelper";
|
|
18
22
|
|
|
19
23
|
const BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
|
|
24
|
+
interface BlackForestLabsResponse {
|
|
25
|
+
id: string;
|
|
26
|
+
polling_url: string;
|
|
27
|
+
}
|
|
20
28
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
29
|
+
export class BlackForestLabsTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper {
|
|
30
|
+
constructor() {
|
|
31
|
+
super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
|
|
32
|
+
}
|
|
24
33
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
34
|
+
preparePayload(params: BodyParams): Record<string, unknown> {
|
|
35
|
+
return {
|
|
36
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
37
|
+
...(params.args.parameters as Record<string, unknown>),
|
|
38
|
+
prompt: params.args.inputs,
|
|
39
|
+
};
|
|
40
|
+
}
|
|
28
41
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
42
|
+
override prepareHeaders(params: HeaderParams, binary: boolean): Record<string, string> {
|
|
43
|
+
const headers: Record<string, string> = {
|
|
44
|
+
Authorization:
|
|
45
|
+
params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`,
|
|
46
|
+
};
|
|
47
|
+
if (!binary) {
|
|
48
|
+
headers["Content-Type"] = "application/json";
|
|
49
|
+
}
|
|
50
|
+
return headers;
|
|
34
51
|
}
|
|
35
|
-
};
|
|
36
52
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
53
|
+
makeRoute(params: UrlParams): string {
|
|
54
|
+
if (!params) {
|
|
55
|
+
throw new Error("Params are required");
|
|
56
|
+
}
|
|
57
|
+
return `/v1/${params.model}`;
|
|
58
|
+
}
|
|
40
59
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
60
|
+
async getResponse(
|
|
61
|
+
response: BlackForestLabsResponse,
|
|
62
|
+
url?: string,
|
|
63
|
+
headers?: HeadersInit,
|
|
64
|
+
outputType?: "url" | "blob"
|
|
65
|
+
): Promise<string | Blob> {
|
|
66
|
+
const urlObj = new URL(response.polling_url);
|
|
67
|
+
for (let step = 0; step < 5; step++) {
|
|
68
|
+
await delay(1000);
|
|
69
|
+
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
|
|
70
|
+
urlObj.searchParams.set("attempt", step.toString(10));
|
|
71
|
+
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
|
|
72
|
+
if (!resp.ok) {
|
|
73
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
74
|
+
}
|
|
75
|
+
const payload = await resp.json();
|
|
76
|
+
if (
|
|
77
|
+
typeof payload === "object" &&
|
|
78
|
+
payload &&
|
|
79
|
+
"status" in payload &&
|
|
80
|
+
typeof payload.status === "string" &&
|
|
81
|
+
payload.status === "Ready" &&
|
|
82
|
+
"result" in payload &&
|
|
83
|
+
typeof payload.result === "object" &&
|
|
84
|
+
payload.result &&
|
|
85
|
+
"sample" in payload.result &&
|
|
86
|
+
typeof payload.result.sample === "string"
|
|
87
|
+
) {
|
|
88
|
+
if (outputType === "url") {
|
|
89
|
+
return payload.result.sample;
|
|
90
|
+
}
|
|
91
|
+
const image = await fetch(payload.result.sample);
|
|
92
|
+
return await image.blob();
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
96
|
+
}
|
|
97
|
+
}
|