@huggingface/inference 3.6.2 → 3.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +0 -25
- package/dist/index.cjs +1232 -898
- package/dist/index.js +1234 -900
- package/dist/src/config.d.ts +1 -0
- package/dist/src/config.d.ts.map +1 -1
- package/dist/src/lib/getProviderHelper.d.ts +37 -0
- package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
- package/dist/src/lib/makeRequestOptions.d.ts +0 -2
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/black-forest-labs.d.ts +14 -18
- package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
- package/dist/src/providers/cerebras.d.ts +4 -2
- package/dist/src/providers/cerebras.d.ts.map +1 -1
- package/dist/src/providers/cohere.d.ts +5 -2
- package/dist/src/providers/cohere.d.ts.map +1 -1
- package/dist/src/providers/fal-ai.d.ts +50 -3
- package/dist/src/providers/fal-ai.d.ts.map +1 -1
- package/dist/src/providers/fireworks-ai.d.ts +5 -2
- package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
- package/dist/src/providers/hf-inference.d.ts +125 -2
- package/dist/src/providers/hf-inference.d.ts.map +1 -1
- package/dist/src/providers/hyperbolic.d.ts +31 -2
- package/dist/src/providers/hyperbolic.d.ts.map +1 -1
- package/dist/src/providers/nebius.d.ts +20 -18
- package/dist/src/providers/nebius.d.ts.map +1 -1
- package/dist/src/providers/novita.d.ts +21 -18
- package/dist/src/providers/novita.d.ts.map +1 -1
- package/dist/src/providers/openai.d.ts +4 -2
- package/dist/src/providers/openai.d.ts.map +1 -1
- package/dist/src/providers/providerHelper.d.ts +182 -0
- package/dist/src/providers/providerHelper.d.ts.map +1 -0
- package/dist/src/providers/replicate.d.ts +23 -19
- package/dist/src/providers/replicate.d.ts.map +1 -1
- package/dist/src/providers/sambanova.d.ts +4 -2
- package/dist/src/providers/sambanova.d.ts.map +1 -1
- package/dist/src/providers/together.d.ts +32 -2
- package/dist/src/providers/together.d.ts.map +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/audio/utils.d.ts +2 -1
- package/dist/src/tasks/audio/utils.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts +1 -2
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts +1 -2
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/index.d.ts +6 -6
- package/dist/src/tasks/index.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
- package/dist/src/types.d.ts +10 -13
- package/dist/src/types.d.ts.map +1 -1
- package/dist/src/utils/request.d.ts +27 -0
- package/dist/src/utils/request.d.ts.map +1 -0
- package/package.json +3 -3
- package/src/config.ts +1 -0
- package/src/lib/getProviderHelper.ts +270 -0
- package/src/lib/makeRequestOptions.ts +36 -90
- package/src/providers/black-forest-labs.ts +73 -22
- package/src/providers/cerebras.ts +6 -27
- package/src/providers/cohere.ts +9 -28
- package/src/providers/fal-ai.ts +195 -77
- package/src/providers/fireworks-ai.ts +8 -29
- package/src/providers/hf-inference.ts +555 -34
- package/src/providers/hyperbolic.ts +107 -29
- package/src/providers/nebius.ts +65 -29
- package/src/providers/novita.ts +68 -32
- package/src/providers/openai.ts +6 -32
- package/src/providers/providerHelper.ts +354 -0
- package/src/providers/replicate.ts +124 -34
- package/src/providers/sambanova.ts +5 -30
- package/src/providers/together.ts +92 -28
- package/src/snippets/getInferenceSnippets.ts +16 -9
- package/src/snippets/templates.exported.ts +2 -2
- package/src/tasks/audio/audioClassification.ts +6 -9
- package/src/tasks/audio/audioToAudio.ts +5 -28
- package/src/tasks/audio/automaticSpeechRecognition.ts +7 -6
- package/src/tasks/audio/textToSpeech.ts +6 -30
- package/src/tasks/audio/utils.ts +2 -1
- package/src/tasks/custom/request.ts +7 -34
- package/src/tasks/custom/streamingRequest.ts +5 -87
- package/src/tasks/cv/imageClassification.ts +5 -9
- package/src/tasks/cv/imageSegmentation.ts +5 -10
- package/src/tasks/cv/imageToImage.ts +5 -8
- package/src/tasks/cv/imageToText.ts +8 -13
- package/src/tasks/cv/objectDetection.ts +6 -21
- package/src/tasks/cv/textToImage.ts +10 -138
- package/src/tasks/cv/textToVideo.ts +11 -59
- package/src/tasks/cv/zeroShotImageClassification.ts +7 -12
- package/src/tasks/index.ts +6 -6
- package/src/tasks/multimodal/documentQuestionAnswering.ts +10 -26
- package/src/tasks/multimodal/visualQuestionAnswering.ts +6 -12
- package/src/tasks/nlp/chatCompletion.ts +7 -23
- package/src/tasks/nlp/chatCompletionStream.ts +4 -5
- package/src/tasks/nlp/featureExtraction.ts +5 -20
- package/src/tasks/nlp/fillMask.ts +5 -18
- package/src/tasks/nlp/questionAnswering.ts +5 -23
- package/src/tasks/nlp/sentenceSimilarity.ts +5 -18
- package/src/tasks/nlp/summarization.ts +5 -8
- package/src/tasks/nlp/tableQuestionAnswering.ts +5 -29
- package/src/tasks/nlp/textClassification.ts +8 -14
- package/src/tasks/nlp/textGeneration.ts +13 -80
- package/src/tasks/nlp/textGenerationStream.ts +2 -2
- package/src/tasks/nlp/tokenClassification.ts +8 -24
- package/src/tasks/nlp/translation.ts +5 -8
- package/src/tasks/nlp/zeroShotClassification.ts +8 -22
- package/src/tasks/tabular/tabularClassification.ts +5 -8
- package/src/tasks/tabular/tabularRegression.ts +5 -8
- package/src/types.ts +11 -14
- package/src/utils/request.ts +161 -0
|
@@ -10,38 +10,559 @@
|
|
|
10
10
|
*
|
|
11
11
|
* Thanks!
|
|
12
12
|
*/
|
|
13
|
+
import type {
|
|
14
|
+
AudioClassificationOutput,
|
|
15
|
+
AutomaticSpeechRecognitionOutput,
|
|
16
|
+
ChatCompletionOutput,
|
|
17
|
+
DocumentQuestionAnsweringOutput,
|
|
18
|
+
FeatureExtractionOutput,
|
|
19
|
+
FillMaskOutput,
|
|
20
|
+
ImageClassificationOutput,
|
|
21
|
+
ImageSegmentationOutput,
|
|
22
|
+
ImageToTextOutput,
|
|
23
|
+
ObjectDetectionOutput,
|
|
24
|
+
QuestionAnsweringOutput,
|
|
25
|
+
SentenceSimilarityOutput,
|
|
26
|
+
SummarizationOutput,
|
|
27
|
+
TableQuestionAnsweringOutput,
|
|
28
|
+
TextClassificationOutput,
|
|
29
|
+
TextGenerationOutput,
|
|
30
|
+
TokenClassificationOutput,
|
|
31
|
+
TranslationOutput,
|
|
32
|
+
VisualQuestionAnsweringOutput,
|
|
33
|
+
ZeroShotClassificationOutput,
|
|
34
|
+
ZeroShotImageClassificationOutput,
|
|
35
|
+
} from "@huggingface/tasks";
|
|
13
36
|
import { HF_ROUTER_URL } from "../config";
|
|
14
|
-
import
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
};
|
|
37
|
+
import { InferenceOutputError } from "../lib/InferenceOutputError";
|
|
38
|
+
import type { TabularClassificationOutput } from "../tasks/tabular/tabularClassification";
|
|
39
|
+
import type { BodyParams, UrlParams } from "../types";
|
|
40
|
+
import { toArray } from "../utils/toArray";
|
|
41
|
+
import type {
|
|
42
|
+
AudioClassificationTaskHelper,
|
|
43
|
+
AudioToAudioTaskHelper,
|
|
44
|
+
AutomaticSpeechRecognitionTaskHelper,
|
|
45
|
+
ConversationalTaskHelper,
|
|
46
|
+
DocumentQuestionAnsweringTaskHelper,
|
|
47
|
+
FeatureExtractionTaskHelper,
|
|
48
|
+
FillMaskTaskHelper,
|
|
49
|
+
ImageClassificationTaskHelper,
|
|
50
|
+
ImageSegmentationTaskHelper,
|
|
51
|
+
ImageToImageTaskHelper,
|
|
52
|
+
ImageToTextTaskHelper,
|
|
53
|
+
ObjectDetectionTaskHelper,
|
|
54
|
+
QuestionAnsweringTaskHelper,
|
|
55
|
+
SentenceSimilarityTaskHelper,
|
|
56
|
+
SummarizationTaskHelper,
|
|
57
|
+
TableQuestionAnsweringTaskHelper,
|
|
58
|
+
TabularClassificationTaskHelper,
|
|
59
|
+
TabularRegressionTaskHelper,
|
|
60
|
+
TextClassificationTaskHelper,
|
|
61
|
+
TextGenerationTaskHelper,
|
|
62
|
+
TextToAudioTaskHelper,
|
|
63
|
+
TextToImageTaskHelper,
|
|
64
|
+
TextToSpeechTaskHelper,
|
|
65
|
+
TokenClassificationTaskHelper,
|
|
66
|
+
TranslationTaskHelper,
|
|
67
|
+
VisualQuestionAnsweringTaskHelper,
|
|
68
|
+
ZeroShotClassificationTaskHelper,
|
|
69
|
+
ZeroShotImageClassificationTaskHelper,
|
|
70
|
+
} from "./providerHelper";
|
|
71
|
+
|
|
72
|
+
import { TaskProviderHelper } from "./providerHelper";
|
|
73
|
+
|
|
74
|
+
interface Base64ImageGeneration {
|
|
75
|
+
data: Array<{
|
|
76
|
+
b64_json: string;
|
|
77
|
+
}>;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
interface OutputUrlImageGeneration {
|
|
81
|
+
output: string[];
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
interface AudioToAudioOutput {
|
|
85
|
+
blob: string;
|
|
86
|
+
"content-type": string;
|
|
87
|
+
label: string;
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
export class HFInferenceTask extends TaskProviderHelper {
|
|
91
|
+
constructor() {
|
|
92
|
+
super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
|
|
93
|
+
}
|
|
94
|
+
preparePayload(params: BodyParams): Record<string, unknown> {
|
|
95
|
+
return params.args;
|
|
96
|
+
}
|
|
97
|
+
override makeUrl(params: UrlParams): string {
|
|
98
|
+
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
|
|
99
|
+
return params.model;
|
|
100
|
+
}
|
|
101
|
+
return super.makeUrl(params);
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
makeRoute(params: UrlParams): string {
|
|
105
|
+
if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
|
|
106
|
+
// when deployed on hf-inference, those two tasks are automatically compatible with one another.
|
|
107
|
+
return `pipeline/${params.task}/${params.model}`;
|
|
108
|
+
}
|
|
109
|
+
return `models/${params.model}`;
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
override async getResponse(response: unknown): Promise<unknown> {
|
|
113
|
+
return response;
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
export class HFInferenceTextToImageTask extends HFInferenceTask implements TextToImageTaskHelper {
|
|
118
|
+
override async getResponse(
|
|
119
|
+
response: Base64ImageGeneration | OutputUrlImageGeneration,
|
|
120
|
+
url?: string,
|
|
121
|
+
headers?: HeadersInit,
|
|
122
|
+
outputType?: "url" | "blob"
|
|
123
|
+
): Promise<string | Blob> {
|
|
124
|
+
if (!response) {
|
|
125
|
+
throw new InferenceOutputError("response is undefined");
|
|
126
|
+
}
|
|
127
|
+
if (typeof response == "object") {
|
|
128
|
+
if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
|
|
129
|
+
const base64Data = response.data[0].b64_json;
|
|
130
|
+
if (outputType === "url") {
|
|
131
|
+
return `data:image/jpeg;base64,${base64Data}`;
|
|
132
|
+
}
|
|
133
|
+
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
134
|
+
return await base64Response.blob();
|
|
135
|
+
}
|
|
136
|
+
if ("output" in response && Array.isArray(response.output)) {
|
|
137
|
+
if (outputType === "url") {
|
|
138
|
+
return response.output[0];
|
|
139
|
+
}
|
|
140
|
+
const urlResponse = await fetch(response.output[0]);
|
|
141
|
+
const blob = await urlResponse.blob();
|
|
142
|
+
return blob;
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
if (response instanceof Blob) {
|
|
146
|
+
if (outputType === "url") {
|
|
147
|
+
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
|
|
148
|
+
return `data:image/jpeg;base64,${b64}`;
|
|
149
|
+
}
|
|
150
|
+
return response;
|
|
151
|
+
}
|
|
152
|
+
throw new InferenceOutputError("Expected a Blob ");
|
|
153
|
+
}
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
export class HFInferenceConversationalTask extends HFInferenceTask implements ConversationalTaskHelper {
|
|
157
|
+
override makeUrl(params: UrlParams): string {
|
|
158
|
+
let url: string;
|
|
159
|
+
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
|
|
160
|
+
url = params.model.trim();
|
|
161
|
+
} else {
|
|
162
|
+
url = `${this.makeBaseUrl(params)}/models/${params.model}`;
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
url = url.replace(/\/+$/, "");
|
|
166
|
+
if (url.endsWith("/v1")) {
|
|
167
|
+
url += "/chat/completions";
|
|
168
|
+
} else if (!url.endsWith("/chat/completions")) {
|
|
169
|
+
url += "/v1/chat/completions";
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
return url;
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
override preparePayload(params: BodyParams): Record<string, unknown> {
|
|
176
|
+
return {
|
|
177
|
+
...params.args,
|
|
178
|
+
model: params.model,
|
|
179
|
+
};
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
override async getResponse(response: ChatCompletionOutput): Promise<ChatCompletionOutput> {
|
|
183
|
+
return response;
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
export class HFInferenceTextGenerationTask extends HFInferenceTask implements TextGenerationTaskHelper {
|
|
188
|
+
override async getResponse(response: TextGenerationOutput | TextGenerationOutput[]): Promise<TextGenerationOutput> {
|
|
189
|
+
const res = toArray(response);
|
|
190
|
+
if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
|
|
191
|
+
return (res as TextGenerationOutput[])?.[0];
|
|
192
|
+
}
|
|
193
|
+
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
194
|
+
}
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
export class HFInferenceAudioClassificationTask extends HFInferenceTask implements AudioClassificationTaskHelper {
|
|
198
|
+
override async getResponse(response: unknown): Promise<AudioClassificationOutput> {
|
|
199
|
+
// Add type checking/validation for the 'unknown' input
|
|
200
|
+
if (
|
|
201
|
+
Array.isArray(response) &&
|
|
202
|
+
response.every(
|
|
203
|
+
(x): x is { label: string; score: number } =>
|
|
204
|
+
typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
|
|
205
|
+
)
|
|
206
|
+
) {
|
|
207
|
+
// If validation passes, it's safe to return as AudioClassificationOutput
|
|
208
|
+
return response;
|
|
209
|
+
}
|
|
210
|
+
// If validation fails, throw an error
|
|
211
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
|
|
212
|
+
}
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
export class HFInferenceAutomaticSpeechRecognitionTask
|
|
216
|
+
extends HFInferenceTask
|
|
217
|
+
implements AutomaticSpeechRecognitionTaskHelper
|
|
218
|
+
{
|
|
219
|
+
override async getResponse(response: AutomaticSpeechRecognitionOutput): Promise<AutomaticSpeechRecognitionOutput> {
|
|
220
|
+
return response;
|
|
221
|
+
}
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
export class HFInferenceAudioToAudioTask extends HFInferenceTask implements AudioToAudioTaskHelper {
|
|
225
|
+
override async getResponse(response: AudioToAudioOutput[]): Promise<AudioToAudioOutput[]> {
|
|
226
|
+
if (!Array.isArray(response)) {
|
|
227
|
+
throw new InferenceOutputError("Expected Array");
|
|
228
|
+
}
|
|
229
|
+
if (
|
|
230
|
+
!response.every((elem): elem is AudioToAudioOutput => {
|
|
231
|
+
return (
|
|
232
|
+
typeof elem === "object" &&
|
|
233
|
+
elem &&
|
|
234
|
+
"label" in elem &&
|
|
235
|
+
typeof elem.label === "string" &&
|
|
236
|
+
"content-type" in elem &&
|
|
237
|
+
typeof elem["content-type"] === "string" &&
|
|
238
|
+
"blob" in elem &&
|
|
239
|
+
typeof elem.blob === "string"
|
|
240
|
+
);
|
|
241
|
+
})
|
|
242
|
+
) {
|
|
243
|
+
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
|
|
244
|
+
}
|
|
245
|
+
return response;
|
|
246
|
+
}
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
export class HFInferenceDocumentQuestionAnsweringTask
|
|
250
|
+
extends HFInferenceTask
|
|
251
|
+
implements DocumentQuestionAnsweringTaskHelper
|
|
252
|
+
{
|
|
253
|
+
override async getResponse(
|
|
254
|
+
response: DocumentQuestionAnsweringOutput
|
|
255
|
+
): Promise<DocumentQuestionAnsweringOutput[number]> {
|
|
256
|
+
if (
|
|
257
|
+
Array.isArray(response) &&
|
|
258
|
+
response.every(
|
|
259
|
+
(elem) =>
|
|
260
|
+
typeof elem === "object" &&
|
|
261
|
+
!!elem &&
|
|
262
|
+
typeof elem?.answer === "string" &&
|
|
263
|
+
(typeof elem.end === "number" || typeof elem.end === "undefined") &&
|
|
264
|
+
(typeof elem.score === "number" || typeof elem.score === "undefined") &&
|
|
265
|
+
(typeof elem.start === "number" || typeof elem.start === "undefined")
|
|
266
|
+
)
|
|
267
|
+
) {
|
|
268
|
+
return response[0];
|
|
269
|
+
}
|
|
270
|
+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
271
|
+
}
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
export class HFInferenceFeatureExtractionTask extends HFInferenceTask implements FeatureExtractionTaskHelper {
|
|
275
|
+
override async getResponse(response: FeatureExtractionOutput): Promise<FeatureExtractionOutput> {
|
|
276
|
+
const isNumArrayRec = (arr: unknown[], maxDepth: number, curDepth = 0): boolean => {
|
|
277
|
+
if (curDepth > maxDepth) return false;
|
|
278
|
+
if (arr.every((x) => Array.isArray(x))) {
|
|
279
|
+
return arr.every((x) => isNumArrayRec(x as unknown[], maxDepth, curDepth + 1));
|
|
280
|
+
} else {
|
|
281
|
+
return arr.every((x) => typeof x === "number");
|
|
282
|
+
}
|
|
283
|
+
};
|
|
284
|
+
if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
|
|
285
|
+
return response;
|
|
286
|
+
}
|
|
287
|
+
throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
|
|
288
|
+
}
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
export class HFInferenceImageClassificationTask extends HFInferenceTask implements ImageClassificationTaskHelper {
|
|
292
|
+
override async getResponse(response: ImageClassificationOutput): Promise<ImageClassificationOutput> {
|
|
293
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
|
|
294
|
+
return response;
|
|
295
|
+
}
|
|
296
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
297
|
+
}
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
export class HFInferenceImageSegmentationTask extends HFInferenceTask implements ImageSegmentationTaskHelper {
|
|
301
|
+
override async getResponse(response: ImageSegmentationOutput): Promise<ImageSegmentationOutput> {
|
|
302
|
+
if (
|
|
303
|
+
Array.isArray(response) &&
|
|
304
|
+
response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")
|
|
305
|
+
) {
|
|
306
|
+
return response;
|
|
307
|
+
}
|
|
308
|
+
throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
|
|
309
|
+
}
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
export class HFInferenceImageToTextTask extends HFInferenceTask implements ImageToTextTaskHelper {
|
|
313
|
+
override async getResponse(response: ImageToTextOutput): Promise<ImageToTextOutput> {
|
|
314
|
+
if (typeof response?.generated_text !== "string") {
|
|
315
|
+
throw new InferenceOutputError("Expected {generated_text: string}");
|
|
316
|
+
}
|
|
317
|
+
return response;
|
|
318
|
+
}
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
export class HFInferenceImageToImageTask extends HFInferenceTask implements ImageToImageTaskHelper {
|
|
322
|
+
override async getResponse(response: Blob): Promise<Blob> {
|
|
323
|
+
if (response instanceof Blob) {
|
|
324
|
+
return response;
|
|
325
|
+
}
|
|
326
|
+
throw new InferenceOutputError("Expected Blob");
|
|
327
|
+
}
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
export class HFInferenceObjectDetectionTask extends HFInferenceTask implements ObjectDetectionTaskHelper {
|
|
331
|
+
override async getResponse(response: ObjectDetectionOutput): Promise<ObjectDetectionOutput> {
|
|
332
|
+
if (
|
|
333
|
+
Array.isArray(response) &&
|
|
334
|
+
response.every(
|
|
335
|
+
(x) =>
|
|
336
|
+
typeof x.label === "string" &&
|
|
337
|
+
typeof x.score === "number" &&
|
|
338
|
+
typeof x.box.xmin === "number" &&
|
|
339
|
+
typeof x.box.ymin === "number" &&
|
|
340
|
+
typeof x.box.xmax === "number" &&
|
|
341
|
+
typeof x.box.ymax === "number"
|
|
342
|
+
)
|
|
343
|
+
) {
|
|
344
|
+
return response;
|
|
345
|
+
}
|
|
346
|
+
throw new InferenceOutputError(
|
|
347
|
+
"Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
|
|
348
|
+
);
|
|
349
|
+
}
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
export class HFInferenceZeroShotImageClassificationTask
|
|
353
|
+
extends HFInferenceTask
|
|
354
|
+
implements ZeroShotImageClassificationTaskHelper
|
|
355
|
+
{
|
|
356
|
+
override async getResponse(response: ZeroShotImageClassificationOutput): Promise<ZeroShotImageClassificationOutput> {
|
|
357
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
|
|
358
|
+
return response;
|
|
359
|
+
}
|
|
360
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
361
|
+
}
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
export class HFInferenceTextClassificationTask extends HFInferenceTask implements TextClassificationTaskHelper {
|
|
365
|
+
override async getResponse(response: TextClassificationOutput): Promise<TextClassificationOutput> {
|
|
366
|
+
const output = response?.[0];
|
|
367
|
+
if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
|
|
368
|
+
return output;
|
|
369
|
+
}
|
|
370
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
371
|
+
}
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
export class HFInferenceQuestionAnsweringTask extends HFInferenceTask implements QuestionAnsweringTaskHelper {
|
|
375
|
+
override async getResponse(
|
|
376
|
+
response: QuestionAnsweringOutput | QuestionAnsweringOutput[number]
|
|
377
|
+
): Promise<QuestionAnsweringOutput[number]> {
|
|
378
|
+
if (
|
|
379
|
+
Array.isArray(response)
|
|
380
|
+
? response.every(
|
|
381
|
+
(elem) =>
|
|
382
|
+
typeof elem === "object" &&
|
|
383
|
+
!!elem &&
|
|
384
|
+
typeof elem.answer === "string" &&
|
|
385
|
+
typeof elem.end === "number" &&
|
|
386
|
+
typeof elem.score === "number" &&
|
|
387
|
+
typeof elem.start === "number"
|
|
388
|
+
)
|
|
389
|
+
: typeof response === "object" &&
|
|
390
|
+
!!response &&
|
|
391
|
+
typeof response.answer === "string" &&
|
|
392
|
+
typeof response.end === "number" &&
|
|
393
|
+
typeof response.score === "number" &&
|
|
394
|
+
typeof response.start === "number"
|
|
395
|
+
) {
|
|
396
|
+
return Array.isArray(response) ? response[0] : response;
|
|
397
|
+
}
|
|
398
|
+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
399
|
+
}
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
export class HFInferenceFillMaskTask extends HFInferenceTask implements FillMaskTaskHelper {
|
|
403
|
+
override async getResponse(response: FillMaskOutput): Promise<FillMaskOutput> {
|
|
404
|
+
if (
|
|
405
|
+
Array.isArray(response) &&
|
|
406
|
+
response.every(
|
|
407
|
+
(x) =>
|
|
408
|
+
typeof x.score === "number" &&
|
|
409
|
+
typeof x.sequence === "string" &&
|
|
410
|
+
typeof x.token === "number" &&
|
|
411
|
+
typeof x.token_str === "string"
|
|
412
|
+
)
|
|
413
|
+
) {
|
|
414
|
+
return response;
|
|
415
|
+
}
|
|
416
|
+
throw new InferenceOutputError(
|
|
417
|
+
"Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
|
|
418
|
+
);
|
|
419
|
+
}
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
export class HFInferenceZeroShotClassificationTask extends HFInferenceTask implements ZeroShotClassificationTaskHelper {
|
|
423
|
+
override async getResponse(response: ZeroShotClassificationOutput): Promise<ZeroShotClassificationOutput> {
|
|
424
|
+
if (
|
|
425
|
+
Array.isArray(response) &&
|
|
426
|
+
response.every(
|
|
427
|
+
(x) =>
|
|
428
|
+
Array.isArray(x.labels) &&
|
|
429
|
+
x.labels.every((_label) => typeof _label === "string") &&
|
|
430
|
+
Array.isArray(x.scores) &&
|
|
431
|
+
x.scores.every((_score) => typeof _score === "number") &&
|
|
432
|
+
typeof x.sequence === "string"
|
|
433
|
+
)
|
|
434
|
+
) {
|
|
435
|
+
return response;
|
|
436
|
+
}
|
|
437
|
+
throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
|
|
438
|
+
}
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
export class HFInferenceSentenceSimilarityTask extends HFInferenceTask implements SentenceSimilarityTaskHelper {
|
|
442
|
+
override async getResponse(response: SentenceSimilarityOutput): Promise<SentenceSimilarityOutput> {
|
|
443
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
444
|
+
return response;
|
|
445
|
+
}
|
|
446
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
447
|
+
}
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
export class HFInferenceTableQuestionAnsweringTask extends HFInferenceTask implements TableQuestionAnsweringTaskHelper {
|
|
451
|
+
static validate(elem: unknown): elem is TableQuestionAnsweringOutput[number] {
|
|
452
|
+
return (
|
|
453
|
+
typeof elem === "object" &&
|
|
454
|
+
!!elem &&
|
|
455
|
+
"aggregator" in elem &&
|
|
456
|
+
typeof elem.aggregator === "string" &&
|
|
457
|
+
"answer" in elem &&
|
|
458
|
+
typeof elem.answer === "string" &&
|
|
459
|
+
"cells" in elem &&
|
|
460
|
+
Array.isArray(elem.cells) &&
|
|
461
|
+
elem.cells.every((x: unknown): x is string => typeof x === "string") &&
|
|
462
|
+
"coordinates" in elem &&
|
|
463
|
+
Array.isArray(elem.coordinates) &&
|
|
464
|
+
elem.coordinates.every(
|
|
465
|
+
(coord: unknown): coord is number[] => Array.isArray(coord) && coord.every((x) => typeof x === "number")
|
|
466
|
+
)
|
|
467
|
+
);
|
|
468
|
+
}
|
|
469
|
+
override async getResponse(response: TableQuestionAnsweringOutput): Promise<TableQuestionAnsweringOutput[number]> {
|
|
470
|
+
if (
|
|
471
|
+
Array.isArray(response) && Array.isArray(response)
|
|
472
|
+
? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem))
|
|
473
|
+
: HFInferenceTableQuestionAnsweringTask.validate(response)
|
|
474
|
+
) {
|
|
475
|
+
return Array.isArray(response) ? response[0] : response;
|
|
476
|
+
}
|
|
477
|
+
throw new InferenceOutputError(
|
|
478
|
+
"Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
|
|
479
|
+
);
|
|
480
|
+
}
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
export class HFInferenceTokenClassificationTask extends HFInferenceTask implements TokenClassificationTaskHelper {
|
|
484
|
+
override async getResponse(response: TokenClassificationOutput): Promise<TokenClassificationOutput> {
|
|
485
|
+
if (
|
|
486
|
+
Array.isArray(response) &&
|
|
487
|
+
response.every(
|
|
488
|
+
(x) =>
|
|
489
|
+
typeof x.end === "number" &&
|
|
490
|
+
typeof x.entity_group === "string" &&
|
|
491
|
+
typeof x.score === "number" &&
|
|
492
|
+
typeof x.start === "number" &&
|
|
493
|
+
typeof x.word === "string"
|
|
494
|
+
)
|
|
495
|
+
) {
|
|
496
|
+
return response;
|
|
497
|
+
}
|
|
498
|
+
throw new InferenceOutputError(
|
|
499
|
+
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
|
|
500
|
+
);
|
|
501
|
+
}
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
export class HFInferenceTranslationTask extends HFInferenceTask implements TranslationTaskHelper {
|
|
505
|
+
override async getResponse(response: TranslationOutput): Promise<TranslationOutput> {
|
|
506
|
+
if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
|
|
507
|
+
return response?.length === 1 ? response?.[0] : response;
|
|
508
|
+
}
|
|
509
|
+
throw new InferenceOutputError("Expected Array<{translation_text: string}>");
|
|
510
|
+
}
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
export class HFInferenceSummarizationTask extends HFInferenceTask implements SummarizationTaskHelper {
|
|
514
|
+
override async getResponse(response: SummarizationOutput): Promise<SummarizationOutput> {
|
|
515
|
+
if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
|
|
516
|
+
return response?.[0];
|
|
517
|
+
}
|
|
518
|
+
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
|
519
|
+
}
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
export class HFInferenceTextToSpeechTask extends HFInferenceTask implements TextToSpeechTaskHelper {
|
|
523
|
+
override async getResponse(response: Blob): Promise<Blob> {
|
|
524
|
+
return response;
|
|
525
|
+
}
|
|
526
|
+
}
|
|
527
|
+
|
|
528
|
+
export class HFInferenceTabularClassificationTask extends HFInferenceTask implements TabularClassificationTaskHelper {
|
|
529
|
+
override async getResponse(response: TabularClassificationOutput): Promise<TabularClassificationOutput> {
|
|
530
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
531
|
+
return response;
|
|
532
|
+
}
|
|
533
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
534
|
+
}
|
|
535
|
+
}
|
|
536
|
+
|
|
537
|
+
export class HFInferenceVisualQuestionAnsweringTask
|
|
538
|
+
extends HFInferenceTask
|
|
539
|
+
implements VisualQuestionAnsweringTaskHelper
|
|
540
|
+
{
|
|
541
|
+
override async getResponse(response: VisualQuestionAnsweringOutput): Promise<VisualQuestionAnsweringOutput[number]> {
|
|
542
|
+
if (
|
|
543
|
+
Array.isArray(response) &&
|
|
544
|
+
response.every(
|
|
545
|
+
(elem) =>
|
|
546
|
+
typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
|
|
547
|
+
)
|
|
548
|
+
) {
|
|
549
|
+
return response[0];
|
|
550
|
+
}
|
|
551
|
+
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
|
552
|
+
}
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
export class HFInferenceTabularRegressionTask extends HFInferenceTask implements TabularRegressionTaskHelper {
|
|
556
|
+
override async getResponse(response: number[]): Promise<number[]> {
|
|
557
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
558
|
+
return response;
|
|
559
|
+
}
|
|
560
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
561
|
+
}
|
|
562
|
+
}
|
|
563
|
+
|
|
564
|
+
export class HFInferenceTextToAudioTask extends HFInferenceTask implements TextToAudioTaskHelper {
|
|
565
|
+
override async getResponse(response: Blob): Promise<Blob> {
|
|
566
|
+
return response;
|
|
567
|
+
}
|
|
568
|
+
}
|