@huggingface/inference 3.7.0 → 3.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.cjs +1369 -941
- package/dist/index.js +1371 -943
- package/dist/src/lib/getInferenceProviderMapping.d.ts +21 -0
- package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -0
- package/dist/src/lib/getProviderHelper.d.ts +37 -0
- package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
- package/dist/src/lib/makeRequestOptions.d.ts +5 -5
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/black-forest-labs.d.ts +14 -18
- package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
- package/dist/src/providers/cerebras.d.ts +4 -2
- package/dist/src/providers/cerebras.d.ts.map +1 -1
- package/dist/src/providers/cohere.d.ts +5 -2
- package/dist/src/providers/cohere.d.ts.map +1 -1
- package/dist/src/providers/consts.d.ts +2 -3
- package/dist/src/providers/consts.d.ts.map +1 -1
- package/dist/src/providers/fal-ai.d.ts +50 -3
- package/dist/src/providers/fal-ai.d.ts.map +1 -1
- package/dist/src/providers/fireworks-ai.d.ts +5 -2
- package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
- package/dist/src/providers/hf-inference.d.ts +126 -2
- package/dist/src/providers/hf-inference.d.ts.map +1 -1
- package/dist/src/providers/hyperbolic.d.ts +31 -2
- package/dist/src/providers/hyperbolic.d.ts.map +1 -1
- package/dist/src/providers/nebius.d.ts +20 -18
- package/dist/src/providers/nebius.d.ts.map +1 -1
- package/dist/src/providers/novita.d.ts +21 -18
- package/dist/src/providers/novita.d.ts.map +1 -1
- package/dist/src/providers/openai.d.ts +4 -2
- package/dist/src/providers/openai.d.ts.map +1 -1
- package/dist/src/providers/providerHelper.d.ts +182 -0
- package/dist/src/providers/providerHelper.d.ts.map +1 -0
- package/dist/src/providers/replicate.d.ts +23 -19
- package/dist/src/providers/replicate.d.ts.map +1 -1
- package/dist/src/providers/sambanova.d.ts +4 -2
- package/dist/src/providers/sambanova.d.ts.map +1 -1
- package/dist/src/providers/together.d.ts +32 -2
- package/dist/src/providers/together.d.ts.map +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts +2 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/audio/utils.d.ts +2 -1
- package/dist/src/tasks/audio/utils.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts +0 -2
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts +0 -2
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/index.d.ts +6 -6
- package/dist/src/tasks/index.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
- package/dist/src/types.d.ts +5 -13
- package/dist/src/types.d.ts.map +1 -1
- package/dist/src/utils/request.d.ts +3 -2
- package/dist/src/utils/request.d.ts.map +1 -1
- package/package.json +3 -3
- package/src/lib/getInferenceProviderMapping.ts +96 -0
- package/src/lib/getProviderHelper.ts +270 -0
- package/src/lib/makeRequestOptions.ts +78 -97
- package/src/providers/black-forest-labs.ts +73 -22
- package/src/providers/cerebras.ts +6 -27
- package/src/providers/cohere.ts +9 -28
- package/src/providers/consts.ts +5 -2
- package/src/providers/fal-ai.ts +224 -77
- package/src/providers/fireworks-ai.ts +8 -29
- package/src/providers/hf-inference.ts +557 -34
- package/src/providers/hyperbolic.ts +107 -29
- package/src/providers/nebius.ts +65 -29
- package/src/providers/novita.ts +68 -32
- package/src/providers/openai.ts +6 -32
- package/src/providers/providerHelper.ts +354 -0
- package/src/providers/replicate.ts +124 -34
- package/src/providers/sambanova.ts +5 -30
- package/src/providers/together.ts +92 -28
- package/src/snippets/getInferenceSnippets.ts +39 -14
- package/src/snippets/templates.exported.ts +25 -25
- package/src/tasks/audio/audioClassification.ts +5 -8
- package/src/tasks/audio/audioToAudio.ts +4 -27
- package/src/tasks/audio/automaticSpeechRecognition.ts +5 -4
- package/src/tasks/audio/textToSpeech.ts +5 -29
- package/src/tasks/audio/utils.ts +2 -1
- package/src/tasks/custom/request.ts +3 -3
- package/src/tasks/custom/streamingRequest.ts +4 -3
- package/src/tasks/cv/imageClassification.ts +4 -8
- package/src/tasks/cv/imageSegmentation.ts +4 -9
- package/src/tasks/cv/imageToImage.ts +4 -7
- package/src/tasks/cv/imageToText.ts +4 -7
- package/src/tasks/cv/objectDetection.ts +4 -19
- package/src/tasks/cv/textToImage.ts +9 -137
- package/src/tasks/cv/textToVideo.ts +17 -64
- package/src/tasks/cv/zeroShotImageClassification.ts +4 -8
- package/src/tasks/index.ts +6 -6
- package/src/tasks/multimodal/documentQuestionAnswering.ts +4 -19
- package/src/tasks/multimodal/visualQuestionAnswering.ts +4 -12
- package/src/tasks/nlp/chatCompletion.ts +5 -20
- package/src/tasks/nlp/chatCompletionStream.ts +4 -3
- package/src/tasks/nlp/featureExtraction.ts +4 -19
- package/src/tasks/nlp/fillMask.ts +4 -17
- package/src/tasks/nlp/questionAnswering.ts +11 -26
- package/src/tasks/nlp/sentenceSimilarity.ts +4 -8
- package/src/tasks/nlp/summarization.ts +4 -7
- package/src/tasks/nlp/tableQuestionAnswering.ts +10 -30
- package/src/tasks/nlp/textClassification.ts +4 -9
- package/src/tasks/nlp/textGeneration.ts +11 -79
- package/src/tasks/nlp/textGenerationStream.ts +3 -1
- package/src/tasks/nlp/tokenClassification.ts +11 -23
- package/src/tasks/nlp/translation.ts +4 -7
- package/src/tasks/nlp/zeroShotClassification.ts +11 -21
- package/src/tasks/tabular/tabularClassification.ts +4 -7
- package/src/tasks/tabular/tabularRegression.ts +4 -7
- package/src/types.ts +5 -14
- package/src/utils/request.ts +7 -4
- package/dist/src/lib/getProviderModelId.d.ts +0 -10
- package/dist/src/lib/getProviderModelId.d.ts.map +0 -1
- package/src/lib/getProviderModelId.ts +0 -74
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
import type {
|
|
2
|
+
AudioClassificationInput,
|
|
3
|
+
AudioClassificationOutput,
|
|
4
|
+
AutomaticSpeechRecognitionInput,
|
|
5
|
+
AutomaticSpeechRecognitionOutput,
|
|
6
|
+
ChatCompletionInput,
|
|
7
|
+
ChatCompletionOutput,
|
|
8
|
+
DocumentQuestionAnsweringInput,
|
|
9
|
+
DocumentQuestionAnsweringOutput,
|
|
10
|
+
FeatureExtractionInput,
|
|
11
|
+
FeatureExtractionOutput,
|
|
12
|
+
FillMaskInput,
|
|
13
|
+
FillMaskOutput,
|
|
14
|
+
ImageClassificationInput,
|
|
15
|
+
ImageClassificationOutput,
|
|
16
|
+
ImageSegmentationInput,
|
|
17
|
+
ImageSegmentationOutput,
|
|
18
|
+
ImageToImageInput,
|
|
19
|
+
ImageToTextInput,
|
|
20
|
+
ImageToTextOutput,
|
|
21
|
+
ObjectDetectionInput,
|
|
22
|
+
ObjectDetectionOutput,
|
|
23
|
+
QuestionAnsweringInput,
|
|
24
|
+
QuestionAnsweringOutput,
|
|
25
|
+
SentenceSimilarityInput,
|
|
26
|
+
SentenceSimilarityOutput,
|
|
27
|
+
SummarizationInput,
|
|
28
|
+
SummarizationOutput,
|
|
29
|
+
TableQuestionAnsweringInput,
|
|
30
|
+
TableQuestionAnsweringOutput,
|
|
31
|
+
TextClassificationOutput,
|
|
32
|
+
TextGenerationInput,
|
|
33
|
+
TextGenerationOutput,
|
|
34
|
+
TextToImageInput,
|
|
35
|
+
TextToSpeechInput,
|
|
36
|
+
TextToVideoInput,
|
|
37
|
+
TokenClassificationInput,
|
|
38
|
+
TokenClassificationOutput,
|
|
39
|
+
TranslationInput,
|
|
40
|
+
TranslationOutput,
|
|
41
|
+
VisualQuestionAnsweringInput,
|
|
42
|
+
VisualQuestionAnsweringOutput,
|
|
43
|
+
ZeroShotClassificationInput,
|
|
44
|
+
ZeroShotClassificationOutput,
|
|
45
|
+
ZeroShotImageClassificationInput,
|
|
46
|
+
ZeroShotImageClassificationOutput,
|
|
47
|
+
} from "@huggingface/tasks";
|
|
48
|
+
import { HF_ROUTER_URL } from "../config";
|
|
49
|
+
import { InferenceOutputError } from "../lib/InferenceOutputError";
|
|
50
|
+
import type { AudioToAudioOutput } from "../tasks/audio/audioToAudio";
|
|
51
|
+
import type { BaseArgs, BodyParams, HeaderParams, InferenceProvider, UrlParams } from "../types";
|
|
52
|
+
import { toArray } from "../utils/toArray";
|
|
53
|
+
|
|
54
|
+
/**
|
|
55
|
+
* Base class for task-specific provider helpers
|
|
56
|
+
*/
|
|
57
|
+
export abstract class TaskProviderHelper {
|
|
58
|
+
constructor(
|
|
59
|
+
private provider: InferenceProvider,
|
|
60
|
+
private baseUrl: string,
|
|
61
|
+
readonly clientSideRoutingOnly: boolean = false
|
|
62
|
+
) {}
|
|
63
|
+
|
|
64
|
+
/**
|
|
65
|
+
* Return the response in the expected format.
|
|
66
|
+
* Needs to be implemented in the subclasses.
|
|
67
|
+
*/
|
|
68
|
+
abstract getResponse(
|
|
69
|
+
response: unknown,
|
|
70
|
+
url?: string,
|
|
71
|
+
headers?: HeadersInit,
|
|
72
|
+
outputType?: "url" | "blob"
|
|
73
|
+
): Promise<unknown>;
|
|
74
|
+
|
|
75
|
+
/**
|
|
76
|
+
* Prepare the route for the request
|
|
77
|
+
* Needs to be implemented in the subclasses.
|
|
78
|
+
*/
|
|
79
|
+
abstract makeRoute(params: UrlParams): string;
|
|
80
|
+
/**
|
|
81
|
+
* Prepare the payload for the request
|
|
82
|
+
* Needs to be implemented in the subclasses.
|
|
83
|
+
*/
|
|
84
|
+
abstract preparePayload(params: BodyParams): unknown;
|
|
85
|
+
|
|
86
|
+
/**
|
|
87
|
+
* Prepare the base URL for the request
|
|
88
|
+
*/
|
|
89
|
+
makeBaseUrl(params: UrlParams): string {
|
|
90
|
+
return params.authMethod !== "provider-key" ? `${HF_ROUTER_URL}/${this.provider}` : this.baseUrl;
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
/**
|
|
94
|
+
* Prepare the body for the request
|
|
95
|
+
*/
|
|
96
|
+
makeBody(params: BodyParams): BodyInit {
|
|
97
|
+
if ("data" in params.args && !!params.args.data) {
|
|
98
|
+
return params.args.data as BodyInit;
|
|
99
|
+
}
|
|
100
|
+
return JSON.stringify(this.preparePayload(params));
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
/**
|
|
104
|
+
* Prepare the URL for the request
|
|
105
|
+
*/
|
|
106
|
+
makeUrl(params: UrlParams): string {
|
|
107
|
+
const baseUrl = this.makeBaseUrl(params);
|
|
108
|
+
const route = this.makeRoute(params).replace(/^\/+/, "");
|
|
109
|
+
return `${baseUrl}/${route}`;
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
/**
|
|
113
|
+
* Prepare the headers for the request
|
|
114
|
+
*/
|
|
115
|
+
prepareHeaders(params: HeaderParams, isBinary: boolean): Record<string, string> {
|
|
116
|
+
const headers: Record<string, string> = { Authorization: `Bearer ${params.accessToken}` };
|
|
117
|
+
if (!isBinary) {
|
|
118
|
+
headers["Content-Type"] = "application/json";
|
|
119
|
+
}
|
|
120
|
+
return headers;
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
// PER-TASK PROVIDER HELPER INTERFACES
|
|
125
|
+
|
|
126
|
+
// CV Tasks
|
|
127
|
+
export interface TextToImageTaskHelper {
|
|
128
|
+
getResponse(
|
|
129
|
+
response: unknown,
|
|
130
|
+
url?: string,
|
|
131
|
+
headers?: HeadersInit,
|
|
132
|
+
outputType?: "url" | "blob"
|
|
133
|
+
): Promise<string | Blob>;
|
|
134
|
+
preparePayload(params: BodyParams<TextToImageInput & BaseArgs>): Record<string, unknown>;
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
export interface TextToVideoTaskHelper {
|
|
138
|
+
getResponse(response: unknown, url?: string, headers?: Record<string, string>): Promise<Blob>;
|
|
139
|
+
preparePayload(params: BodyParams<TextToVideoInput & BaseArgs>): Record<string, unknown>;
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
export interface ImageToImageTaskHelper {
|
|
143
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<Blob>;
|
|
144
|
+
preparePayload(params: BodyParams<ImageToImageInput & BaseArgs>): Record<string, unknown>;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
export interface ImageSegmentationTaskHelper {
|
|
148
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ImageSegmentationOutput>;
|
|
149
|
+
preparePayload(params: BodyParams<ImageSegmentationInput & BaseArgs>): Record<string, unknown> | BodyInit;
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
export interface ImageClassificationTaskHelper {
|
|
153
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ImageClassificationOutput>;
|
|
154
|
+
preparePayload(params: BodyParams<ImageClassificationInput & BaseArgs>): Record<string, unknown> | BodyInit;
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
export interface ObjectDetectionTaskHelper {
|
|
158
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ObjectDetectionOutput>;
|
|
159
|
+
preparePayload(params: BodyParams<ObjectDetectionInput & BaseArgs>): Record<string, unknown> | BodyInit;
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
export interface ImageToTextTaskHelper {
|
|
163
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ImageToTextOutput>;
|
|
164
|
+
preparePayload(params: BodyParams<ImageToTextInput & BaseArgs>): Record<string, unknown> | BodyInit;
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
export interface ZeroShotImageClassificationTaskHelper {
|
|
168
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ZeroShotImageClassificationOutput>;
|
|
169
|
+
preparePayload(params: BodyParams<ZeroShotImageClassificationInput & BaseArgs>): Record<string, unknown> | BodyInit;
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
// NLP Tasks
|
|
173
|
+
export interface TextGenerationTaskHelper {
|
|
174
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<TextGenerationOutput>;
|
|
175
|
+
preparePayload(params: BodyParams<TextGenerationInput & BaseArgs>): Record<string, unknown>;
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
export interface ConversationalTaskHelper {
|
|
179
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ChatCompletionOutput>;
|
|
180
|
+
preparePayload(params: BodyParams<ChatCompletionInput & BaseArgs>): Record<string, unknown>;
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
export interface TextClassificationTaskHelper {
|
|
184
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<TextClassificationOutput>;
|
|
185
|
+
preparePayload(params: BodyParams<ZeroShotClassificationInput & BaseArgs>): Record<string, unknown>;
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
export interface QuestionAnsweringTaskHelper {
|
|
189
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<QuestionAnsweringOutput[number]>;
|
|
190
|
+
preparePayload(params: BodyParams<QuestionAnsweringInput & BaseArgs>): Record<string, unknown>;
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
export interface FillMaskTaskHelper {
|
|
194
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<FillMaskOutput>;
|
|
195
|
+
preparePayload(params: BodyParams<FillMaskInput & BaseArgs>): Record<string, unknown>;
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
export interface ZeroShotClassificationTaskHelper {
|
|
199
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ZeroShotClassificationOutput>;
|
|
200
|
+
preparePayload(params: BodyParams<ZeroShotClassificationInput & BaseArgs>): Record<string, unknown>;
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
export interface SentenceSimilarityTaskHelper {
|
|
204
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<SentenceSimilarityOutput>;
|
|
205
|
+
preparePayload(params: BodyParams<SentenceSimilarityInput & BaseArgs>): Record<string, unknown>;
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
export interface TableQuestionAnsweringTaskHelper {
|
|
209
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<TableQuestionAnsweringOutput[number]>;
|
|
210
|
+
preparePayload(params: BodyParams<TableQuestionAnsweringInput & BaseArgs>): Record<string, unknown>;
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
export interface TokenClassificationTaskHelper {
|
|
214
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<TokenClassificationOutput>;
|
|
215
|
+
preparePayload(params: BodyParams<TokenClassificationInput & BaseArgs>): Record<string, unknown>;
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
export interface TranslationTaskHelper {
|
|
219
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<TranslationOutput>;
|
|
220
|
+
preparePayload(params: BodyParams<TranslationInput & BaseArgs>): Record<string, unknown>;
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
export interface SummarizationTaskHelper {
|
|
224
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<SummarizationOutput>;
|
|
225
|
+
preparePayload(params: BodyParams<SummarizationInput & BaseArgs>): Record<string, unknown>;
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
// Audio Tasks
|
|
229
|
+
export interface TextToSpeechTaskHelper {
|
|
230
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<Blob>;
|
|
231
|
+
preparePayload(params: BodyParams<TextToSpeechInput & BaseArgs>): Record<string, unknown>;
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
export interface TextToAudioTaskHelper {
|
|
235
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<Blob>;
|
|
236
|
+
preparePayload(params: BodyParams<Record<string, unknown> & BaseArgs>): Record<string, unknown>;
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
export interface AudioToAudioTaskHelper {
|
|
240
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<AudioToAudioOutput[]>;
|
|
241
|
+
preparePayload(
|
|
242
|
+
params: BodyParams<BaseArgs & { inputs: Blob } & Record<string, unknown>>
|
|
243
|
+
): Record<string, unknown> | BodyInit;
|
|
244
|
+
}
|
|
245
|
+
export interface AutomaticSpeechRecognitionTaskHelper {
|
|
246
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<AutomaticSpeechRecognitionOutput>;
|
|
247
|
+
preparePayload(params: BodyParams<AutomaticSpeechRecognitionInput & BaseArgs>): Record<string, unknown> | BodyInit;
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
export interface AudioClassificationTaskHelper {
|
|
251
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<AudioClassificationOutput>;
|
|
252
|
+
preparePayload(params: BodyParams<AudioClassificationInput & BaseArgs>): Record<string, unknown> | BodyInit;
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
// Multimodal Tasks
|
|
256
|
+
export interface DocumentQuestionAnsweringTaskHelper {
|
|
257
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<DocumentQuestionAnsweringOutput[number]>;
|
|
258
|
+
preparePayload(params: BodyParams<DocumentQuestionAnsweringInput & BaseArgs>): Record<string, unknown> | BodyInit;
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
export interface FeatureExtractionTaskHelper {
|
|
262
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<FeatureExtractionOutput>;
|
|
263
|
+
preparePayload(params: BodyParams<FeatureExtractionInput & BaseArgs>): Record<string, unknown>;
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
export interface VisualQuestionAnsweringTaskHelper {
|
|
267
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<VisualQuestionAnsweringOutput[number]>;
|
|
268
|
+
preparePayload(params: BodyParams<VisualQuestionAnsweringInput & BaseArgs>): Record<string, unknown> | BodyInit;
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
export interface TabularClassificationTaskHelper {
|
|
272
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<number[]>;
|
|
273
|
+
preparePayload(
|
|
274
|
+
params: BodyParams<BaseArgs & { inputs: { data: Record<string, string[]> } } & Record<string, unknown>>
|
|
275
|
+
): Record<string, unknown> | BodyInit;
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
export interface TabularRegressionTaskHelper {
|
|
279
|
+
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<number[]>;
|
|
280
|
+
preparePayload(
|
|
281
|
+
params: BodyParams<BaseArgs & { inputs: { data: Record<string, string[]> } } & Record<string, unknown>>
|
|
282
|
+
): Record<string, unknown> | BodyInit;
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
// BASE IMPLEMENTATIONS FOR COMMON PATTERNS
|
|
286
|
+
|
|
287
|
+
export class BaseConversationalTask extends TaskProviderHelper implements ConversationalTaskHelper {
|
|
288
|
+
constructor(provider: InferenceProvider, baseUrl: string, clientSideRoutingOnly: boolean = false) {
|
|
289
|
+
super(provider, baseUrl, clientSideRoutingOnly);
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
makeRoute(): string {
|
|
293
|
+
return "v1/chat/completions";
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
preparePayload(params: BodyParams): Record<string, unknown> {
|
|
297
|
+
return {
|
|
298
|
+
...params.args,
|
|
299
|
+
model: params.model,
|
|
300
|
+
};
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
async getResponse(response: ChatCompletionOutput): Promise<ChatCompletionOutput> {
|
|
304
|
+
if (
|
|
305
|
+
typeof response === "object" &&
|
|
306
|
+
Array.isArray(response?.choices) &&
|
|
307
|
+
typeof response?.created === "number" &&
|
|
308
|
+
typeof response?.id === "string" &&
|
|
309
|
+
typeof response?.model === "string" &&
|
|
310
|
+
/// Together.ai and Nebius do not output a system_fingerprint
|
|
311
|
+
(response.system_fingerprint === undefined ||
|
|
312
|
+
response.system_fingerprint === null ||
|
|
313
|
+
typeof response.system_fingerprint === "string") &&
|
|
314
|
+
typeof response?.usage === "object"
|
|
315
|
+
) {
|
|
316
|
+
return response;
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
320
|
+
}
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
export class BaseTextGenerationTask extends TaskProviderHelper implements TextGenerationTaskHelper {
|
|
324
|
+
constructor(provider: InferenceProvider, baseUrl: string, clientSideRoutingOnly: boolean = false) {
|
|
325
|
+
super(provider, baseUrl, clientSideRoutingOnly);
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
preparePayload(params: BodyParams): Record<string, unknown> {
|
|
329
|
+
return {
|
|
330
|
+
...params.args,
|
|
331
|
+
model: params.model,
|
|
332
|
+
};
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
makeRoute(): string {
|
|
336
|
+
return "v1/completions";
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
async getResponse(response: unknown): Promise<TextGenerationOutput> {
|
|
340
|
+
const res = toArray(response);
|
|
341
|
+
if (
|
|
342
|
+
Array.isArray(res) &&
|
|
343
|
+
res.length > 0 &&
|
|
344
|
+
res.every(
|
|
345
|
+
(x): x is { generated_text: string } =>
|
|
346
|
+
typeof x === "object" && !!x && "generated_text" in x && typeof x.generated_text === "string"
|
|
347
|
+
)
|
|
348
|
+
) {
|
|
349
|
+
return res[0];
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
353
|
+
}
|
|
354
|
+
}
|
|
@@ -14,37 +14,127 @@
|
|
|
14
14
|
*
|
|
15
15
|
* Thanks!
|
|
16
16
|
*/
|
|
17
|
-
import
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
}
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
return
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
17
|
+
import { InferenceOutputError } from "../lib/InferenceOutputError";
|
|
18
|
+
import { isUrl } from "../lib/isUrl";
|
|
19
|
+
import type { BodyParams, HeaderParams, UrlParams } from "../types";
|
|
20
|
+
import { omit } from "../utils/omit";
|
|
21
|
+
import { TaskProviderHelper, type TextToImageTaskHelper, type TextToVideoTaskHelper } from "./providerHelper";
|
|
22
|
+
|
|
23
|
+
export interface ReplicateOutput {
|
|
24
|
+
output?: string | string[];
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
abstract class ReplicateTask extends TaskProviderHelper {
|
|
28
|
+
constructor(url?: string) {
|
|
29
|
+
super("replicate", url || "https://api.replicate.com");
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
makeRoute(params: UrlParams): string {
|
|
33
|
+
if (params.model.includes(":")) {
|
|
34
|
+
return "v1/predictions";
|
|
35
|
+
}
|
|
36
|
+
return `v1/models/${params.model}/predictions`;
|
|
37
|
+
}
|
|
38
|
+
preparePayload(params: BodyParams): Record<string, unknown> {
|
|
39
|
+
return {
|
|
40
|
+
input: {
|
|
41
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
42
|
+
...(params.args.parameters as Record<string, unknown>),
|
|
43
|
+
prompt: params.args.inputs,
|
|
44
|
+
},
|
|
45
|
+
version: params.model.includes(":") ? params.model.split(":")[1] : undefined,
|
|
46
|
+
};
|
|
47
|
+
}
|
|
48
|
+
override prepareHeaders(params: HeaderParams, binary: boolean): Record<string, string> {
|
|
49
|
+
const headers: Record<string, string> = { Authorization: `Bearer ${params.accessToken}`, Prefer: "wait" };
|
|
50
|
+
if (!binary) {
|
|
51
|
+
headers["Content-Type"] = "application/json";
|
|
52
|
+
}
|
|
53
|
+
return headers;
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
override makeUrl(params: UrlParams): string {
|
|
57
|
+
const baseUrl = this.makeBaseUrl(params);
|
|
58
|
+
if (params.model.includes(":")) {
|
|
59
|
+
return `${baseUrl}/v1/predictions`;
|
|
60
|
+
}
|
|
61
|
+
return `${baseUrl}/v1/models/${params.model}/predictions`;
|
|
62
|
+
}
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
export class ReplicateTextToImageTask extends ReplicateTask implements TextToImageTaskHelper {
|
|
66
|
+
override async getResponse(
|
|
67
|
+
res: ReplicateOutput | Blob,
|
|
68
|
+
url?: string,
|
|
69
|
+
headers?: Record<string, string>,
|
|
70
|
+
outputType?: "url" | "blob"
|
|
71
|
+
): Promise<string | Blob> {
|
|
72
|
+
void url;
|
|
73
|
+
void headers;
|
|
74
|
+
if (
|
|
75
|
+
typeof res === "object" &&
|
|
76
|
+
"output" in res &&
|
|
77
|
+
Array.isArray(res.output) &&
|
|
78
|
+
res.output.length > 0 &&
|
|
79
|
+
typeof res.output[0] === "string"
|
|
80
|
+
) {
|
|
81
|
+
if (outputType === "url") {
|
|
82
|
+
return res.output[0];
|
|
83
|
+
}
|
|
84
|
+
const urlResponse = await fetch(res.output[0]);
|
|
85
|
+
return await urlResponse.blob();
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
throw new InferenceOutputError("Expected Replicate text-to-image response format");
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
export class ReplicateTextToSpeechTask extends ReplicateTask {
|
|
93
|
+
override preparePayload(params: BodyParams): Record<string, unknown> {
|
|
94
|
+
const payload = super.preparePayload(params);
|
|
95
|
+
|
|
96
|
+
const input = payload["input"];
|
|
97
|
+
if (typeof input === "object" && input !== null && "prompt" in input) {
|
|
98
|
+
const inputObj = input as Record<string, unknown>;
|
|
99
|
+
inputObj["text"] = inputObj["prompt"];
|
|
100
|
+
delete inputObj["prompt"];
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
return payload;
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
override async getResponse(response: ReplicateOutput): Promise<Blob> {
|
|
107
|
+
if (response instanceof Blob) {
|
|
108
|
+
return response;
|
|
109
|
+
}
|
|
110
|
+
if (response && typeof response === "object") {
|
|
111
|
+
if ("output" in response) {
|
|
112
|
+
if (typeof response.output === "string") {
|
|
113
|
+
const urlResponse = await fetch(response.output);
|
|
114
|
+
return await urlResponse.blob();
|
|
115
|
+
} else if (Array.isArray(response.output)) {
|
|
116
|
+
const urlResponse = await fetch(response.output[0]);
|
|
117
|
+
return await urlResponse.blob();
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
}
|
|
121
|
+
throw new InferenceOutputError("Expected Blob or object with output");
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
export class ReplicateTextToVideoTask extends ReplicateTask implements TextToVideoTaskHelper {
|
|
126
|
+
override async getResponse(response: ReplicateOutput): Promise<Blob> {
|
|
127
|
+
if (
|
|
128
|
+
typeof response === "object" &&
|
|
129
|
+
!!response &&
|
|
130
|
+
"output" in response &&
|
|
131
|
+
typeof response.output === "string" &&
|
|
132
|
+
isUrl(response.output)
|
|
133
|
+
) {
|
|
134
|
+
const urlResponse = await fetch(response.output);
|
|
135
|
+
return await urlResponse.blob();
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
throw new InferenceOutputError("Expected { output: string }");
|
|
139
|
+
}
|
|
140
|
+
}
|
|
@@ -14,35 +14,10 @@
|
|
|
14
14
|
*
|
|
15
15
|
* Thanks!
|
|
16
16
|
*/
|
|
17
|
-
import
|
|
17
|
+
import { BaseConversationalTask } from "./providerHelper";
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
return SAMBANOVA_API_BASE_URL;
|
|
23
|
-
};
|
|
24
|
-
|
|
25
|
-
const makeBody = (params: BodyParams): Record<string, unknown> => {
|
|
26
|
-
return {
|
|
27
|
-
...params.args,
|
|
28
|
-
...(params.chatCompletion ? { model: params.model } : undefined),
|
|
29
|
-
};
|
|
30
|
-
};
|
|
31
|
-
|
|
32
|
-
const makeHeaders = (params: HeaderParams): Record<string, string> => {
|
|
33
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
34
|
-
};
|
|
35
|
-
|
|
36
|
-
const makeUrl = (params: UrlParams): string => {
|
|
37
|
-
if (params.chatCompletion) {
|
|
38
|
-
return `${params.baseUrl}/v1/chat/completions`;
|
|
19
|
+
export class SambanovaConversationalTask extends BaseConversationalTask {
|
|
20
|
+
constructor() {
|
|
21
|
+
super("sambanova", "https://api.sambanova.ai");
|
|
39
22
|
}
|
|
40
|
-
|
|
41
|
-
};
|
|
42
|
-
|
|
43
|
-
export const SAMBANOVA_CONFIG: ProviderConfig = {
|
|
44
|
-
makeBaseUrl,
|
|
45
|
-
makeBody,
|
|
46
|
-
makeHeaders,
|
|
47
|
-
makeUrl,
|
|
48
|
-
};
|
|
23
|
+
}
|
|
@@ -14,41 +14,105 @@
|
|
|
14
14
|
*
|
|
15
15
|
* Thanks!
|
|
16
16
|
*/
|
|
17
|
-
import type {
|
|
17
|
+
import type { ChatCompletionOutput, TextGenerationOutput, TextGenerationOutputFinishReason } from "@huggingface/tasks";
|
|
18
|
+
import { InferenceOutputError } from "../lib/InferenceOutputError";
|
|
19
|
+
import type { BodyParams } from "../types";
|
|
20
|
+
import { omit } from "../utils/omit";
|
|
21
|
+
import {
|
|
22
|
+
BaseConversationalTask,
|
|
23
|
+
BaseTextGenerationTask,
|
|
24
|
+
TaskProviderHelper,
|
|
25
|
+
type TextToImageTaskHelper,
|
|
26
|
+
} from "./providerHelper";
|
|
18
27
|
|
|
19
28
|
const TOGETHER_API_BASE_URL = "https://api.together.xyz";
|
|
20
29
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
30
|
+
interface TogetherTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
|
|
31
|
+
choices: Array<{
|
|
32
|
+
text: string;
|
|
33
|
+
finish_reason: TextGenerationOutputFinishReason;
|
|
34
|
+
seed: number;
|
|
35
|
+
logprobs: unknown;
|
|
36
|
+
index: number;
|
|
37
|
+
}>;
|
|
38
|
+
}
|
|
24
39
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
};
|
|
40
|
+
interface TogetherBase64ImageGeneration {
|
|
41
|
+
data: Array<{
|
|
42
|
+
b64_json: string;
|
|
43
|
+
}>;
|
|
44
|
+
}
|
|
31
45
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
46
|
+
export class TogetherConversationalTask extends BaseConversationalTask {
|
|
47
|
+
constructor() {
|
|
48
|
+
super("together", TOGETHER_API_BASE_URL);
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
export class TogetherTextGenerationTask extends BaseTextGenerationTask {
|
|
53
|
+
constructor() {
|
|
54
|
+
super("together", TOGETHER_API_BASE_URL);
|
|
55
|
+
}
|
|
35
56
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
57
|
+
override preparePayload(params: BodyParams): Record<string, unknown> {
|
|
58
|
+
return {
|
|
59
|
+
model: params.model,
|
|
60
|
+
...params.args,
|
|
61
|
+
prompt: params.args.inputs,
|
|
62
|
+
};
|
|
39
63
|
}
|
|
40
|
-
|
|
41
|
-
|
|
64
|
+
|
|
65
|
+
override async getResponse(response: TogetherTextCompletionOutput): Promise<TextGenerationOutput> {
|
|
66
|
+
if (
|
|
67
|
+
typeof response === "object" &&
|
|
68
|
+
"choices" in response &&
|
|
69
|
+
Array.isArray(response?.choices) &&
|
|
70
|
+
typeof response?.model === "string"
|
|
71
|
+
) {
|
|
72
|
+
const completion = response.choices[0];
|
|
73
|
+
return {
|
|
74
|
+
generated_text: completion.text,
|
|
75
|
+
};
|
|
76
|
+
}
|
|
77
|
+
throw new InferenceOutputError("Expected Together text generation response format");
|
|
42
78
|
}
|
|
43
|
-
|
|
44
|
-
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
export class TogetherTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper {
|
|
82
|
+
constructor() {
|
|
83
|
+
super("together", TOGETHER_API_BASE_URL);
|
|
45
84
|
}
|
|
46
|
-
return params.baseUrl;
|
|
47
|
-
};
|
|
48
85
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
86
|
+
makeRoute(): string {
|
|
87
|
+
return "v1/images/generations";
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
preparePayload(params: BodyParams): Record<string, unknown> {
|
|
91
|
+
return {
|
|
92
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
93
|
+
...(params.args.parameters as Record<string, unknown>),
|
|
94
|
+
prompt: params.args.inputs,
|
|
95
|
+
response_format: "base64",
|
|
96
|
+
model: params.model,
|
|
97
|
+
};
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
async getResponse(response: TogetherBase64ImageGeneration, outputType?: "url" | "blob"): Promise<string | Blob> {
|
|
101
|
+
if (
|
|
102
|
+
typeof response === "object" &&
|
|
103
|
+
"data" in response &&
|
|
104
|
+
Array.isArray(response.data) &&
|
|
105
|
+
response.data.length > 0 &&
|
|
106
|
+
"b64_json" in response.data[0] &&
|
|
107
|
+
typeof response.data[0].b64_json === "string"
|
|
108
|
+
) {
|
|
109
|
+
const base64Data = response.data[0].b64_json;
|
|
110
|
+
if (outputType === "url") {
|
|
111
|
+
return `data:image/jpeg;base64,${base64Data}`;
|
|
112
|
+
}
|
|
113
|
+
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
throw new InferenceOutputError("Expected Together text-to-image response format");
|
|
117
|
+
}
|
|
118
|
+
}
|