@huggingface/inference 3.10.0 → 3.12.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.cjs +713 -643
- package/dist/index.js +712 -643
- package/dist/src/InferenceClient.d.ts +16 -17
- package/dist/src/InferenceClient.d.ts.map +1 -1
- package/dist/src/lib/getInferenceProviderMapping.d.ts +5 -1
- package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -1
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/providerHelper.d.ts +1 -1
- package/dist/src/providers/providerHelper.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioToAudio.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
- package/dist/src/types.d.ts +6 -4
- package/dist/src/types.d.ts.map +1 -1
- package/dist/src/utils/typedEntries.d.ts +4 -0
- package/dist/src/utils/typedEntries.d.ts.map +1 -0
- package/package.json +3 -3
- package/src/InferenceClient.ts +32 -43
- package/src/lib/getInferenceProviderMapping.ts +68 -19
- package/src/lib/makeRequestOptions.ts +4 -3
- package/src/providers/hf-inference.ts +1 -1
- package/src/providers/providerHelper.ts +1 -1
- package/src/snippets/getInferenceSnippets.ts +1 -1
- package/src/tasks/audio/audioClassification.ts +3 -1
- package/src/tasks/audio/audioToAudio.ts +4 -1
- package/src/tasks/audio/automaticSpeechRecognition.ts +3 -1
- package/src/tasks/audio/textToSpeech.ts +2 -1
- package/src/tasks/custom/request.ts +3 -1
- package/src/tasks/custom/streamingRequest.ts +3 -1
- package/src/tasks/cv/imageClassification.ts +3 -1
- package/src/tasks/cv/imageSegmentation.ts +3 -1
- package/src/tasks/cv/imageToImage.ts +3 -1
- package/src/tasks/cv/imageToText.ts +3 -1
- package/src/tasks/cv/objectDetection.ts +3 -1
- package/src/tasks/cv/textToImage.ts +2 -1
- package/src/tasks/cv/textToVideo.ts +2 -1
- package/src/tasks/cv/zeroShotImageClassification.ts +3 -1
- package/src/tasks/multimodal/documentQuestionAnswering.ts +3 -1
- package/src/tasks/multimodal/visualQuestionAnswering.ts +3 -1
- package/src/tasks/nlp/chatCompletion.ts +3 -1
- package/src/tasks/nlp/chatCompletionStream.ts +3 -1
- package/src/tasks/nlp/featureExtraction.ts +3 -1
- package/src/tasks/nlp/fillMask.ts +3 -1
- package/src/tasks/nlp/questionAnswering.ts +4 -1
- package/src/tasks/nlp/sentenceSimilarity.ts +3 -1
- package/src/tasks/nlp/summarization.ts +3 -1
- package/src/tasks/nlp/tableQuestionAnswering.ts +3 -1
- package/src/tasks/nlp/textClassification.ts +3 -1
- package/src/tasks/nlp/textGeneration.ts +3 -1
- package/src/tasks/nlp/textGenerationStream.ts +3 -1
- package/src/tasks/nlp/tokenClassification.ts +3 -1
- package/src/tasks/nlp/translation.ts +3 -1
- package/src/tasks/nlp/zeroShotClassification.ts +3 -1
- package/src/tasks/tabular/tabularClassification.ts +3 -1
- package/src/tasks/tabular/tabularRegression.ts +3 -1
- package/src/types.ts +8 -4
- package/src/utils/typedEntries.ts +5 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import type { WidgetType } from "@huggingface/tasks";
|
|
2
|
-
import type { InferenceProvider, ModelId } from "../types";
|
|
3
2
|
import { HF_HUB_URL } from "../config";
|
|
4
3
|
import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts";
|
|
5
4
|
import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference";
|
|
5
|
+
import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../types";
|
|
6
6
|
import { typedInclude } from "../utils/typedInclude";
|
|
7
7
|
|
|
8
8
|
export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
|
|
@@ -20,44 +20,62 @@ export interface InferenceProviderModelMapping {
|
|
|
20
20
|
task: WidgetType;
|
|
21
21
|
}
|
|
22
22
|
|
|
23
|
-
export async function
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
provider: InferenceProvider;
|
|
28
|
-
task: WidgetType;
|
|
29
|
-
},
|
|
30
|
-
options: {
|
|
23
|
+
export async function fetchInferenceProviderMappingForModel(
|
|
24
|
+
modelId: ModelId,
|
|
25
|
+
accessToken?: string,
|
|
26
|
+
options?: {
|
|
31
27
|
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
|
|
32
28
|
}
|
|
33
|
-
): Promise<
|
|
34
|
-
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
|
|
35
|
-
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
|
|
36
|
-
}
|
|
29
|
+
): Promise<InferenceProviderMapping> {
|
|
37
30
|
let inferenceProviderMapping: InferenceProviderMapping | null;
|
|
38
|
-
if (inferenceProviderMappingCache.has(
|
|
31
|
+
if (inferenceProviderMappingCache.has(modelId)) {
|
|
39
32
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
40
|
-
inferenceProviderMapping = inferenceProviderMappingCache.get(
|
|
33
|
+
inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!;
|
|
41
34
|
} else {
|
|
42
35
|
const resp = await (options?.fetch ?? fetch)(
|
|
43
|
-
`${HF_HUB_URL}/api/models/${
|
|
36
|
+
`${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
|
|
44
37
|
{
|
|
45
|
-
headers:
|
|
38
|
+
headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {},
|
|
46
39
|
}
|
|
47
40
|
);
|
|
48
41
|
if (resp.status === 404) {
|
|
49
|
-
throw new Error(`Model ${
|
|
42
|
+
throw new Error(`Model ${modelId} does not exist`);
|
|
50
43
|
}
|
|
51
44
|
inferenceProviderMapping = await resp
|
|
52
45
|
.json()
|
|
53
46
|
.then((json) => json.inferenceProviderMapping)
|
|
54
47
|
.catch(() => null);
|
|
48
|
+
|
|
49
|
+
if (inferenceProviderMapping) {
|
|
50
|
+
inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
|
|
51
|
+
}
|
|
55
52
|
}
|
|
56
53
|
|
|
57
54
|
if (!inferenceProviderMapping) {
|
|
58
|
-
throw new Error(`We have not been able to find inference provider information for model ${
|
|
55
|
+
throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
|
|
59
56
|
}
|
|
57
|
+
return inferenceProviderMapping;
|
|
58
|
+
}
|
|
60
59
|
|
|
60
|
+
export async function getInferenceProviderMapping(
|
|
61
|
+
params: {
|
|
62
|
+
accessToken?: string;
|
|
63
|
+
modelId: ModelId;
|
|
64
|
+
provider: InferenceProvider;
|
|
65
|
+
task: WidgetType;
|
|
66
|
+
},
|
|
67
|
+
options: {
|
|
68
|
+
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
|
|
69
|
+
}
|
|
70
|
+
): Promise<InferenceProviderModelMapping | null> {
|
|
71
|
+
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
|
|
72
|
+
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
|
|
73
|
+
}
|
|
74
|
+
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
|
|
75
|
+
params.modelId,
|
|
76
|
+
params.accessToken,
|
|
77
|
+
options
|
|
78
|
+
);
|
|
61
79
|
const providerMapping = inferenceProviderMapping[params.provider];
|
|
62
80
|
if (providerMapping) {
|
|
63
81
|
const equivalentTasks =
|
|
@@ -78,3 +96,34 @@ export async function getInferenceProviderMapping(
|
|
|
78
96
|
}
|
|
79
97
|
return null;
|
|
80
98
|
}
|
|
99
|
+
|
|
100
|
+
export async function resolveProvider(
|
|
101
|
+
provider?: InferenceProviderOrPolicy,
|
|
102
|
+
modelId?: string,
|
|
103
|
+
endpointUrl?: string
|
|
104
|
+
): Promise<InferenceProvider> {
|
|
105
|
+
if (endpointUrl) {
|
|
106
|
+
if (provider) {
|
|
107
|
+
throw new Error("Specifying both endpointUrl and provider is not supported.");
|
|
108
|
+
}
|
|
109
|
+
/// Defaulting to hf-inference helpers / API
|
|
110
|
+
return "hf-inference";
|
|
111
|
+
}
|
|
112
|
+
if (!provider) {
|
|
113
|
+
console.log(
|
|
114
|
+
"Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
|
|
115
|
+
);
|
|
116
|
+
provider = "auto";
|
|
117
|
+
}
|
|
118
|
+
if (provider === "auto") {
|
|
119
|
+
if (!modelId) {
|
|
120
|
+
throw new Error("Specifying a model is required when provider is 'auto'");
|
|
121
|
+
}
|
|
122
|
+
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
|
|
123
|
+
provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider | undefined;
|
|
124
|
+
}
|
|
125
|
+
if (!provider) {
|
|
126
|
+
throw new Error(`No Inference Provider available for model ${modelId}.`);
|
|
127
|
+
}
|
|
128
|
+
return provider;
|
|
129
|
+
}
|
|
@@ -27,8 +27,8 @@ export async function makeRequestOptions(
|
|
|
27
27
|
task?: InferenceTask;
|
|
28
28
|
}
|
|
29
29
|
): Promise<{ url: string; info: RequestInit }> {
|
|
30
|
-
const {
|
|
31
|
-
const provider =
|
|
30
|
+
const { model: maybeModel } = args;
|
|
31
|
+
const provider = providerHelper.provider;
|
|
32
32
|
const { task } = options ?? {};
|
|
33
33
|
|
|
34
34
|
// Validate inputs
|
|
@@ -113,8 +113,9 @@ export function makeRequestOptionsFromResolvedModel(
|
|
|
113
113
|
): { url: string; info: RequestInit } {
|
|
114
114
|
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
|
|
115
115
|
void model;
|
|
116
|
+
void maybeProvider;
|
|
116
117
|
|
|
117
|
-
const provider =
|
|
118
|
+
const provider = providerHelper.provider;
|
|
118
119
|
|
|
119
120
|
const { includeCredentials, task, signal, billTo } = options ?? {};
|
|
120
121
|
const authMethod = (() => {
|
|
@@ -106,7 +106,7 @@ export class HFInferenceTask extends TaskProviderHelper {
|
|
|
106
106
|
makeRoute(params: UrlParams): string {
|
|
107
107
|
if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
|
|
108
108
|
// when deployed on hf-inference, those two tasks are automatically compatible with one another.
|
|
109
|
-
return `
|
|
109
|
+
return `models/${params.model}/pipeline/${params.task}`;
|
|
110
110
|
}
|
|
111
111
|
return `models/${params.model}`;
|
|
112
112
|
}
|
|
@@ -56,7 +56,7 @@ import { toArray } from "../utils/toArray";
|
|
|
56
56
|
*/
|
|
57
57
|
export abstract class TaskProviderHelper {
|
|
58
58
|
constructor(
|
|
59
|
-
|
|
59
|
+
readonly provider: InferenceProvider,
|
|
60
60
|
private baseUrl: string,
|
|
61
61
|
readonly clientSideRoutingOnly: boolean = false
|
|
62
62
|
) {}
|
|
@@ -272,7 +272,7 @@ const prepareConversationalInput = (
|
|
|
272
272
|
return {
|
|
273
273
|
messages: opts?.messages ?? getModelInputSnippet(model),
|
|
274
274
|
...(opts?.temperature ? { temperature: opts?.temperature } : undefined),
|
|
275
|
-
max_tokens: opts?.max_tokens
|
|
275
|
+
...(opts?.max_tokens ? { max_tokens: opts?.max_tokens } : undefined),
|
|
276
276
|
...(opts?.top_p ? { top_p: opts?.top_p } : undefined),
|
|
277
277
|
};
|
|
278
278
|
};
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { AudioClassificationInput, AudioClassificationOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -15,7 +16,8 @@ export async function audioClassification(
|
|
|
15
16
|
args: AudioClassificationArgs,
|
|
16
17
|
options?: Options
|
|
17
18
|
): Promise<AudioClassificationOutput> {
|
|
18
|
-
const
|
|
19
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
20
|
+
const providerHelper = getProviderHelper(provider, "audio-classification");
|
|
19
21
|
const payload = preparePayload(args);
|
|
20
22
|
const { data: res } = await innerRequest<AudioClassificationOutput>(payload, providerHelper, {
|
|
21
23
|
...options,
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
1
2
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
2
3
|
import type { BaseArgs, Options } from "../../types";
|
|
3
4
|
import { innerRequest } from "../../utils/request";
|
|
@@ -36,7 +37,9 @@ export interface AudioToAudioOutput {
|
|
|
36
37
|
* Example model: speechbrain/sepformer-wham does audio source separation.
|
|
37
38
|
*/
|
|
38
39
|
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
|
|
39
|
-
const
|
|
40
|
+
const model = "inputs" in args ? args.model : undefined;
|
|
41
|
+
const provider = await resolveProvider(args.provider, model);
|
|
42
|
+
const providerHelper = getProviderHelper(provider, "audio-to-audio");
|
|
40
43
|
const payload = preparePayload(args);
|
|
41
44
|
const { data: res } = await innerRequest<AudioToAudioOutput>(payload, providerHelper, {
|
|
42
45
|
...options,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
|
4
5
|
import { FAL_AI_SUPPORTED_BLOB_TYPES } from "../../providers/fal-ai";
|
|
@@ -18,7 +19,8 @@ export async function automaticSpeechRecognition(
|
|
|
18
19
|
args: AutomaticSpeechRecognitionArgs,
|
|
19
20
|
options?: Options
|
|
20
21
|
): Promise<AutomaticSpeechRecognitionOutput> {
|
|
21
|
-
const
|
|
22
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
23
|
+
const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
|
|
22
24
|
const payload = await buildPayload(args);
|
|
23
25
|
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, providerHelper, {
|
|
24
26
|
...options,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { TextToSpeechInput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -12,7 +13,7 @@ interface OutputUrlTextToSpeechGeneration {
|
|
|
12
13
|
* Recommended model: espnet/kan-bayashi_ljspeech_vits
|
|
13
14
|
*/
|
|
14
15
|
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
|
|
15
|
-
const provider = args.provider
|
|
16
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
16
17
|
const providerHelper = getProviderHelper(provider, "text-to-speech");
|
|
17
18
|
const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, providerHelper, {
|
|
18
19
|
...options,
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
1
2
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
2
3
|
import type { InferenceTask, Options, RequestArgs } from "../../types";
|
|
3
4
|
import { innerRequest } from "../../utils/request";
|
|
@@ -16,7 +17,8 @@ export async function request<T>(
|
|
|
16
17
|
console.warn(
|
|
17
18
|
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
|
|
18
19
|
);
|
|
19
|
-
const
|
|
20
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
21
|
+
const providerHelper = getProviderHelper(provider, options?.task);
|
|
20
22
|
const result = await innerRequest<T>(args, providerHelper, options);
|
|
21
23
|
return result.data;
|
|
22
24
|
}
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
1
2
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
2
3
|
import type { InferenceTask, Options, RequestArgs } from "../../types";
|
|
3
4
|
import { innerStreamingRequest } from "../../utils/request";
|
|
@@ -16,6 +17,7 @@ export async function* streamingRequest<T>(
|
|
|
16
17
|
console.warn(
|
|
17
18
|
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
|
|
18
19
|
);
|
|
19
|
-
const
|
|
20
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
21
|
+
const providerHelper = getProviderHelper(provider, options?.task);
|
|
20
22
|
yield* innerStreamingRequest(args, providerHelper, options);
|
|
21
23
|
}
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { ImageClassificationInput, ImageClassificationOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -14,7 +15,8 @@ export async function imageClassification(
|
|
|
14
15
|
args: ImageClassificationArgs,
|
|
15
16
|
options?: Options
|
|
16
17
|
): Promise<ImageClassificationOutput> {
|
|
17
|
-
const
|
|
18
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
19
|
+
const providerHelper = getProviderHelper(provider, "image-classification");
|
|
18
20
|
const payload = preparePayload(args);
|
|
19
21
|
const { data: res } = await innerRequest<ImageClassificationOutput>(payload, providerHelper, {
|
|
20
22
|
...options,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { ImageSegmentationInput, ImageSegmentationOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -14,7 +15,8 @@ export async function imageSegmentation(
|
|
|
14
15
|
args: ImageSegmentationArgs,
|
|
15
16
|
options?: Options
|
|
16
17
|
): Promise<ImageSegmentationOutput> {
|
|
17
|
-
const
|
|
18
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
19
|
+
const providerHelper = getProviderHelper(provider, "image-segmentation");
|
|
18
20
|
const payload = preparePayload(args);
|
|
19
21
|
const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, providerHelper, {
|
|
20
22
|
...options,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { ImageToImageInput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options, RequestArgs } from "../../types";
|
|
4
5
|
import { base64FromBytes } from "../../utils/base64FromBytes";
|
|
@@ -11,7 +12,8 @@ export type ImageToImageArgs = BaseArgs & ImageToImageInput;
|
|
|
11
12
|
* Recommended model: lllyasviel/sd-controlnet-depth
|
|
12
13
|
*/
|
|
13
14
|
export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> {
|
|
14
|
-
const
|
|
15
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
16
|
+
const providerHelper = getProviderHelper(provider, "image-to-image");
|
|
15
17
|
let reqArgs: RequestArgs;
|
|
16
18
|
if (!args.parameters) {
|
|
17
19
|
reqArgs = {
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { ImageToTextInput, ImageToTextOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -10,7 +11,8 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput);
|
|
|
10
11
|
* This task reads some image input and outputs the text caption.
|
|
11
12
|
*/
|
|
12
13
|
export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
|
|
13
|
-
const
|
|
14
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
15
|
+
const providerHelper = getProviderHelper(provider, "image-to-text");
|
|
14
16
|
const payload = preparePayload(args);
|
|
15
17
|
const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, providerHelper, {
|
|
16
18
|
...options,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { ObjectDetectionInput, ObjectDetectionOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -11,7 +12,8 @@ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImage
|
|
|
11
12
|
* Recommended model: facebook/detr-resnet-50
|
|
12
13
|
*/
|
|
13
14
|
export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> {
|
|
14
|
-
const
|
|
15
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
16
|
+
const providerHelper = getProviderHelper(provider, "object-detection");
|
|
15
17
|
const payload = preparePayload(args);
|
|
16
18
|
const { data: res } = await innerRequest<ObjectDetectionOutput>(payload, providerHelper, {
|
|
17
19
|
...options,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { TextToImageInput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import { makeRequestOptions } from "../../lib/makeRequestOptions";
|
|
4
5
|
import type { BaseArgs, Options } from "../../types";
|
|
@@ -23,7 +24,7 @@ export async function textToImage(
|
|
|
23
24
|
options?: TextToImageOptions & { outputType?: undefined | "blob" }
|
|
24
25
|
): Promise<Blob>;
|
|
25
26
|
export async function textToImage(args: TextToImageArgs, options?: TextToImageOptions): Promise<Blob | string> {
|
|
26
|
-
const provider = args.provider
|
|
27
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
27
28
|
const providerHelper = getProviderHelper(provider, "text-to-image");
|
|
28
29
|
const { data: res } = await innerRequest<Record<string, unknown>>(args, providerHelper, {
|
|
29
30
|
...options,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { TextToVideoInput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import { makeRequestOptions } from "../../lib/makeRequestOptions";
|
|
4
5
|
import type { FalAiQueueOutput } from "../../providers/fal-ai";
|
|
@@ -12,7 +13,7 @@ export type TextToVideoArgs = BaseArgs & TextToVideoInput;
|
|
|
12
13
|
export type TextToVideoOutput = Blob;
|
|
13
14
|
|
|
14
15
|
export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise<TextToVideoOutput> {
|
|
15
|
-
const provider = args.provider
|
|
16
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
16
17
|
const providerHelper = getProviderHelper(provider, "text-to-video");
|
|
17
18
|
const { data: response } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(
|
|
18
19
|
args,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { ZeroShotImageClassificationInput, ZeroShotImageClassificationOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options, RequestArgs } from "../../types";
|
|
4
5
|
import { base64FromBytes } from "../../utils/base64FromBytes";
|
|
@@ -44,7 +45,8 @@ export async function zeroShotImageClassification(
|
|
|
44
45
|
args: ZeroShotImageClassificationArgs,
|
|
45
46
|
options?: Options
|
|
46
47
|
): Promise<ZeroShotImageClassificationOutput> {
|
|
47
|
-
const
|
|
48
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
49
|
+
const providerHelper = getProviderHelper(provider, "zero-shot-image-classification");
|
|
48
50
|
const payload = await preparePayload(args);
|
|
49
51
|
const { data: res } = await innerRequest<ZeroShotImageClassificationOutput>(payload, providerHelper, {
|
|
50
52
|
...options,
|
|
@@ -3,6 +3,7 @@ import type {
|
|
|
3
3
|
DocumentQuestionAnsweringInputData,
|
|
4
4
|
DocumentQuestionAnsweringOutput,
|
|
5
5
|
} from "@huggingface/tasks";
|
|
6
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
6
7
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
7
8
|
import type { BaseArgs, Options, RequestArgs } from "../../types";
|
|
8
9
|
import { base64FromBytes } from "../../utils/base64FromBytes";
|
|
@@ -19,7 +20,8 @@ export async function documentQuestionAnswering(
|
|
|
19
20
|
args: DocumentQuestionAnsweringArgs,
|
|
20
21
|
options?: Options
|
|
21
22
|
): Promise<DocumentQuestionAnsweringOutput[number]> {
|
|
22
|
-
const
|
|
23
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
24
|
+
const providerHelper = getProviderHelper(provider, "document-question-answering");
|
|
23
25
|
const reqArgs: RequestArgs = {
|
|
24
26
|
...args,
|
|
25
27
|
inputs: {
|
|
@@ -3,6 +3,7 @@ import type {
|
|
|
3
3
|
VisualQuestionAnsweringInputData,
|
|
4
4
|
VisualQuestionAnsweringOutput,
|
|
5
5
|
} from "@huggingface/tasks";
|
|
6
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
6
7
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
7
8
|
import type { BaseArgs, Options, RequestArgs } from "../../types";
|
|
8
9
|
import { base64FromBytes } from "../../utils/base64FromBytes";
|
|
@@ -19,7 +20,8 @@ export async function visualQuestionAnswering(
|
|
|
19
20
|
args: VisualQuestionAnsweringArgs,
|
|
20
21
|
options?: Options
|
|
21
22
|
): Promise<VisualQuestionAnsweringOutput[number]> {
|
|
22
|
-
const
|
|
23
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
24
|
+
const providerHelper = getProviderHelper(provider, "visual-question-answering");
|
|
23
25
|
const reqArgs: RequestArgs = {
|
|
24
26
|
...args,
|
|
25
27
|
inputs: {
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { ChatCompletionInput, ChatCompletionOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -10,7 +11,8 @@ export async function chatCompletion(
|
|
|
10
11
|
args: BaseArgs & ChatCompletionInput,
|
|
11
12
|
options?: Options
|
|
12
13
|
): Promise<ChatCompletionOutput> {
|
|
13
|
-
const
|
|
14
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
15
|
+
const providerHelper = getProviderHelper(provider, "conversational");
|
|
14
16
|
const { data: response } = await innerRequest<ChatCompletionOutput>(args, providerHelper, {
|
|
15
17
|
...options,
|
|
16
18
|
task: "conversational",
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { ChatCompletionInput, ChatCompletionStreamOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerStreamingRequest } from "../../utils/request";
|
|
@@ -10,7 +11,8 @@ export async function* chatCompletionStream(
|
|
|
10
11
|
args: BaseArgs & ChatCompletionInput,
|
|
11
12
|
options?: Options
|
|
12
13
|
): AsyncGenerator<ChatCompletionStreamOutput> {
|
|
13
|
-
const
|
|
14
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
15
|
+
const providerHelper = getProviderHelper(provider, "conversational");
|
|
14
16
|
yield* innerStreamingRequest<ChatCompletionStreamOutput>(args, providerHelper, {
|
|
15
17
|
...options,
|
|
16
18
|
task: "conversational",
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { FeatureExtractionInput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -22,7 +23,8 @@ export async function featureExtraction(
|
|
|
22
23
|
args: FeatureExtractionArgs,
|
|
23
24
|
options?: Options
|
|
24
25
|
): Promise<FeatureExtractionOutput> {
|
|
25
|
-
const
|
|
26
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
27
|
+
const providerHelper = getProviderHelper(provider, "feature-extraction");
|
|
26
28
|
const { data: res } = await innerRequest<FeatureExtractionOutput>(args, providerHelper, {
|
|
27
29
|
...options,
|
|
28
30
|
task: "feature-extraction",
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { FillMaskInput, FillMaskOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -9,7 +10,8 @@ export type FillMaskArgs = BaseArgs & FillMaskInput;
|
|
|
9
10
|
* Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
|
|
10
11
|
*/
|
|
11
12
|
export async function fillMask(args: FillMaskArgs, options?: Options): Promise<FillMaskOutput> {
|
|
12
|
-
const
|
|
13
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
14
|
+
const providerHelper = getProviderHelper(provider, "fill-mask");
|
|
13
15
|
const { data: res } = await innerRequest<FillMaskOutput>(args, providerHelper, {
|
|
14
16
|
...options,
|
|
15
17
|
task: "fill-mask",
|
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import type { QuestionAnsweringInput, QuestionAnsweringOutput } from "@huggingface/tasks";
|
|
2
|
+
|
|
3
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
4
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
5
|
import type { BaseArgs, Options } from "../../types";
|
|
4
6
|
import { innerRequest } from "../../utils/request";
|
|
@@ -12,7 +14,8 @@ export async function questionAnswering(
|
|
|
12
14
|
args: QuestionAnsweringArgs,
|
|
13
15
|
options?: Options
|
|
14
16
|
): Promise<QuestionAnsweringOutput[number]> {
|
|
15
|
-
const
|
|
17
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
18
|
+
const providerHelper = getProviderHelper(provider, "question-answering");
|
|
16
19
|
const { data: res } = await innerRequest<QuestionAnsweringOutput | QuestionAnsweringOutput[number]>(
|
|
17
20
|
args,
|
|
18
21
|
providerHelper,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { SentenceSimilarityInput, SentenceSimilarityOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -12,7 +13,8 @@ export async function sentenceSimilarity(
|
|
|
12
13
|
args: SentenceSimilarityArgs,
|
|
13
14
|
options?: Options
|
|
14
15
|
): Promise<SentenceSimilarityOutput> {
|
|
15
|
-
const
|
|
16
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
17
|
+
const providerHelper = getProviderHelper(provider, "sentence-similarity");
|
|
16
18
|
const { data: res } = await innerRequest<SentenceSimilarityOutput>(args, providerHelper, {
|
|
17
19
|
...options,
|
|
18
20
|
task: "sentence-similarity",
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { SummarizationInput, SummarizationOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -9,7 +10,8 @@ export type SummarizationArgs = BaseArgs & SummarizationInput;
|
|
|
9
10
|
* This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.
|
|
10
11
|
*/
|
|
11
12
|
export async function summarization(args: SummarizationArgs, options?: Options): Promise<SummarizationOutput> {
|
|
12
|
-
const
|
|
13
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
14
|
+
const providerHelper = getProviderHelper(provider, "summarization");
|
|
13
15
|
const { data: res } = await innerRequest<SummarizationOutput[]>(args, providerHelper, {
|
|
14
16
|
...options,
|
|
15
17
|
task: "summarization",
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { TableQuestionAnsweringInput, TableQuestionAnsweringOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -12,7 +13,8 @@ export async function tableQuestionAnswering(
|
|
|
12
13
|
args: TableQuestionAnsweringArgs,
|
|
13
14
|
options?: Options
|
|
14
15
|
): Promise<TableQuestionAnsweringOutput[number]> {
|
|
15
|
-
const
|
|
16
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
17
|
+
const providerHelper = getProviderHelper(provider, "table-question-answering");
|
|
16
18
|
const { data: res } = await innerRequest<TableQuestionAnsweringOutput | TableQuestionAnsweringOutput[number]>(
|
|
17
19
|
args,
|
|
18
20
|
providerHelper,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { TextClassificationInput, TextClassificationOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -12,7 +13,8 @@ export async function textClassification(
|
|
|
12
13
|
args: TextClassificationArgs,
|
|
13
14
|
options?: Options
|
|
14
15
|
): Promise<TextClassificationOutput> {
|
|
15
|
-
const
|
|
16
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
17
|
+
const providerHelper = getProviderHelper(provider, "text-classification");
|
|
16
18
|
const { data: res } = await innerRequest<TextClassificationOutput>(args, providerHelper, {
|
|
17
19
|
...options,
|
|
18
20
|
task: "text-classification",
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { HyperbolicTextCompletionOutput } from "../../providers/hyperbolic";
|
|
4
5
|
import type { BaseArgs, Options } from "../../types";
|
|
@@ -13,7 +14,8 @@ export async function textGeneration(
|
|
|
13
14
|
args: BaseArgs & TextGenerationInput,
|
|
14
15
|
options?: Options
|
|
15
16
|
): Promise<TextGenerationOutput> {
|
|
16
|
-
const
|
|
17
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
18
|
+
const providerHelper = getProviderHelper(provider, "text-generation");
|
|
17
19
|
const { data: response } = await innerRequest<
|
|
18
20
|
HyperbolicTextCompletionOutput | TextGenerationOutput | TextGenerationOutput[]
|
|
19
21
|
>(args, providerHelper, {
|