@huggingface/inference 3.7.0 → 3.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.cjs +1369 -941
- package/dist/index.js +1371 -943
- package/dist/src/lib/getInferenceProviderMapping.d.ts +21 -0
- package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -0
- package/dist/src/lib/getProviderHelper.d.ts +37 -0
- package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
- package/dist/src/lib/makeRequestOptions.d.ts +5 -5
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/black-forest-labs.d.ts +14 -18
- package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
- package/dist/src/providers/cerebras.d.ts +4 -2
- package/dist/src/providers/cerebras.d.ts.map +1 -1
- package/dist/src/providers/cohere.d.ts +5 -2
- package/dist/src/providers/cohere.d.ts.map +1 -1
- package/dist/src/providers/consts.d.ts +2 -3
- package/dist/src/providers/consts.d.ts.map +1 -1
- package/dist/src/providers/fal-ai.d.ts +50 -3
- package/dist/src/providers/fal-ai.d.ts.map +1 -1
- package/dist/src/providers/fireworks-ai.d.ts +5 -2
- package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
- package/dist/src/providers/hf-inference.d.ts +126 -2
- package/dist/src/providers/hf-inference.d.ts.map +1 -1
- package/dist/src/providers/hyperbolic.d.ts +31 -2
- package/dist/src/providers/hyperbolic.d.ts.map +1 -1
- package/dist/src/providers/nebius.d.ts +20 -18
- package/dist/src/providers/nebius.d.ts.map +1 -1
- package/dist/src/providers/novita.d.ts +21 -18
- package/dist/src/providers/novita.d.ts.map +1 -1
- package/dist/src/providers/openai.d.ts +4 -2
- package/dist/src/providers/openai.d.ts.map +1 -1
- package/dist/src/providers/providerHelper.d.ts +182 -0
- package/dist/src/providers/providerHelper.d.ts.map +1 -0
- package/dist/src/providers/replicate.d.ts +23 -19
- package/dist/src/providers/replicate.d.ts.map +1 -1
- package/dist/src/providers/sambanova.d.ts +4 -2
- package/dist/src/providers/sambanova.d.ts.map +1 -1
- package/dist/src/providers/together.d.ts +32 -2
- package/dist/src/providers/together.d.ts.map +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts +2 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/audio/utils.d.ts +2 -1
- package/dist/src/tasks/audio/utils.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts +0 -2
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts +0 -2
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/index.d.ts +6 -6
- package/dist/src/tasks/index.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
- package/dist/src/types.d.ts +5 -13
- package/dist/src/types.d.ts.map +1 -1
- package/dist/src/utils/request.d.ts +3 -2
- package/dist/src/utils/request.d.ts.map +1 -1
- package/package.json +3 -3
- package/src/lib/getInferenceProviderMapping.ts +96 -0
- package/src/lib/getProviderHelper.ts +270 -0
- package/src/lib/makeRequestOptions.ts +78 -97
- package/src/providers/black-forest-labs.ts +73 -22
- package/src/providers/cerebras.ts +6 -27
- package/src/providers/cohere.ts +9 -28
- package/src/providers/consts.ts +5 -2
- package/src/providers/fal-ai.ts +224 -77
- package/src/providers/fireworks-ai.ts +8 -29
- package/src/providers/hf-inference.ts +557 -34
- package/src/providers/hyperbolic.ts +107 -29
- package/src/providers/nebius.ts +65 -29
- package/src/providers/novita.ts +68 -32
- package/src/providers/openai.ts +6 -32
- package/src/providers/providerHelper.ts +354 -0
- package/src/providers/replicate.ts +124 -34
- package/src/providers/sambanova.ts +5 -30
- package/src/providers/together.ts +92 -28
- package/src/snippets/getInferenceSnippets.ts +39 -14
- package/src/snippets/templates.exported.ts +25 -25
- package/src/tasks/audio/audioClassification.ts +5 -8
- package/src/tasks/audio/audioToAudio.ts +4 -27
- package/src/tasks/audio/automaticSpeechRecognition.ts +5 -4
- package/src/tasks/audio/textToSpeech.ts +5 -29
- package/src/tasks/audio/utils.ts +2 -1
- package/src/tasks/custom/request.ts +3 -3
- package/src/tasks/custom/streamingRequest.ts +4 -3
- package/src/tasks/cv/imageClassification.ts +4 -8
- package/src/tasks/cv/imageSegmentation.ts +4 -9
- package/src/tasks/cv/imageToImage.ts +4 -7
- package/src/tasks/cv/imageToText.ts +4 -7
- package/src/tasks/cv/objectDetection.ts +4 -19
- package/src/tasks/cv/textToImage.ts +9 -137
- package/src/tasks/cv/textToVideo.ts +17 -64
- package/src/tasks/cv/zeroShotImageClassification.ts +4 -8
- package/src/tasks/index.ts +6 -6
- package/src/tasks/multimodal/documentQuestionAnswering.ts +4 -19
- package/src/tasks/multimodal/visualQuestionAnswering.ts +4 -12
- package/src/tasks/nlp/chatCompletion.ts +5 -20
- package/src/tasks/nlp/chatCompletionStream.ts +4 -3
- package/src/tasks/nlp/featureExtraction.ts +4 -19
- package/src/tasks/nlp/fillMask.ts +4 -17
- package/src/tasks/nlp/questionAnswering.ts +11 -26
- package/src/tasks/nlp/sentenceSimilarity.ts +4 -8
- package/src/tasks/nlp/summarization.ts +4 -7
- package/src/tasks/nlp/tableQuestionAnswering.ts +10 -30
- package/src/tasks/nlp/textClassification.ts +4 -9
- package/src/tasks/nlp/textGeneration.ts +11 -79
- package/src/tasks/nlp/textGenerationStream.ts +3 -1
- package/src/tasks/nlp/tokenClassification.ts +11 -23
- package/src/tasks/nlp/translation.ts +4 -7
- package/src/tasks/nlp/zeroShotClassification.ts +11 -21
- package/src/tasks/tabular/tabularClassification.ts +4 -7
- package/src/tasks/tabular/tabularRegression.ts +4 -7
- package/src/types.ts +5 -14
- package/src/utils/request.ts +7 -4
- package/dist/src/lib/getProviderModelId.d.ts +0 -10
- package/dist/src/lib/getProviderModelId.d.ts.map +0 -1
- package/src/lib/getProviderModelId.ts +0 -74
|
@@ -1,48 +1,15 @@
|
|
|
1
|
-
import type { TextToImageInput
|
|
2
|
-
import {
|
|
3
|
-
import
|
|
4
|
-
import {
|
|
5
|
-
import { omit } from "../../utils/omit";
|
|
1
|
+
import type { TextToImageInput } from "@huggingface/tasks";
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
|
+
import { makeRequestOptions } from "../../lib/makeRequestOptions";
|
|
4
|
+
import type { BaseArgs, Options } from "../../types";
|
|
6
5
|
import { innerRequest } from "../../utils/request";
|
|
7
6
|
|
|
8
7
|
export type TextToImageArgs = BaseArgs & TextToImageInput;
|
|
9
8
|
|
|
10
|
-
interface Base64ImageGeneration {
|
|
11
|
-
data: Array<{
|
|
12
|
-
b64_json: string;
|
|
13
|
-
}>;
|
|
14
|
-
}
|
|
15
|
-
interface OutputUrlImageGeneration {
|
|
16
|
-
output: string[];
|
|
17
|
-
}
|
|
18
|
-
interface HyperbolicTextToImageOutput {
|
|
19
|
-
images: Array<{ image: string }>;
|
|
20
|
-
}
|
|
21
|
-
|
|
22
|
-
interface BlackForestLabsResponse {
|
|
23
|
-
id: string;
|
|
24
|
-
polling_url: string;
|
|
25
|
-
}
|
|
26
|
-
|
|
27
9
|
interface TextToImageOptions extends Options {
|
|
28
10
|
outputType?: "url" | "blob";
|
|
29
11
|
}
|
|
30
12
|
|
|
31
|
-
function getResponseFormatArg(provider: InferenceProvider) {
|
|
32
|
-
switch (provider) {
|
|
33
|
-
case "fal-ai":
|
|
34
|
-
return { sync_mode: true };
|
|
35
|
-
case "nebius":
|
|
36
|
-
return { response_format: "b64_json" };
|
|
37
|
-
case "replicate":
|
|
38
|
-
return undefined;
|
|
39
|
-
case "together":
|
|
40
|
-
return { response_format: "base64" };
|
|
41
|
-
default:
|
|
42
|
-
return undefined;
|
|
43
|
-
}
|
|
44
|
-
}
|
|
45
|
-
|
|
46
13
|
/**
|
|
47
14
|
* This task reads some text input and outputs an image.
|
|
48
15
|
* Recommended model: stabilityai/stable-diffusion-2
|
|
@@ -56,108 +23,13 @@ export async function textToImage(
|
|
|
56
23
|
options?: TextToImageOptions & { outputType?: undefined | "blob" }
|
|
57
24
|
): Promise<Blob>;
|
|
58
25
|
export async function textToImage(args: TextToImageArgs, options?: TextToImageOptions): Promise<Blob | string> {
|
|
59
|
-
const
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
: {
|
|
63
|
-
...omit(args, ["inputs", "parameters"]),
|
|
64
|
-
...args.parameters,
|
|
65
|
-
...getResponseFormatArg(args.provider),
|
|
66
|
-
prompt: args.inputs,
|
|
67
|
-
};
|
|
68
|
-
const { data: res } = await innerRequest<
|
|
69
|
-
| TextToImageOutput
|
|
70
|
-
| Base64ImageGeneration
|
|
71
|
-
| OutputUrlImageGeneration
|
|
72
|
-
| BlackForestLabsResponse
|
|
73
|
-
| HyperbolicTextToImageOutput
|
|
74
|
-
>(payload, {
|
|
26
|
+
const provider = args.provider ?? "hf-inference";
|
|
27
|
+
const providerHelper = getProviderHelper(provider, "text-to-image");
|
|
28
|
+
const { data: res } = await innerRequest<Record<string, unknown>>(args, providerHelper, {
|
|
75
29
|
...options,
|
|
76
30
|
task: "text-to-image",
|
|
77
31
|
});
|
|
78
32
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
return await pollBflResponse(res.polling_url, options?.outputType);
|
|
82
|
-
}
|
|
83
|
-
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
|
|
84
|
-
if (options?.outputType === "url") {
|
|
85
|
-
return res.images[0].url;
|
|
86
|
-
} else {
|
|
87
|
-
const image = await fetch(res.images[0].url);
|
|
88
|
-
return await image.blob();
|
|
89
|
-
}
|
|
90
|
-
}
|
|
91
|
-
if (
|
|
92
|
-
args.provider === "hyperbolic" &&
|
|
93
|
-
"images" in res &&
|
|
94
|
-
Array.isArray(res.images) &&
|
|
95
|
-
res.images[0] &&
|
|
96
|
-
typeof res.images[0].image === "string"
|
|
97
|
-
) {
|
|
98
|
-
if (options?.outputType === "url") {
|
|
99
|
-
return `data:image/jpeg;base64,${res.images[0].image}`;
|
|
100
|
-
}
|
|
101
|
-
const base64Response = await fetch(`data:image/jpeg;base64,${res.images[0].image}`);
|
|
102
|
-
return await base64Response.blob();
|
|
103
|
-
}
|
|
104
|
-
if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
|
|
105
|
-
const base64Data = res.data[0].b64_json;
|
|
106
|
-
if (options?.outputType === "url") {
|
|
107
|
-
return `data:image/jpeg;base64,${base64Data}`;
|
|
108
|
-
}
|
|
109
|
-
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
110
|
-
return await base64Response.blob();
|
|
111
|
-
}
|
|
112
|
-
if ("output" in res && Array.isArray(res.output)) {
|
|
113
|
-
if (options?.outputType === "url") {
|
|
114
|
-
return res.output[0];
|
|
115
|
-
}
|
|
116
|
-
const urlResponse = await fetch(res.output[0]);
|
|
117
|
-
const blob = await urlResponse.blob();
|
|
118
|
-
return blob;
|
|
119
|
-
}
|
|
120
|
-
}
|
|
121
|
-
const isValidOutput = res && res instanceof Blob;
|
|
122
|
-
if (!isValidOutput) {
|
|
123
|
-
throw new InferenceOutputError("Expected Blob");
|
|
124
|
-
}
|
|
125
|
-
if (options?.outputType === "url") {
|
|
126
|
-
const b64 = await res.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
|
|
127
|
-
return `data:image/jpeg;base64,${b64}`;
|
|
128
|
-
}
|
|
129
|
-
return res;
|
|
130
|
-
}
|
|
131
|
-
|
|
132
|
-
async function pollBflResponse(url: string, outputType?: "url" | "blob"): Promise<Blob> {
|
|
133
|
-
const urlObj = new URL(url);
|
|
134
|
-
for (let step = 0; step < 5; step++) {
|
|
135
|
-
await delay(1000);
|
|
136
|
-
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
|
|
137
|
-
urlObj.searchParams.set("attempt", step.toString(10));
|
|
138
|
-
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
|
|
139
|
-
if (!resp.ok) {
|
|
140
|
-
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
141
|
-
}
|
|
142
|
-
const payload = await resp.json();
|
|
143
|
-
if (
|
|
144
|
-
typeof payload === "object" &&
|
|
145
|
-
payload &&
|
|
146
|
-
"status" in payload &&
|
|
147
|
-
typeof payload.status === "string" &&
|
|
148
|
-
payload.status === "Ready" &&
|
|
149
|
-
"result" in payload &&
|
|
150
|
-
typeof payload.result === "object" &&
|
|
151
|
-
payload.result &&
|
|
152
|
-
"sample" in payload.result &&
|
|
153
|
-
typeof payload.result.sample === "string"
|
|
154
|
-
) {
|
|
155
|
-
if (outputType === "url") {
|
|
156
|
-
return payload.result.sample;
|
|
157
|
-
}
|
|
158
|
-
const image = await fetch(payload.result.sample);
|
|
159
|
-
return await image.blob();
|
|
160
|
-
}
|
|
161
|
-
}
|
|
162
|
-
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
33
|
+
const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "text-to-image" });
|
|
34
|
+
return providerHelper.getResponse(res, url, info.headers as Record<string, string>, options?.outputType);
|
|
163
35
|
}
|
|
@@ -1,74 +1,27 @@
|
|
|
1
1
|
import type { TextToVideoInput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
3
|
-
import {
|
|
4
|
-
import {
|
|
5
|
-
import type {
|
|
6
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
|
+
import { makeRequestOptions } from "../../lib/makeRequestOptions";
|
|
4
|
+
import type { FalAiQueueOutput } from "../../providers/fal-ai";
|
|
5
|
+
import type { NovitaOutput } from "../../providers/novita";
|
|
6
|
+
import type { ReplicateOutput } from "../../providers/replicate";
|
|
7
|
+
import type { BaseArgs, Options } from "../../types";
|
|
7
8
|
import { innerRequest } from "../../utils/request";
|
|
8
|
-
import { typedInclude } from "../../utils/typedInclude";
|
|
9
9
|
|
|
10
10
|
export type TextToVideoArgs = BaseArgs & TextToVideoInput;
|
|
11
11
|
|
|
12
12
|
export type TextToVideoOutput = Blob;
|
|
13
13
|
|
|
14
|
-
interface ReplicateOutput {
|
|
15
|
-
output: string;
|
|
16
|
-
}
|
|
17
|
-
|
|
18
|
-
interface NovitaOutput {
|
|
19
|
-
video: {
|
|
20
|
-
video_url: string;
|
|
21
|
-
};
|
|
22
|
-
}
|
|
23
|
-
|
|
24
|
-
const SUPPORTED_PROVIDERS = ["fal-ai", "novita", "replicate"] as const satisfies readonly InferenceProvider[];
|
|
25
|
-
|
|
26
14
|
export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise<TextToVideoOutput> {
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs }
|
|
36
|
-
: args;
|
|
37
|
-
const { data, requestContext } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(payload, {
|
|
38
|
-
...options,
|
|
39
|
-
task: "text-to-video",
|
|
40
|
-
});
|
|
41
|
-
|
|
42
|
-
if (args.provider === "fal-ai") {
|
|
43
|
-
return await pollFalResponse(
|
|
44
|
-
data as FalAiQueueOutput,
|
|
45
|
-
requestContext.url,
|
|
46
|
-
requestContext.info.headers as Record<string, string>
|
|
47
|
-
);
|
|
48
|
-
} else if (args.provider === "novita") {
|
|
49
|
-
const isValidOutput =
|
|
50
|
-
typeof data === "object" &&
|
|
51
|
-
!!data &&
|
|
52
|
-
"video" in data &&
|
|
53
|
-
typeof data.video === "object" &&
|
|
54
|
-
!!data.video &&
|
|
55
|
-
"video_url" in data.video &&
|
|
56
|
-
typeof data.video.video_url === "string" &&
|
|
57
|
-
isUrl(data.video.video_url);
|
|
58
|
-
if (!isValidOutput) {
|
|
59
|
-
throw new InferenceOutputError("Expected { video: { video_url: string } }");
|
|
60
|
-
}
|
|
61
|
-
const urlResponse = await fetch((data as NovitaOutput).video.video_url);
|
|
62
|
-
return await urlResponse.blob();
|
|
63
|
-
} else {
|
|
64
|
-
/// TODO: Replicate: handle the case where the generation request "times out" / is async (ie output is null)
|
|
65
|
-
/// https://replicate.com/docs/topics/predictions/create-a-prediction
|
|
66
|
-
const isValidOutput =
|
|
67
|
-
typeof data === "object" && !!data && "output" in data && typeof data.output === "string" && isUrl(data.output);
|
|
68
|
-
if (!isValidOutput) {
|
|
69
|
-
throw new InferenceOutputError("Expected { output: string }");
|
|
15
|
+
const provider = args.provider ?? "hf-inference";
|
|
16
|
+
const providerHelper = getProviderHelper(provider, "text-to-video");
|
|
17
|
+
const { data: response } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(
|
|
18
|
+
args,
|
|
19
|
+
providerHelper,
|
|
20
|
+
{
|
|
21
|
+
...options,
|
|
22
|
+
task: "text-to-video",
|
|
70
23
|
}
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
24
|
+
);
|
|
25
|
+
const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "text-to-video" });
|
|
26
|
+
return providerHelper.getResponse(response, url, info.headers as Record<string, string>);
|
|
74
27
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { ZeroShotImageClassificationInput, ZeroShotImageClassificationOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options, RequestArgs } from "../../types";
|
|
4
4
|
import { base64FromBytes } from "../../utils/base64FromBytes";
|
|
5
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -44,15 +44,11 @@ export async function zeroShotImageClassification(
|
|
|
44
44
|
args: ZeroShotImageClassificationArgs,
|
|
45
45
|
options?: Options
|
|
46
46
|
): Promise<ZeroShotImageClassificationOutput> {
|
|
47
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
|
|
47
48
|
const payload = await preparePayload(args);
|
|
48
|
-
const { data: res } = await innerRequest<ZeroShotImageClassificationOutput>(payload, {
|
|
49
|
+
const { data: res } = await innerRequest<ZeroShotImageClassificationOutput>(payload, providerHelper, {
|
|
49
50
|
...options,
|
|
50
51
|
task: "zero-shot-image-classification",
|
|
51
52
|
});
|
|
52
|
-
|
|
53
|
-
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
54
|
-
if (!isValidOutput) {
|
|
55
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
56
|
-
}
|
|
57
|
-
return res;
|
|
53
|
+
return providerHelper.getResponse(res);
|
|
58
54
|
}
|
package/src/tasks/index.ts
CHANGED
|
@@ -4,21 +4,23 @@ export * from "./custom/streamingRequest";
|
|
|
4
4
|
|
|
5
5
|
// Audio tasks
|
|
6
6
|
export * from "./audio/audioClassification";
|
|
7
|
+
export * from "./audio/audioToAudio";
|
|
7
8
|
export * from "./audio/automaticSpeechRecognition";
|
|
8
9
|
export * from "./audio/textToSpeech";
|
|
9
|
-
export * from "./audio/audioToAudio";
|
|
10
10
|
|
|
11
11
|
// Computer Vision tasks
|
|
12
12
|
export * from "./cv/imageClassification";
|
|
13
13
|
export * from "./cv/imageSegmentation";
|
|
14
|
+
export * from "./cv/imageToImage";
|
|
14
15
|
export * from "./cv/imageToText";
|
|
15
16
|
export * from "./cv/objectDetection";
|
|
16
17
|
export * from "./cv/textToImage";
|
|
17
|
-
export * from "./cv/imageToImage";
|
|
18
|
-
export * from "./cv/zeroShotImageClassification";
|
|
19
18
|
export * from "./cv/textToVideo";
|
|
19
|
+
export * from "./cv/zeroShotImageClassification";
|
|
20
20
|
|
|
21
21
|
// Natural Language Processing tasks
|
|
22
|
+
export * from "./nlp/chatCompletion";
|
|
23
|
+
export * from "./nlp/chatCompletionStream";
|
|
22
24
|
export * from "./nlp/featureExtraction";
|
|
23
25
|
export * from "./nlp/fillMask";
|
|
24
26
|
export * from "./nlp/questionAnswering";
|
|
@@ -31,13 +33,11 @@ export * from "./nlp/textGenerationStream";
|
|
|
31
33
|
export * from "./nlp/tokenClassification";
|
|
32
34
|
export * from "./nlp/translation";
|
|
33
35
|
export * from "./nlp/zeroShotClassification";
|
|
34
|
-
export * from "./nlp/chatCompletion";
|
|
35
|
-
export * from "./nlp/chatCompletionStream";
|
|
36
36
|
|
|
37
37
|
// Multimodal tasks
|
|
38
38
|
export * from "./multimodal/documentQuestionAnswering";
|
|
39
39
|
export * from "./multimodal/visualQuestionAnswering";
|
|
40
40
|
|
|
41
41
|
// Tabular tasks
|
|
42
|
-
export * from "./tabular/tabularRegression";
|
|
43
42
|
export * from "./tabular/tabularClassification";
|
|
43
|
+
export * from "./tabular/tabularRegression";
|
|
@@ -3,11 +3,10 @@ import type {
|
|
|
3
3
|
DocumentQuestionAnsweringInputData,
|
|
4
4
|
DocumentQuestionAnsweringOutput,
|
|
5
5
|
} from "@huggingface/tasks";
|
|
6
|
-
import {
|
|
6
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
7
7
|
import type { BaseArgs, Options, RequestArgs } from "../../types";
|
|
8
8
|
import { base64FromBytes } from "../../utils/base64FromBytes";
|
|
9
9
|
import { innerRequest } from "../../utils/request";
|
|
10
|
-
import { toArray } from "../../utils/toArray";
|
|
11
10
|
|
|
12
11
|
/// Override the type to properly set inputs.image as Blob
|
|
13
12
|
export type DocumentQuestionAnsweringArgs = BaseArgs &
|
|
@@ -20,6 +19,7 @@ export async function documentQuestionAnswering(
|
|
|
20
19
|
args: DocumentQuestionAnsweringArgs,
|
|
21
20
|
options?: Options
|
|
22
21
|
): Promise<DocumentQuestionAnsweringOutput[number]> {
|
|
22
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering");
|
|
23
23
|
const reqArgs: RequestArgs = {
|
|
24
24
|
...args,
|
|
25
25
|
inputs: {
|
|
@@ -30,26 +30,11 @@ export async function documentQuestionAnswering(
|
|
|
30
30
|
} as RequestArgs;
|
|
31
31
|
const { data: res } = await innerRequest<DocumentQuestionAnsweringOutput | DocumentQuestionAnsweringOutput[number]>(
|
|
32
32
|
reqArgs,
|
|
33
|
+
providerHelper,
|
|
33
34
|
{
|
|
34
35
|
...options,
|
|
35
36
|
task: "document-question-answering",
|
|
36
37
|
}
|
|
37
38
|
);
|
|
38
|
-
|
|
39
|
-
const isValidOutput =
|
|
40
|
-
Array.isArray(output) &&
|
|
41
|
-
output.every(
|
|
42
|
-
(elem) =>
|
|
43
|
-
typeof elem === "object" &&
|
|
44
|
-
!!elem &&
|
|
45
|
-
typeof elem?.answer === "string" &&
|
|
46
|
-
(typeof elem.end === "number" || typeof elem.end === "undefined") &&
|
|
47
|
-
(typeof elem.score === "number" || typeof elem.score === "undefined") &&
|
|
48
|
-
(typeof elem.start === "number" || typeof elem.start === "undefined")
|
|
49
|
-
);
|
|
50
|
-
if (!isValidOutput) {
|
|
51
|
-
throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>");
|
|
52
|
-
}
|
|
53
|
-
|
|
54
|
-
return output[0];
|
|
39
|
+
return providerHelper.getResponse(res);
|
|
55
40
|
}
|
|
@@ -3,7 +3,7 @@ import type {
|
|
|
3
3
|
VisualQuestionAnsweringInputData,
|
|
4
4
|
VisualQuestionAnsweringOutput,
|
|
5
5
|
} from "@huggingface/tasks";
|
|
6
|
-
import {
|
|
6
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
7
7
|
import type { BaseArgs, Options, RequestArgs } from "../../types";
|
|
8
8
|
import { base64FromBytes } from "../../utils/base64FromBytes";
|
|
9
9
|
import { innerRequest } from "../../utils/request";
|
|
@@ -19,6 +19,7 @@ export async function visualQuestionAnswering(
|
|
|
19
19
|
args: VisualQuestionAnsweringArgs,
|
|
20
20
|
options?: Options
|
|
21
21
|
): Promise<VisualQuestionAnsweringOutput[number]> {
|
|
22
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "visual-question-answering");
|
|
22
23
|
const reqArgs: RequestArgs = {
|
|
23
24
|
...args,
|
|
24
25
|
inputs: {
|
|
@@ -28,18 +29,9 @@ export async function visualQuestionAnswering(
|
|
|
28
29
|
},
|
|
29
30
|
} as RequestArgs;
|
|
30
31
|
|
|
31
|
-
const { data: res } = await innerRequest<VisualQuestionAnsweringOutput>(reqArgs, {
|
|
32
|
+
const { data: res } = await innerRequest<VisualQuestionAnsweringOutput>(reqArgs, providerHelper, {
|
|
32
33
|
...options,
|
|
33
34
|
task: "visual-question-answering",
|
|
34
35
|
});
|
|
35
|
-
|
|
36
|
-
const isValidOutput =
|
|
37
|
-
Array.isArray(res) &&
|
|
38
|
-
res.every(
|
|
39
|
-
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
|
|
40
|
-
);
|
|
41
|
-
if (!isValidOutput) {
|
|
42
|
-
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
|
43
|
-
}
|
|
44
|
-
return res[0];
|
|
36
|
+
return providerHelper.getResponse(res);
|
|
45
37
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { ChatCompletionInput, ChatCompletionOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
4
|
import { innerRequest } from "../../utils/request";
|
|
5
5
|
|
|
@@ -10,25 +10,10 @@ export async function chatCompletion(
|
|
|
10
10
|
args: BaseArgs & ChatCompletionInput,
|
|
11
11
|
options?: Options
|
|
12
12
|
): Promise<ChatCompletionOutput> {
|
|
13
|
-
const
|
|
13
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
|
|
14
|
+
const { data: response } = await innerRequest<ChatCompletionOutput>(args, providerHelper, {
|
|
14
15
|
...options,
|
|
15
|
-
task: "
|
|
16
|
-
chatCompletion: true,
|
|
16
|
+
task: "conversational",
|
|
17
17
|
});
|
|
18
|
-
|
|
19
|
-
typeof res === "object" &&
|
|
20
|
-
Array.isArray(res?.choices) &&
|
|
21
|
-
typeof res?.created === "number" &&
|
|
22
|
-
typeof res?.id === "string" &&
|
|
23
|
-
typeof res?.model === "string" &&
|
|
24
|
-
/// Together.ai and Nebius do not output a system_fingerprint
|
|
25
|
-
(res.system_fingerprint === undefined ||
|
|
26
|
-
res.system_fingerprint === null ||
|
|
27
|
-
typeof res.system_fingerprint === "string") &&
|
|
28
|
-
typeof res?.usage === "object";
|
|
29
|
-
|
|
30
|
-
if (!isValidOutput) {
|
|
31
|
-
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
32
|
-
}
|
|
33
|
-
return res;
|
|
18
|
+
return providerHelper.getResponse(response);
|
|
34
19
|
}
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { ChatCompletionInput, ChatCompletionStreamOutput } from "@huggingface/tasks";
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
2
3
|
import type { BaseArgs, Options } from "../../types";
|
|
3
4
|
import { innerStreamingRequest } from "../../utils/request";
|
|
4
5
|
|
|
@@ -9,9 +10,9 @@ export async function* chatCompletionStream(
|
|
|
9
10
|
args: BaseArgs & ChatCompletionInput,
|
|
10
11
|
options?: Options
|
|
11
12
|
): AsyncGenerator<ChatCompletionStreamOutput> {
|
|
12
|
-
|
|
13
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
|
|
14
|
+
yield* innerStreamingRequest<ChatCompletionStreamOutput>(args, providerHelper, {
|
|
13
15
|
...options,
|
|
14
|
-
task: "
|
|
15
|
-
chatCompletion: true,
|
|
16
|
+
task: "conversational",
|
|
16
17
|
});
|
|
17
18
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { FeatureExtractionInput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
4
|
import { innerRequest } from "../../utils/request";
|
|
5
5
|
|
|
@@ -17,25 +17,10 @@ export async function featureExtraction(
|
|
|
17
17
|
args: FeatureExtractionArgs,
|
|
18
18
|
options?: Options
|
|
19
19
|
): Promise<FeatureExtractionOutput> {
|
|
20
|
-
const
|
|
20
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction");
|
|
21
|
+
const { data: res } = await innerRequest<FeatureExtractionOutput>(args, providerHelper, {
|
|
21
22
|
...options,
|
|
22
23
|
task: "feature-extraction",
|
|
23
24
|
});
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
const isNumArrayRec = (arr: unknown[], maxDepth: number, curDepth = 0): boolean => {
|
|
27
|
-
if (curDepth > maxDepth) return false;
|
|
28
|
-
if (arr.every((x) => Array.isArray(x))) {
|
|
29
|
-
return arr.every((x) => isNumArrayRec(x as unknown[], maxDepth, curDepth + 1));
|
|
30
|
-
} else {
|
|
31
|
-
return arr.every((x) => typeof x === "number");
|
|
32
|
-
}
|
|
33
|
-
};
|
|
34
|
-
|
|
35
|
-
isValidOutput = Array.isArray(res) && isNumArrayRec(res, 3, 0);
|
|
36
|
-
|
|
37
|
-
if (!isValidOutput) {
|
|
38
|
-
throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
|
|
39
|
-
}
|
|
40
|
-
return res;
|
|
25
|
+
return providerHelper.getResponse(res);
|
|
41
26
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { FillMaskInput, FillMaskOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
4
|
import { innerRequest } from "../../utils/request";
|
|
5
5
|
|
|
@@ -9,23 +9,10 @@ export type FillMaskArgs = BaseArgs & FillMaskInput;
|
|
|
9
9
|
* Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
|
|
10
10
|
*/
|
|
11
11
|
export async function fillMask(args: FillMaskArgs, options?: Options): Promise<FillMaskOutput> {
|
|
12
|
-
const
|
|
12
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask");
|
|
13
|
+
const { data: res } = await innerRequest<FillMaskOutput>(args, providerHelper, {
|
|
13
14
|
...options,
|
|
14
15
|
task: "fill-mask",
|
|
15
16
|
});
|
|
16
|
-
|
|
17
|
-
Array.isArray(res) &&
|
|
18
|
-
res.every(
|
|
19
|
-
(x) =>
|
|
20
|
-
typeof x.score === "number" &&
|
|
21
|
-
typeof x.sequence === "string" &&
|
|
22
|
-
typeof x.token === "number" &&
|
|
23
|
-
typeof x.token_str === "string"
|
|
24
|
-
);
|
|
25
|
-
if (!isValidOutput) {
|
|
26
|
-
throw new InferenceOutputError(
|
|
27
|
-
"Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
|
|
28
|
-
);
|
|
29
|
-
}
|
|
30
|
-
return res;
|
|
17
|
+
return providerHelper.getResponse(res);
|
|
31
18
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { QuestionAnsweringInput, QuestionAnsweringOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
4
|
import { innerRequest } from "../../utils/request";
|
|
5
5
|
|
|
@@ -12,29 +12,14 @@ export async function questionAnswering(
|
|
|
12
12
|
args: QuestionAnsweringArgs,
|
|
13
13
|
options?: Options
|
|
14
14
|
): Promise<QuestionAnsweringOutput[number]> {
|
|
15
|
-
const
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
typeof elem.answer === "string" &&
|
|
26
|
-
typeof elem.end === "number" &&
|
|
27
|
-
typeof elem.score === "number" &&
|
|
28
|
-
typeof elem.start === "number"
|
|
29
|
-
)
|
|
30
|
-
: typeof res === "object" &&
|
|
31
|
-
!!res &&
|
|
32
|
-
typeof res.answer === "string" &&
|
|
33
|
-
typeof res.end === "number" &&
|
|
34
|
-
typeof res.score === "number" &&
|
|
35
|
-
typeof res.start === "number";
|
|
36
|
-
if (!isValidOutput) {
|
|
37
|
-
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
38
|
-
}
|
|
39
|
-
return Array.isArray(res) ? res[0] : res;
|
|
15
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering");
|
|
16
|
+
const { data: res } = await innerRequest<QuestionAnsweringOutput | QuestionAnsweringOutput[number]>(
|
|
17
|
+
args,
|
|
18
|
+
providerHelper,
|
|
19
|
+
{
|
|
20
|
+
...options,
|
|
21
|
+
task: "question-answering",
|
|
22
|
+
}
|
|
23
|
+
);
|
|
24
|
+
return providerHelper.getResponse(res);
|
|
40
25
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { SentenceSimilarityInput, SentenceSimilarityOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
4
|
import { innerRequest } from "../../utils/request";
|
|
5
5
|
|
|
@@ -12,14 +12,10 @@ export async function sentenceSimilarity(
|
|
|
12
12
|
args: SentenceSimilarityArgs,
|
|
13
13
|
options?: Options
|
|
14
14
|
): Promise<SentenceSimilarityOutput> {
|
|
15
|
-
const
|
|
15
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity");
|
|
16
|
+
const { data: res } = await innerRequest<SentenceSimilarityOutput>(args, providerHelper, {
|
|
16
17
|
...options,
|
|
17
18
|
task: "sentence-similarity",
|
|
18
19
|
});
|
|
19
|
-
|
|
20
|
-
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
|
|
21
|
-
if (!isValidOutput) {
|
|
22
|
-
throw new InferenceOutputError("Expected number[]");
|
|
23
|
-
}
|
|
24
|
-
return res;
|
|
20
|
+
return providerHelper.getResponse(res);
|
|
25
21
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { SummarizationInput, SummarizationOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
4
|
import { innerRequest } from "../../utils/request";
|
|
5
5
|
|
|
@@ -9,13 +9,10 @@ export type SummarizationArgs = BaseArgs & SummarizationInput;
|
|
|
9
9
|
* This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.
|
|
10
10
|
*/
|
|
11
11
|
export async function summarization(args: SummarizationArgs, options?: Options): Promise<SummarizationOutput> {
|
|
12
|
-
const
|
|
12
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization");
|
|
13
|
+
const { data: res } = await innerRequest<SummarizationOutput[]>(args, providerHelper, {
|
|
13
14
|
...options,
|
|
14
15
|
task: "summarization",
|
|
15
16
|
});
|
|
16
|
-
|
|
17
|
-
if (!isValidOutput) {
|
|
18
|
-
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
|
19
|
-
}
|
|
20
|
-
return res?.[0];
|
|
17
|
+
return providerHelper.getResponse(res);
|
|
21
18
|
}
|