@huggingface/inference 3.7.0 → 3.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.cjs +1152 -839
- package/dist/index.js +1154 -841
- package/dist/src/lib/getProviderHelper.d.ts +37 -0
- package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
- package/dist/src/lib/makeRequestOptions.d.ts +0 -2
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/black-forest-labs.d.ts +14 -18
- package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
- package/dist/src/providers/cerebras.d.ts +4 -2
- package/dist/src/providers/cerebras.d.ts.map +1 -1
- package/dist/src/providers/cohere.d.ts +5 -2
- package/dist/src/providers/cohere.d.ts.map +1 -1
- package/dist/src/providers/fal-ai.d.ts +50 -3
- package/dist/src/providers/fal-ai.d.ts.map +1 -1
- package/dist/src/providers/fireworks-ai.d.ts +5 -2
- package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
- package/dist/src/providers/hf-inference.d.ts +125 -2
- package/dist/src/providers/hf-inference.d.ts.map +1 -1
- package/dist/src/providers/hyperbolic.d.ts +31 -2
- package/dist/src/providers/hyperbolic.d.ts.map +1 -1
- package/dist/src/providers/nebius.d.ts +20 -18
- package/dist/src/providers/nebius.d.ts.map +1 -1
- package/dist/src/providers/novita.d.ts +21 -18
- package/dist/src/providers/novita.d.ts.map +1 -1
- package/dist/src/providers/openai.d.ts +4 -2
- package/dist/src/providers/openai.d.ts.map +1 -1
- package/dist/src/providers/providerHelper.d.ts +182 -0
- package/dist/src/providers/providerHelper.d.ts.map +1 -0
- package/dist/src/providers/replicate.d.ts +23 -19
- package/dist/src/providers/replicate.d.ts.map +1 -1
- package/dist/src/providers/sambanova.d.ts +4 -2
- package/dist/src/providers/sambanova.d.ts.map +1 -1
- package/dist/src/providers/together.d.ts +32 -2
- package/dist/src/providers/together.d.ts.map +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/audio/utils.d.ts +2 -1
- package/dist/src/tasks/audio/utils.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts +0 -2
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts +0 -2
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/index.d.ts +6 -6
- package/dist/src/tasks/index.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
- package/dist/src/types.d.ts +3 -13
- package/dist/src/types.d.ts.map +1 -1
- package/package.json +3 -3
- package/src/lib/getProviderHelper.ts +270 -0
- package/src/lib/makeRequestOptions.ts +34 -91
- package/src/providers/black-forest-labs.ts +73 -22
- package/src/providers/cerebras.ts +6 -27
- package/src/providers/cohere.ts +9 -28
- package/src/providers/fal-ai.ts +195 -77
- package/src/providers/fireworks-ai.ts +8 -29
- package/src/providers/hf-inference.ts +555 -34
- package/src/providers/hyperbolic.ts +107 -29
- package/src/providers/nebius.ts +65 -29
- package/src/providers/novita.ts +68 -32
- package/src/providers/openai.ts +6 -32
- package/src/providers/providerHelper.ts +354 -0
- package/src/providers/replicate.ts +124 -34
- package/src/providers/sambanova.ts +5 -30
- package/src/providers/together.ts +92 -28
- package/src/snippets/getInferenceSnippets.ts +16 -9
- package/src/snippets/templates.exported.ts +1 -1
- package/src/tasks/audio/audioClassification.ts +4 -7
- package/src/tasks/audio/audioToAudio.ts +3 -26
- package/src/tasks/audio/automaticSpeechRecognition.ts +4 -3
- package/src/tasks/audio/textToSpeech.ts +5 -29
- package/src/tasks/audio/utils.ts +2 -1
- package/src/tasks/custom/request.ts +0 -2
- package/src/tasks/custom/streamingRequest.ts +0 -2
- package/src/tasks/cv/imageClassification.ts +3 -7
- package/src/tasks/cv/imageSegmentation.ts +3 -8
- package/src/tasks/cv/imageToImage.ts +3 -6
- package/src/tasks/cv/imageToText.ts +3 -6
- package/src/tasks/cv/objectDetection.ts +3 -18
- package/src/tasks/cv/textToImage.ts +9 -137
- package/src/tasks/cv/textToVideo.ts +11 -62
- package/src/tasks/cv/zeroShotImageClassification.ts +3 -7
- package/src/tasks/index.ts +6 -6
- package/src/tasks/multimodal/documentQuestionAnswering.ts +3 -19
- package/src/tasks/multimodal/visualQuestionAnswering.ts +3 -11
- package/src/tasks/nlp/chatCompletion.ts +5 -20
- package/src/tasks/nlp/chatCompletionStream.ts +1 -2
- package/src/tasks/nlp/featureExtraction.ts +3 -18
- package/src/tasks/nlp/fillMask.ts +3 -16
- package/src/tasks/nlp/questionAnswering.ts +3 -22
- package/src/tasks/nlp/sentenceSimilarity.ts +3 -7
- package/src/tasks/nlp/summarization.ts +3 -6
- package/src/tasks/nlp/tableQuestionAnswering.ts +3 -27
- package/src/tasks/nlp/textClassification.ts +3 -8
- package/src/tasks/nlp/textGeneration.ts +12 -79
- package/src/tasks/nlp/tokenClassification.ts +3 -18
- package/src/tasks/nlp/translation.ts +3 -6
- package/src/tasks/nlp/zeroShotClassification.ts +3 -16
- package/src/tasks/tabular/tabularClassification.ts +3 -6
- package/src/tasks/tabular/tabularRegression.ts +3 -6
- package/src/types.ts +3 -14
|
@@ -1,15 +1,15 @@
|
|
|
1
|
-
import
|
|
2
|
-
import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingface/tasks/src/tasks/index.js";
|
|
1
|
+
import { Template } from "@huggingface/jinja";
|
|
3
2
|
import {
|
|
4
3
|
type InferenceSnippet,
|
|
5
4
|
type InferenceSnippetLanguage,
|
|
6
5
|
type ModelDataMinimal,
|
|
7
|
-
inferenceSnippetLanguages,
|
|
8
6
|
getModelInputSnippet,
|
|
7
|
+
inferenceSnippetLanguages,
|
|
9
8
|
} from "@huggingface/tasks";
|
|
10
|
-
import type {
|
|
11
|
-
import {
|
|
9
|
+
import type { PipelineType, WidgetType } from "@huggingface/tasks/src/pipelines.js";
|
|
10
|
+
import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingface/tasks/src/tasks/index.js";
|
|
12
11
|
import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions";
|
|
12
|
+
import type { InferenceProvider, InferenceTask, RequestArgs } from "../types";
|
|
13
13
|
import { templates } from "./templates.exported";
|
|
14
14
|
|
|
15
15
|
const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const;
|
|
@@ -120,6 +120,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
|
|
|
120
120
|
opts?: Record<string, unknown>
|
|
121
121
|
): InferenceSnippet[] => {
|
|
122
122
|
/// Hacky: hard-code conversational templates here
|
|
123
|
+
let task = model.pipeline_tag as InferenceTask;
|
|
123
124
|
if (
|
|
124
125
|
model.pipeline_tag &&
|
|
125
126
|
["text-generation", "image-text-to-text"].includes(model.pipeline_tag) &&
|
|
@@ -127,14 +128,20 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
|
|
|
127
128
|
) {
|
|
128
129
|
templateName = opts?.streaming ? "conversationalStream" : "conversational";
|
|
129
130
|
inputPreparationFn = prepareConversationalInput;
|
|
131
|
+
task = "conversational";
|
|
130
132
|
}
|
|
131
|
-
|
|
132
133
|
/// Prepare inputs + make request
|
|
133
134
|
const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
|
|
134
135
|
const request = makeRequestOptionsFromResolvedModel(
|
|
135
136
|
providerModelId ?? model.id,
|
|
136
|
-
{
|
|
137
|
-
|
|
137
|
+
{
|
|
138
|
+
accessToken: accessToken,
|
|
139
|
+
provider: provider,
|
|
140
|
+
...inputs,
|
|
141
|
+
} as RequestArgs,
|
|
142
|
+
{
|
|
143
|
+
task: task,
|
|
144
|
+
}
|
|
138
145
|
);
|
|
139
146
|
|
|
140
147
|
/// Parse request.info.body if not a binary.
|
|
@@ -247,7 +254,7 @@ const prepareConversationalInput = (
|
|
|
247
254
|
return {
|
|
248
255
|
messages: opts?.messages ?? getModelInputSnippet(model),
|
|
249
256
|
...(opts?.temperature ? { temperature: opts?.temperature } : undefined),
|
|
250
|
-
max_tokens: opts?.max_tokens ??
|
|
257
|
+
max_tokens: opts?.max_tokens ?? 512,
|
|
251
258
|
...(opts?.top_p ? { top_p: opts?.top_p } : undefined),
|
|
252
259
|
};
|
|
253
260
|
};
|
|
@@ -20,7 +20,7 @@ export const templates: Record<string, Record<string, Record<string, string>>> =
|
|
|
20
20
|
},
|
|
21
21
|
"openai": {
|
|
22
22
|
"conversational": "import { OpenAI } from \"openai\";\n\nconst client = new OpenAI({\n\tbaseURL: \"{{ baseUrl }}\",\n\tapiKey: \"{{ accessToken }}\",\n});\n\nconst chatCompletion = await client.chat.completions.create({\n\tmodel: \"{{ providerModelId }}\",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);",
|
|
23
|
-
"conversationalStream": "import { OpenAI } from \"openai\";\n\nconst client = new OpenAI({\n\tbaseURL: \"{{ baseUrl }}\",\n\tapiKey: \"{{ accessToken }}\",\n});\n\
|
|
23
|
+
"conversationalStream": "import { OpenAI } from \"openai\";\n\nconst client = new OpenAI({\n\tbaseURL: \"{{ baseUrl }}\",\n\tapiKey: \"{{ accessToken }}\",\n});\n\nconst stream = await client.chat.completions.create({\n model: \"{{ providerModelId }}\",\n{{ inputs.asTsString }}\n stream: true,\n});\n\nfor await (const chunk of stream) {\n process.stdout.write(chunk.choices[0]?.delta?.content || \"\");\n}"
|
|
24
24
|
}
|
|
25
25
|
},
|
|
26
26
|
"python": {
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { AudioClassificationInput, AudioClassificationOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
4
|
import { innerRequest } from "../../utils/request";
|
|
5
5
|
import type { LegacyAudioInput } from "./utils";
|
|
@@ -15,15 +15,12 @@ export async function audioClassification(
|
|
|
15
15
|
args: AudioClassificationArgs,
|
|
16
16
|
options?: Options
|
|
17
17
|
): Promise<AudioClassificationOutput> {
|
|
18
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
|
|
18
19
|
const payload = preparePayload(args);
|
|
19
20
|
const { data: res } = await innerRequest<AudioClassificationOutput>(payload, {
|
|
20
21
|
...options,
|
|
21
22
|
task: "audio-classification",
|
|
22
23
|
});
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
if (!isValidOutput) {
|
|
26
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
27
|
-
}
|
|
28
|
-
return res;
|
|
24
|
+
|
|
25
|
+
return providerHelper.getResponse(res);
|
|
29
26
|
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
2
2
|
import type { BaseArgs, Options } from "../../types";
|
|
3
3
|
import { innerRequest } from "../../utils/request";
|
|
4
4
|
import type { LegacyAudioInput } from "./utils";
|
|
@@ -36,34 +36,11 @@ export interface AudioToAudioOutput {
|
|
|
36
36
|
* Example model: speechbrain/sepformer-wham does audio source separation.
|
|
37
37
|
*/
|
|
38
38
|
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
|
|
39
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
|
|
39
40
|
const payload = preparePayload(args);
|
|
40
41
|
const { data: res } = await innerRequest<AudioToAudioOutput>(payload, {
|
|
41
42
|
...options,
|
|
42
43
|
task: "audio-to-audio",
|
|
43
44
|
});
|
|
44
|
-
|
|
45
|
-
return validateOutput(res);
|
|
46
|
-
}
|
|
47
|
-
|
|
48
|
-
function validateOutput(output: unknown): AudioToAudioOutput[] {
|
|
49
|
-
if (!Array.isArray(output)) {
|
|
50
|
-
throw new InferenceOutputError("Expected Array");
|
|
51
|
-
}
|
|
52
|
-
if (
|
|
53
|
-
!output.every((elem): elem is AudioToAudioOutput => {
|
|
54
|
-
return (
|
|
55
|
-
typeof elem === "object" &&
|
|
56
|
-
elem &&
|
|
57
|
-
"label" in elem &&
|
|
58
|
-
typeof elem.label === "string" &&
|
|
59
|
-
"content-type" in elem &&
|
|
60
|
-
typeof elem["content-type"] === "string" &&
|
|
61
|
-
"blob" in elem &&
|
|
62
|
-
typeof elem.blob === "string"
|
|
63
|
-
);
|
|
64
|
-
})
|
|
65
|
-
) {
|
|
66
|
-
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
|
|
67
|
-
}
|
|
68
|
-
return output;
|
|
45
|
+
return providerHelper.getResponse(res);
|
|
69
46
|
}
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
2
3
|
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
|
4
|
+
import { FAL_AI_SUPPORTED_BLOB_TYPES } from "../../providers/fal-ai";
|
|
3
5
|
import type { BaseArgs, Options, RequestArgs } from "../../types";
|
|
4
6
|
import { base64FromBytes } from "../../utils/base64FromBytes";
|
|
5
7
|
import { omit } from "../../utils/omit";
|
|
@@ -16,6 +18,7 @@ export async function automaticSpeechRecognition(
|
|
|
16
18
|
args: AutomaticSpeechRecognitionArgs,
|
|
17
19
|
options?: Options
|
|
18
20
|
): Promise<AutomaticSpeechRecognitionOutput> {
|
|
21
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
|
|
19
22
|
const payload = await buildPayload(args);
|
|
20
23
|
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, {
|
|
21
24
|
...options,
|
|
@@ -25,11 +28,9 @@ export async function automaticSpeechRecognition(
|
|
|
25
28
|
if (!isValidOutput) {
|
|
26
29
|
throw new InferenceOutputError("Expected {text: string}");
|
|
27
30
|
}
|
|
28
|
-
return res;
|
|
31
|
+
return providerHelper.getResponse(res);
|
|
29
32
|
}
|
|
30
33
|
|
|
31
|
-
const FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
|
|
32
|
-
|
|
33
34
|
async function buildPayload(args: AutomaticSpeechRecognitionArgs): Promise<RequestArgs> {
|
|
34
35
|
if (args.provider === "fal-ai") {
|
|
35
36
|
const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : undefined;
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import type { TextToSpeechInput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
|
-
import { omit } from "../../utils/omit";
|
|
5
4
|
import { innerRequest } from "../../utils/request";
|
|
6
5
|
type TextToSpeechArgs = BaseArgs & TextToSpeechInput;
|
|
7
6
|
|
|
@@ -13,34 +12,11 @@ interface OutputUrlTextToSpeechGeneration {
|
|
|
13
12
|
* Recommended model: espnet/kan-bayashi_ljspeech_vits
|
|
14
13
|
*/
|
|
15
14
|
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
|
|
16
|
-
|
|
17
|
-
const
|
|
18
|
-
|
|
19
|
-
? {
|
|
20
|
-
...omit(args, ["inputs", "parameters"]),
|
|
21
|
-
...args.parameters,
|
|
22
|
-
text: args.inputs,
|
|
23
|
-
}
|
|
24
|
-
: args;
|
|
25
|
-
const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(payload, {
|
|
15
|
+
const provider = args.provider ?? "hf-inference";
|
|
16
|
+
const providerHelper = getProviderHelper(provider, "text-to-speech");
|
|
17
|
+
const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, {
|
|
26
18
|
...options,
|
|
27
19
|
task: "text-to-speech",
|
|
28
20
|
});
|
|
29
|
-
|
|
30
|
-
return res;
|
|
31
|
-
}
|
|
32
|
-
if (res && typeof res === "object") {
|
|
33
|
-
if ("output" in res) {
|
|
34
|
-
if (typeof res.output === "string") {
|
|
35
|
-
const urlResponse = await fetch(res.output);
|
|
36
|
-
const blob = await urlResponse.blob();
|
|
37
|
-
return blob;
|
|
38
|
-
} else if (Array.isArray(res.output)) {
|
|
39
|
-
const urlResponse = await fetch(res.output[0]);
|
|
40
|
-
const blob = await urlResponse.blob();
|
|
41
|
-
return blob;
|
|
42
|
-
}
|
|
43
|
-
}
|
|
44
|
-
}
|
|
45
|
-
throw new InferenceOutputError("Expected Blob or object with output");
|
|
21
|
+
return providerHelper.getResponse(res);
|
|
46
22
|
}
|
package/src/tasks/audio/utils.ts
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import type { BaseArgs, RequestArgs } from "../../types";
|
|
1
|
+
import type { BaseArgs, InferenceProvider, RequestArgs } from "../../types";
|
|
2
2
|
import { omit } from "../../utils/omit";
|
|
3
3
|
|
|
4
4
|
/**
|
|
@@ -6,6 +6,7 @@ import { omit } from "../../utils/omit";
|
|
|
6
6
|
*/
|
|
7
7
|
export interface LegacyAudioInput {
|
|
8
8
|
data: Blob | ArrayBuffer;
|
|
9
|
+
provider?: InferenceProvider;
|
|
9
10
|
}
|
|
10
11
|
|
|
11
12
|
export function preparePayload(args: BaseArgs & ({ inputs: Blob } | LegacyAudioInput)): RequestArgs {
|
|
@@ -10,8 +10,6 @@ export async function request<T>(
|
|
|
10
10
|
options?: Options & {
|
|
11
11
|
/** In most cases (unless we pass a endpointUrl) we know the task */
|
|
12
12
|
task?: InferenceTask;
|
|
13
|
-
/** Is chat completion compatible */
|
|
14
|
-
chatCompletion?: boolean;
|
|
15
13
|
}
|
|
16
14
|
): Promise<T> {
|
|
17
15
|
console.warn(
|
|
@@ -9,8 +9,6 @@ export async function* streamingRequest<T>(
|
|
|
9
9
|
options?: Options & {
|
|
10
10
|
/** In most cases (unless we pass a endpointUrl) we know the task */
|
|
11
11
|
task?: InferenceTask;
|
|
12
|
-
/** Is chat completion compatible */
|
|
13
|
-
chatCompletion?: boolean;
|
|
14
12
|
}
|
|
15
13
|
): AsyncGenerator<T> {
|
|
16
14
|
console.warn(
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { ImageClassificationInput, ImageClassificationOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
4
|
import { innerRequest } from "../../utils/request";
|
|
5
5
|
import { preparePayload, type LegacyImageInput } from "./utils";
|
|
@@ -14,15 +14,11 @@ export async function imageClassification(
|
|
|
14
14
|
args: ImageClassificationArgs,
|
|
15
15
|
options?: Options
|
|
16
16
|
): Promise<ImageClassificationOutput> {
|
|
17
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
|
|
17
18
|
const payload = preparePayload(args);
|
|
18
19
|
const { data: res } = await innerRequest<ImageClassificationOutput>(payload, {
|
|
19
20
|
...options,
|
|
20
21
|
task: "image-classification",
|
|
21
22
|
});
|
|
22
|
-
|
|
23
|
-
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
24
|
-
if (!isValidOutput) {
|
|
25
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
26
|
-
}
|
|
27
|
-
return res;
|
|
23
|
+
return providerHelper.getResponse(res);
|
|
28
24
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { ImageSegmentationInput, ImageSegmentationOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
4
|
import { innerRequest } from "../../utils/request";
|
|
5
5
|
import { preparePayload, type LegacyImageInput } from "./utils";
|
|
@@ -14,16 +14,11 @@ export async function imageSegmentation(
|
|
|
14
14
|
args: ImageSegmentationArgs,
|
|
15
15
|
options?: Options
|
|
16
16
|
): Promise<ImageSegmentationOutput> {
|
|
17
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
|
|
17
18
|
const payload = preparePayload(args);
|
|
18
19
|
const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, {
|
|
19
20
|
...options,
|
|
20
21
|
task: "image-segmentation",
|
|
21
22
|
});
|
|
22
|
-
|
|
23
|
-
Array.isArray(res) &&
|
|
24
|
-
res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
|
|
25
|
-
if (!isValidOutput) {
|
|
26
|
-
throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
|
|
27
|
-
}
|
|
28
|
-
return res;
|
|
23
|
+
return providerHelper.getResponse(res);
|
|
29
24
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { ImageToImageInput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options, RequestArgs } from "../../types";
|
|
4
4
|
import { base64FromBytes } from "../../utils/base64FromBytes";
|
|
5
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -11,6 +11,7 @@ export type ImageToImageArgs = BaseArgs & ImageToImageInput;
|
|
|
11
11
|
* Recommended model: lllyasviel/sd-controlnet-depth
|
|
12
12
|
*/
|
|
13
13
|
export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> {
|
|
14
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
|
|
14
15
|
let reqArgs: RequestArgs;
|
|
15
16
|
if (!args.parameters) {
|
|
16
17
|
reqArgs = {
|
|
@@ -30,9 +31,5 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P
|
|
|
30
31
|
...options,
|
|
31
32
|
task: "image-to-image",
|
|
32
33
|
});
|
|
33
|
-
|
|
34
|
-
if (!isValidOutput) {
|
|
35
|
-
throw new InferenceOutputError("Expected Blob");
|
|
36
|
-
}
|
|
37
|
-
return res;
|
|
34
|
+
return providerHelper.getResponse(res);
|
|
38
35
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { ImageToTextInput, ImageToTextOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
4
|
import { innerRequest } from "../../utils/request";
|
|
5
5
|
import type { LegacyImageInput } from "./utils";
|
|
@@ -10,15 +10,12 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput);
|
|
|
10
10
|
* This task reads some image input and outputs the text caption.
|
|
11
11
|
*/
|
|
12
12
|
export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
|
|
13
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
|
|
13
14
|
const payload = preparePayload(args);
|
|
14
15
|
const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, {
|
|
15
16
|
...options,
|
|
16
17
|
task: "image-to-text",
|
|
17
18
|
});
|
|
18
19
|
|
|
19
|
-
|
|
20
|
-
throw new InferenceOutputError("Expected {generated_text: string}");
|
|
21
|
-
}
|
|
22
|
-
|
|
23
|
-
return res?.[0];
|
|
20
|
+
return providerHelper.getResponse(res[0]);
|
|
24
21
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { ObjectDetectionInput, ObjectDetectionOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
4
|
import { innerRequest } from "../../utils/request";
|
|
5
5
|
import { preparePayload, type LegacyImageInput } from "./utils";
|
|
@@ -11,26 +11,11 @@ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImage
|
|
|
11
11
|
* Recommended model: facebook/detr-resnet-50
|
|
12
12
|
*/
|
|
13
13
|
export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> {
|
|
14
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
|
|
14
15
|
const payload = preparePayload(args);
|
|
15
16
|
const { data: res } = await innerRequest<ObjectDetectionOutput>(payload, {
|
|
16
17
|
...options,
|
|
17
18
|
task: "object-detection",
|
|
18
19
|
});
|
|
19
|
-
|
|
20
|
-
Array.isArray(res) &&
|
|
21
|
-
res.every(
|
|
22
|
-
(x) =>
|
|
23
|
-
typeof x.label === "string" &&
|
|
24
|
-
typeof x.score === "number" &&
|
|
25
|
-
typeof x.box.xmin === "number" &&
|
|
26
|
-
typeof x.box.ymin === "number" &&
|
|
27
|
-
typeof x.box.xmax === "number" &&
|
|
28
|
-
typeof x.box.ymax === "number"
|
|
29
|
-
);
|
|
30
|
-
if (!isValidOutput) {
|
|
31
|
-
throw new InferenceOutputError(
|
|
32
|
-
"Expected Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
|
|
33
|
-
);
|
|
34
|
-
}
|
|
35
|
-
return res;
|
|
20
|
+
return providerHelper.getResponse(res);
|
|
36
21
|
}
|
|
@@ -1,48 +1,15 @@
|
|
|
1
|
-
import type { TextToImageInput
|
|
2
|
-
import {
|
|
3
|
-
import
|
|
4
|
-
import {
|
|
5
|
-
import { omit } from "../../utils/omit";
|
|
1
|
+
import type { TextToImageInput } from "@huggingface/tasks";
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
|
+
import { makeRequestOptions } from "../../lib/makeRequestOptions";
|
|
4
|
+
import type { BaseArgs, Options } from "../../types";
|
|
6
5
|
import { innerRequest } from "../../utils/request";
|
|
7
6
|
|
|
8
7
|
export type TextToImageArgs = BaseArgs & TextToImageInput;
|
|
9
8
|
|
|
10
|
-
interface Base64ImageGeneration {
|
|
11
|
-
data: Array<{
|
|
12
|
-
b64_json: string;
|
|
13
|
-
}>;
|
|
14
|
-
}
|
|
15
|
-
interface OutputUrlImageGeneration {
|
|
16
|
-
output: string[];
|
|
17
|
-
}
|
|
18
|
-
interface HyperbolicTextToImageOutput {
|
|
19
|
-
images: Array<{ image: string }>;
|
|
20
|
-
}
|
|
21
|
-
|
|
22
|
-
interface BlackForestLabsResponse {
|
|
23
|
-
id: string;
|
|
24
|
-
polling_url: string;
|
|
25
|
-
}
|
|
26
|
-
|
|
27
9
|
interface TextToImageOptions extends Options {
|
|
28
10
|
outputType?: "url" | "blob";
|
|
29
11
|
}
|
|
30
12
|
|
|
31
|
-
function getResponseFormatArg(provider: InferenceProvider) {
|
|
32
|
-
switch (provider) {
|
|
33
|
-
case "fal-ai":
|
|
34
|
-
return { sync_mode: true };
|
|
35
|
-
case "nebius":
|
|
36
|
-
return { response_format: "b64_json" };
|
|
37
|
-
case "replicate":
|
|
38
|
-
return undefined;
|
|
39
|
-
case "together":
|
|
40
|
-
return { response_format: "base64" };
|
|
41
|
-
default:
|
|
42
|
-
return undefined;
|
|
43
|
-
}
|
|
44
|
-
}
|
|
45
|
-
|
|
46
13
|
/**
|
|
47
14
|
* This task reads some text input and outputs an image.
|
|
48
15
|
* Recommended model: stabilityai/stable-diffusion-2
|
|
@@ -56,108 +23,13 @@ export async function textToImage(
|
|
|
56
23
|
options?: TextToImageOptions & { outputType?: undefined | "blob" }
|
|
57
24
|
): Promise<Blob>;
|
|
58
25
|
export async function textToImage(args: TextToImageArgs, options?: TextToImageOptions): Promise<Blob | string> {
|
|
59
|
-
const
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
: {
|
|
63
|
-
...omit(args, ["inputs", "parameters"]),
|
|
64
|
-
...args.parameters,
|
|
65
|
-
...getResponseFormatArg(args.provider),
|
|
66
|
-
prompt: args.inputs,
|
|
67
|
-
};
|
|
68
|
-
const { data: res } = await innerRequest<
|
|
69
|
-
| TextToImageOutput
|
|
70
|
-
| Base64ImageGeneration
|
|
71
|
-
| OutputUrlImageGeneration
|
|
72
|
-
| BlackForestLabsResponse
|
|
73
|
-
| HyperbolicTextToImageOutput
|
|
74
|
-
>(payload, {
|
|
26
|
+
const provider = args.provider ?? "hf-inference";
|
|
27
|
+
const providerHelper = getProviderHelper(provider, "text-to-image");
|
|
28
|
+
const { data: res } = await innerRequest<Record<string, unknown>>(args, {
|
|
75
29
|
...options,
|
|
76
30
|
task: "text-to-image",
|
|
77
31
|
});
|
|
78
32
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
return await pollBflResponse(res.polling_url, options?.outputType);
|
|
82
|
-
}
|
|
83
|
-
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
|
|
84
|
-
if (options?.outputType === "url") {
|
|
85
|
-
return res.images[0].url;
|
|
86
|
-
} else {
|
|
87
|
-
const image = await fetch(res.images[0].url);
|
|
88
|
-
return await image.blob();
|
|
89
|
-
}
|
|
90
|
-
}
|
|
91
|
-
if (
|
|
92
|
-
args.provider === "hyperbolic" &&
|
|
93
|
-
"images" in res &&
|
|
94
|
-
Array.isArray(res.images) &&
|
|
95
|
-
res.images[0] &&
|
|
96
|
-
typeof res.images[0].image === "string"
|
|
97
|
-
) {
|
|
98
|
-
if (options?.outputType === "url") {
|
|
99
|
-
return `data:image/jpeg;base64,${res.images[0].image}`;
|
|
100
|
-
}
|
|
101
|
-
const base64Response = await fetch(`data:image/jpeg;base64,${res.images[0].image}`);
|
|
102
|
-
return await base64Response.blob();
|
|
103
|
-
}
|
|
104
|
-
if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
|
|
105
|
-
const base64Data = res.data[0].b64_json;
|
|
106
|
-
if (options?.outputType === "url") {
|
|
107
|
-
return `data:image/jpeg;base64,${base64Data}`;
|
|
108
|
-
}
|
|
109
|
-
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
110
|
-
return await base64Response.blob();
|
|
111
|
-
}
|
|
112
|
-
if ("output" in res && Array.isArray(res.output)) {
|
|
113
|
-
if (options?.outputType === "url") {
|
|
114
|
-
return res.output[0];
|
|
115
|
-
}
|
|
116
|
-
const urlResponse = await fetch(res.output[0]);
|
|
117
|
-
const blob = await urlResponse.blob();
|
|
118
|
-
return blob;
|
|
119
|
-
}
|
|
120
|
-
}
|
|
121
|
-
const isValidOutput = res && res instanceof Blob;
|
|
122
|
-
if (!isValidOutput) {
|
|
123
|
-
throw new InferenceOutputError("Expected Blob");
|
|
124
|
-
}
|
|
125
|
-
if (options?.outputType === "url") {
|
|
126
|
-
const b64 = await res.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
|
|
127
|
-
return `data:image/jpeg;base64,${b64}`;
|
|
128
|
-
}
|
|
129
|
-
return res;
|
|
130
|
-
}
|
|
131
|
-
|
|
132
|
-
async function pollBflResponse(url: string, outputType?: "url" | "blob"): Promise<Blob> {
|
|
133
|
-
const urlObj = new URL(url);
|
|
134
|
-
for (let step = 0; step < 5; step++) {
|
|
135
|
-
await delay(1000);
|
|
136
|
-
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
|
|
137
|
-
urlObj.searchParams.set("attempt", step.toString(10));
|
|
138
|
-
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
|
|
139
|
-
if (!resp.ok) {
|
|
140
|
-
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
141
|
-
}
|
|
142
|
-
const payload = await resp.json();
|
|
143
|
-
if (
|
|
144
|
-
typeof payload === "object" &&
|
|
145
|
-
payload &&
|
|
146
|
-
"status" in payload &&
|
|
147
|
-
typeof payload.status === "string" &&
|
|
148
|
-
payload.status === "Ready" &&
|
|
149
|
-
"result" in payload &&
|
|
150
|
-
typeof payload.result === "object" &&
|
|
151
|
-
payload.result &&
|
|
152
|
-
"sample" in payload.result &&
|
|
153
|
-
typeof payload.result.sample === "string"
|
|
154
|
-
) {
|
|
155
|
-
if (outputType === "url") {
|
|
156
|
-
return payload.result.sample;
|
|
157
|
-
}
|
|
158
|
-
const image = await fetch(payload.result.sample);
|
|
159
|
-
return await image.blob();
|
|
160
|
-
}
|
|
161
|
-
}
|
|
162
|
-
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
33
|
+
const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-image" });
|
|
34
|
+
return providerHelper.getResponse(res, url, info.headers as Record<string, string>, options?.outputType);
|
|
163
35
|
}
|
|
@@ -1,74 +1,23 @@
|
|
|
1
1
|
import type { TextToVideoInput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
3
|
-
import {
|
|
4
|
-
import {
|
|
5
|
-
import type {
|
|
6
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
|
+
import { makeRequestOptions } from "../../lib/makeRequestOptions";
|
|
4
|
+
import type { FalAiQueueOutput } from "../../providers/fal-ai";
|
|
5
|
+
import type { NovitaOutput } from "../../providers/novita";
|
|
6
|
+
import type { ReplicateOutput } from "../../providers/replicate";
|
|
7
|
+
import type { BaseArgs, Options } from "../../types";
|
|
7
8
|
import { innerRequest } from "../../utils/request";
|
|
8
|
-
import { typedInclude } from "../../utils/typedInclude";
|
|
9
9
|
|
|
10
10
|
export type TextToVideoArgs = BaseArgs & TextToVideoInput;
|
|
11
11
|
|
|
12
12
|
export type TextToVideoOutput = Blob;
|
|
13
13
|
|
|
14
|
-
interface ReplicateOutput {
|
|
15
|
-
output: string;
|
|
16
|
-
}
|
|
17
|
-
|
|
18
|
-
interface NovitaOutput {
|
|
19
|
-
video: {
|
|
20
|
-
video_url: string;
|
|
21
|
-
};
|
|
22
|
-
}
|
|
23
|
-
|
|
24
|
-
const SUPPORTED_PROVIDERS = ["fal-ai", "novita", "replicate"] as const satisfies readonly InferenceProvider[];
|
|
25
|
-
|
|
26
14
|
export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise<TextToVideoOutput> {
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
);
|
|
31
|
-
}
|
|
32
|
-
|
|
33
|
-
const payload =
|
|
34
|
-
args.provider === "fal-ai" || args.provider === "replicate" || args.provider === "novita"
|
|
35
|
-
? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs }
|
|
36
|
-
: args;
|
|
37
|
-
const { data, requestContext } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(payload, {
|
|
15
|
+
const provider = args.provider ?? "hf-inference";
|
|
16
|
+
const providerHelper = getProviderHelper(provider, "text-to-video");
|
|
17
|
+
const { data: response } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(args, {
|
|
38
18
|
...options,
|
|
39
19
|
task: "text-to-video",
|
|
40
20
|
});
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
return await pollFalResponse(
|
|
44
|
-
data as FalAiQueueOutput,
|
|
45
|
-
requestContext.url,
|
|
46
|
-
requestContext.info.headers as Record<string, string>
|
|
47
|
-
);
|
|
48
|
-
} else if (args.provider === "novita") {
|
|
49
|
-
const isValidOutput =
|
|
50
|
-
typeof data === "object" &&
|
|
51
|
-
!!data &&
|
|
52
|
-
"video" in data &&
|
|
53
|
-
typeof data.video === "object" &&
|
|
54
|
-
!!data.video &&
|
|
55
|
-
"video_url" in data.video &&
|
|
56
|
-
typeof data.video.video_url === "string" &&
|
|
57
|
-
isUrl(data.video.video_url);
|
|
58
|
-
if (!isValidOutput) {
|
|
59
|
-
throw new InferenceOutputError("Expected { video: { video_url: string } }");
|
|
60
|
-
}
|
|
61
|
-
const urlResponse = await fetch((data as NovitaOutput).video.video_url);
|
|
62
|
-
return await urlResponse.blob();
|
|
63
|
-
} else {
|
|
64
|
-
/// TODO: Replicate: handle the case where the generation request "times out" / is async (ie output is null)
|
|
65
|
-
/// https://replicate.com/docs/topics/predictions/create-a-prediction
|
|
66
|
-
const isValidOutput =
|
|
67
|
-
typeof data === "object" && !!data && "output" in data && typeof data.output === "string" && isUrl(data.output);
|
|
68
|
-
if (!isValidOutput) {
|
|
69
|
-
throw new InferenceOutputError("Expected { output: string }");
|
|
70
|
-
}
|
|
71
|
-
const urlResponse = await fetch(data.output);
|
|
72
|
-
return await urlResponse.blob();
|
|
73
|
-
}
|
|
21
|
+
const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-video" });
|
|
22
|
+
return providerHelper.getResponse(response, url, info.headers as Record<string, string>);
|
|
74
23
|
}
|