@huggingface/inference 3.7.0 → 3.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.cjs +1369 -941
- package/dist/index.js +1371 -943
- package/dist/src/lib/getInferenceProviderMapping.d.ts +21 -0
- package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -0
- package/dist/src/lib/getProviderHelper.d.ts +37 -0
- package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
- package/dist/src/lib/makeRequestOptions.d.ts +5 -5
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/black-forest-labs.d.ts +14 -18
- package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
- package/dist/src/providers/cerebras.d.ts +4 -2
- package/dist/src/providers/cerebras.d.ts.map +1 -1
- package/dist/src/providers/cohere.d.ts +5 -2
- package/dist/src/providers/cohere.d.ts.map +1 -1
- package/dist/src/providers/consts.d.ts +2 -3
- package/dist/src/providers/consts.d.ts.map +1 -1
- package/dist/src/providers/fal-ai.d.ts +50 -3
- package/dist/src/providers/fal-ai.d.ts.map +1 -1
- package/dist/src/providers/fireworks-ai.d.ts +5 -2
- package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
- package/dist/src/providers/hf-inference.d.ts +126 -2
- package/dist/src/providers/hf-inference.d.ts.map +1 -1
- package/dist/src/providers/hyperbolic.d.ts +31 -2
- package/dist/src/providers/hyperbolic.d.ts.map +1 -1
- package/dist/src/providers/nebius.d.ts +20 -18
- package/dist/src/providers/nebius.d.ts.map +1 -1
- package/dist/src/providers/novita.d.ts +21 -18
- package/dist/src/providers/novita.d.ts.map +1 -1
- package/dist/src/providers/openai.d.ts +4 -2
- package/dist/src/providers/openai.d.ts.map +1 -1
- package/dist/src/providers/providerHelper.d.ts +182 -0
- package/dist/src/providers/providerHelper.d.ts.map +1 -0
- package/dist/src/providers/replicate.d.ts +23 -19
- package/dist/src/providers/replicate.d.ts.map +1 -1
- package/dist/src/providers/sambanova.d.ts +4 -2
- package/dist/src/providers/sambanova.d.ts.map +1 -1
- package/dist/src/providers/together.d.ts +32 -2
- package/dist/src/providers/together.d.ts.map +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts +2 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/audio/utils.d.ts +2 -1
- package/dist/src/tasks/audio/utils.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts +0 -2
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts +0 -2
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/index.d.ts +6 -6
- package/dist/src/tasks/index.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
- package/dist/src/types.d.ts +5 -13
- package/dist/src/types.d.ts.map +1 -1
- package/dist/src/utils/request.d.ts +3 -2
- package/dist/src/utils/request.d.ts.map +1 -1
- package/package.json +3 -3
- package/src/lib/getInferenceProviderMapping.ts +96 -0
- package/src/lib/getProviderHelper.ts +270 -0
- package/src/lib/makeRequestOptions.ts +78 -97
- package/src/providers/black-forest-labs.ts +73 -22
- package/src/providers/cerebras.ts +6 -27
- package/src/providers/cohere.ts +9 -28
- package/src/providers/consts.ts +5 -2
- package/src/providers/fal-ai.ts +224 -77
- package/src/providers/fireworks-ai.ts +8 -29
- package/src/providers/hf-inference.ts +557 -34
- package/src/providers/hyperbolic.ts +107 -29
- package/src/providers/nebius.ts +65 -29
- package/src/providers/novita.ts +68 -32
- package/src/providers/openai.ts +6 -32
- package/src/providers/providerHelper.ts +354 -0
- package/src/providers/replicate.ts +124 -34
- package/src/providers/sambanova.ts +5 -30
- package/src/providers/together.ts +92 -28
- package/src/snippets/getInferenceSnippets.ts +39 -14
- package/src/snippets/templates.exported.ts +25 -25
- package/src/tasks/audio/audioClassification.ts +5 -8
- package/src/tasks/audio/audioToAudio.ts +4 -27
- package/src/tasks/audio/automaticSpeechRecognition.ts +5 -4
- package/src/tasks/audio/textToSpeech.ts +5 -29
- package/src/tasks/audio/utils.ts +2 -1
- package/src/tasks/custom/request.ts +3 -3
- package/src/tasks/custom/streamingRequest.ts +4 -3
- package/src/tasks/cv/imageClassification.ts +4 -8
- package/src/tasks/cv/imageSegmentation.ts +4 -9
- package/src/tasks/cv/imageToImage.ts +4 -7
- package/src/tasks/cv/imageToText.ts +4 -7
- package/src/tasks/cv/objectDetection.ts +4 -19
- package/src/tasks/cv/textToImage.ts +9 -137
- package/src/tasks/cv/textToVideo.ts +17 -64
- package/src/tasks/cv/zeroShotImageClassification.ts +4 -8
- package/src/tasks/index.ts +6 -6
- package/src/tasks/multimodal/documentQuestionAnswering.ts +4 -19
- package/src/tasks/multimodal/visualQuestionAnswering.ts +4 -12
- package/src/tasks/nlp/chatCompletion.ts +5 -20
- package/src/tasks/nlp/chatCompletionStream.ts +4 -3
- package/src/tasks/nlp/featureExtraction.ts +4 -19
- package/src/tasks/nlp/fillMask.ts +4 -17
- package/src/tasks/nlp/questionAnswering.ts +11 -26
- package/src/tasks/nlp/sentenceSimilarity.ts +4 -8
- package/src/tasks/nlp/summarization.ts +4 -7
- package/src/tasks/nlp/tableQuestionAnswering.ts +10 -30
- package/src/tasks/nlp/textClassification.ts +4 -9
- package/src/tasks/nlp/textGeneration.ts +11 -79
- package/src/tasks/nlp/textGenerationStream.ts +3 -1
- package/src/tasks/nlp/tokenClassification.ts +11 -23
- package/src/tasks/nlp/translation.ts +4 -7
- package/src/tasks/nlp/zeroShotClassification.ts +11 -21
- package/src/tasks/tabular/tabularClassification.ts +4 -7
- package/src/tasks/tabular/tabularRegression.ts +4 -7
- package/src/types.ts +5 -14
- package/src/utils/request.ts +7 -4
- package/dist/src/lib/getProviderModelId.d.ts +0 -10
- package/dist/src/lib/getProviderModelId.d.ts.map +0 -1
- package/src/lib/getProviderModelId.ts +0 -74
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { TableQuestionAnsweringInput, TableQuestionAnsweringOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
4
|
import { innerRequest } from "../../utils/request";
|
|
5
5
|
|
|
@@ -12,34 +12,14 @@ export async function tableQuestionAnswering(
|
|
|
12
12
|
args: TableQuestionAnsweringArgs,
|
|
13
13
|
options?: Options
|
|
14
14
|
): Promise<TableQuestionAnsweringOutput[number]> {
|
|
15
|
-
const
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
);
|
|
24
|
-
}
|
|
25
|
-
return Array.isArray(res) ? res[0] : res;
|
|
26
|
-
}
|
|
27
|
-
|
|
28
|
-
function validate(elem: unknown): elem is TableQuestionAnsweringOutput[number] {
|
|
29
|
-
return (
|
|
30
|
-
typeof elem === "object" &&
|
|
31
|
-
!!elem &&
|
|
32
|
-
"aggregator" in elem &&
|
|
33
|
-
typeof elem.aggregator === "string" &&
|
|
34
|
-
"answer" in elem &&
|
|
35
|
-
typeof elem.answer === "string" &&
|
|
36
|
-
"cells" in elem &&
|
|
37
|
-
Array.isArray(elem.cells) &&
|
|
38
|
-
elem.cells.every((x: unknown): x is string => typeof x === "string") &&
|
|
39
|
-
"coordinates" in elem &&
|
|
40
|
-
Array.isArray(elem.coordinates) &&
|
|
41
|
-
elem.coordinates.every(
|
|
42
|
-
(coord: unknown): coord is number[] => Array.isArray(coord) && coord.every((x) => typeof x === "number")
|
|
43
|
-
)
|
|
15
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
|
|
16
|
+
const { data: res } = await innerRequest<TableQuestionAnsweringOutput | TableQuestionAnsweringOutput[number]>(
|
|
17
|
+
args,
|
|
18
|
+
providerHelper,
|
|
19
|
+
{
|
|
20
|
+
...options,
|
|
21
|
+
task: "table-question-answering",
|
|
22
|
+
}
|
|
44
23
|
);
|
|
24
|
+
return providerHelper.getResponse(res);
|
|
45
25
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { TextClassificationInput, TextClassificationOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
4
|
import { innerRequest } from "../../utils/request";
|
|
5
5
|
|
|
@@ -12,15 +12,10 @@ export async function textClassification(
|
|
|
12
12
|
args: TextClassificationArgs,
|
|
13
13
|
options?: Options
|
|
14
14
|
): Promise<TextClassificationOutput> {
|
|
15
|
-
const
|
|
15
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
|
|
16
|
+
const { data: res } = await innerRequest<TextClassificationOutput>(args, providerHelper, {
|
|
16
17
|
...options,
|
|
17
18
|
task: "text-classification",
|
|
18
19
|
});
|
|
19
|
-
|
|
20
|
-
const isValidOutput =
|
|
21
|
-
Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number");
|
|
22
|
-
if (!isValidOutput) {
|
|
23
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
24
|
-
}
|
|
25
|
-
return output;
|
|
20
|
+
return providerHelper.getResponse(res);
|
|
26
21
|
}
|
|
@@ -1,33 +1,11 @@
|
|
|
1
|
-
import type {
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
TextGenerationOutput,
|
|
5
|
-
TextGenerationOutputFinishReason,
|
|
6
|
-
} from "@huggingface/tasks";
|
|
7
|
-
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
|
1
|
+
import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks";
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
|
+
import type { HyperbolicTextCompletionOutput } from "../../providers/hyperbolic";
|
|
8
4
|
import type { BaseArgs, Options } from "../../types";
|
|
9
|
-
import { omit } from "../../utils/omit";
|
|
10
5
|
import { innerRequest } from "../../utils/request";
|
|
11
|
-
import { toArray } from "../../utils/toArray";
|
|
12
6
|
|
|
13
7
|
export type { TextGenerationInput, TextGenerationOutput };
|
|
14
8
|
|
|
15
|
-
interface TogeteherTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
|
|
16
|
-
choices: Array<{
|
|
17
|
-
text: string;
|
|
18
|
-
finish_reason: TextGenerationOutputFinishReason;
|
|
19
|
-
seed: number;
|
|
20
|
-
logprobs: unknown;
|
|
21
|
-
index: number;
|
|
22
|
-
}>;
|
|
23
|
-
}
|
|
24
|
-
|
|
25
|
-
interface HyperbolicTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
|
|
26
|
-
choices: Array<{
|
|
27
|
-
message: { content: string };
|
|
28
|
-
}>;
|
|
29
|
-
}
|
|
30
|
-
|
|
31
9
|
/**
|
|
32
10
|
* Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
|
|
33
11
|
*/
|
|
@@ -35,58 +13,12 @@ export async function textGeneration(
|
|
|
35
13
|
args: BaseArgs & TextGenerationInput,
|
|
36
14
|
options?: Options
|
|
37
15
|
): Promise<TextGenerationOutput> {
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
if (!isValidOutput) {
|
|
47
|
-
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
48
|
-
}
|
|
49
|
-
const completion = raw.choices[0];
|
|
50
|
-
return {
|
|
51
|
-
generated_text: completion.text,
|
|
52
|
-
};
|
|
53
|
-
} else if (args.provider === "hyperbolic") {
|
|
54
|
-
const payload = {
|
|
55
|
-
messages: [{ content: args.inputs, role: "user" }],
|
|
56
|
-
...(args.parameters
|
|
57
|
-
? {
|
|
58
|
-
max_tokens: args.parameters.max_new_tokens,
|
|
59
|
-
...omit(args.parameters, "max_new_tokens"),
|
|
60
|
-
}
|
|
61
|
-
: undefined),
|
|
62
|
-
...omit(args, ["inputs", "parameters"]),
|
|
63
|
-
};
|
|
64
|
-
const raw = (
|
|
65
|
-
await innerRequest<HyperbolicTextCompletionOutput>(payload, {
|
|
66
|
-
...options,
|
|
67
|
-
task: "text-generation",
|
|
68
|
-
})
|
|
69
|
-
).data;
|
|
70
|
-
const isValidOutput =
|
|
71
|
-
typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
|
|
72
|
-
if (!isValidOutput) {
|
|
73
|
-
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
74
|
-
}
|
|
75
|
-
const completion = raw.choices[0];
|
|
76
|
-
return {
|
|
77
|
-
generated_text: completion.message.content,
|
|
78
|
-
};
|
|
79
|
-
} else {
|
|
80
|
-
const { data: res } = await innerRequest<TextGenerationOutput | TextGenerationOutput[]>(args, {
|
|
81
|
-
...options,
|
|
82
|
-
task: "text-generation",
|
|
83
|
-
});
|
|
84
|
-
const output = toArray(res);
|
|
85
|
-
const isValidOutput =
|
|
86
|
-
Array.isArray(output) && output.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
|
|
87
|
-
if (!isValidOutput) {
|
|
88
|
-
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
89
|
-
}
|
|
90
|
-
return (output as TextGenerationOutput[])?.[0];
|
|
91
|
-
}
|
|
16
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
|
|
17
|
+
const { data: response } = await innerRequest<
|
|
18
|
+
HyperbolicTextCompletionOutput | TextGenerationOutput | TextGenerationOutput[]
|
|
19
|
+
>(args, providerHelper, {
|
|
20
|
+
...options,
|
|
21
|
+
task: "text-generation",
|
|
22
|
+
});
|
|
23
|
+
return providerHelper.getResponse(response);
|
|
92
24
|
}
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { TextGenerationInput } from "@huggingface/tasks";
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
2
3
|
import type { BaseArgs, Options } from "../../types";
|
|
3
4
|
import { innerStreamingRequest } from "../../utils/request";
|
|
4
5
|
|
|
@@ -89,7 +90,8 @@ export async function* textGenerationStream(
|
|
|
89
90
|
args: BaseArgs & TextGenerationInput,
|
|
90
91
|
options?: Options
|
|
91
92
|
): AsyncGenerator<TextGenerationStreamOutput> {
|
|
92
|
-
|
|
93
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
|
|
94
|
+
yield* innerStreamingRequest<TextGenerationStreamOutput>(args, providerHelper, {
|
|
93
95
|
...options,
|
|
94
96
|
task: "text-generation",
|
|
95
97
|
});
|
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
import type { TokenClassificationInput, TokenClassificationOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
4
|
import { innerRequest } from "../../utils/request";
|
|
5
|
-
import { toArray } from "../../utils/toArray";
|
|
6
5
|
|
|
7
6
|
export type TokenClassificationArgs = BaseArgs & TokenClassificationInput;
|
|
8
7
|
|
|
@@ -13,25 +12,14 @@ export async function tokenClassification(
|
|
|
13
12
|
args: TokenClassificationArgs,
|
|
14
13
|
options?: Options
|
|
15
14
|
): Promise<TokenClassificationOutput> {
|
|
16
|
-
const
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
typeof x.entity_group === "string" &&
|
|
27
|
-
typeof x.score === "number" &&
|
|
28
|
-
typeof x.start === "number" &&
|
|
29
|
-
typeof x.word === "string"
|
|
30
|
-
);
|
|
31
|
-
if (!isValidOutput) {
|
|
32
|
-
throw new InferenceOutputError(
|
|
33
|
-
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
|
|
34
|
-
);
|
|
35
|
-
}
|
|
36
|
-
return output;
|
|
15
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification");
|
|
16
|
+
const { data: res } = await innerRequest<TokenClassificationOutput[number] | TokenClassificationOutput>(
|
|
17
|
+
args,
|
|
18
|
+
providerHelper,
|
|
19
|
+
{
|
|
20
|
+
...options,
|
|
21
|
+
task: "token-classification",
|
|
22
|
+
}
|
|
23
|
+
);
|
|
24
|
+
return providerHelper.getResponse(res);
|
|
37
25
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { TranslationInput, TranslationOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
4
|
import { innerRequest } from "../../utils/request";
|
|
5
5
|
|
|
@@ -8,13 +8,10 @@ export type TranslationArgs = BaseArgs & TranslationInput;
|
|
|
8
8
|
* This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en.
|
|
9
9
|
*/
|
|
10
10
|
export async function translation(args: TranslationArgs, options?: Options): Promise<TranslationOutput> {
|
|
11
|
-
const
|
|
11
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation");
|
|
12
|
+
const { data: res } = await innerRequest<TranslationOutput>(args, providerHelper, {
|
|
12
13
|
...options,
|
|
13
14
|
task: "translation",
|
|
14
15
|
});
|
|
15
|
-
|
|
16
|
-
if (!isValidOutput) {
|
|
17
|
-
throw new InferenceOutputError("Expected type Array<{translation_text: string}>");
|
|
18
|
-
}
|
|
19
|
-
return res?.length === 1 ? res?.[0] : res;
|
|
16
|
+
return providerHelper.getResponse(res);
|
|
20
17
|
}
|
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
import type { ZeroShotClassificationInput, ZeroShotClassificationOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
4
|
import { innerRequest } from "../../utils/request";
|
|
5
|
-
import { toArray } from "../../utils/toArray";
|
|
6
5
|
|
|
7
6
|
export type ZeroShotClassificationArgs = BaseArgs & ZeroShotClassificationInput;
|
|
8
7
|
|
|
@@ -13,23 +12,14 @@ export async function zeroShotClassification(
|
|
|
13
12
|
args: ZeroShotClassificationArgs,
|
|
14
13
|
options?: Options
|
|
15
14
|
): Promise<ZeroShotClassificationOutput> {
|
|
16
|
-
const
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
x.labels.every((_label) => typeof _label === "string") &&
|
|
27
|
-
Array.isArray(x.scores) &&
|
|
28
|
-
x.scores.every((_score) => typeof _score === "number") &&
|
|
29
|
-
typeof x.sequence === "string"
|
|
30
|
-
);
|
|
31
|
-
if (!isValidOutput) {
|
|
32
|
-
throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
|
|
33
|
-
}
|
|
34
|
-
return output;
|
|
15
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification");
|
|
16
|
+
const { data: res } = await innerRequest<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>(
|
|
17
|
+
args,
|
|
18
|
+
providerHelper,
|
|
19
|
+
{
|
|
20
|
+
...options,
|
|
21
|
+
task: "zero-shot-classification",
|
|
22
|
+
}
|
|
23
|
+
);
|
|
24
|
+
return providerHelper.getResponse(res);
|
|
35
25
|
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
2
2
|
import type { BaseArgs, Options } from "../../types";
|
|
3
3
|
import { innerRequest } from "../../utils/request";
|
|
4
4
|
|
|
@@ -25,13 +25,10 @@ export async function tabularClassification(
|
|
|
25
25
|
args: TabularClassificationArgs,
|
|
26
26
|
options?: Options
|
|
27
27
|
): Promise<TabularClassificationOutput> {
|
|
28
|
-
const
|
|
28
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification");
|
|
29
|
+
const { data: res } = await innerRequest<TabularClassificationOutput>(args, providerHelper, {
|
|
29
30
|
...options,
|
|
30
31
|
task: "tabular-classification",
|
|
31
32
|
});
|
|
32
|
-
|
|
33
|
-
if (!isValidOutput) {
|
|
34
|
-
throw new InferenceOutputError("Expected number[]");
|
|
35
|
-
}
|
|
36
|
-
return res;
|
|
33
|
+
return providerHelper.getResponse(res);
|
|
37
34
|
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
2
2
|
import type { BaseArgs, Options } from "../../types";
|
|
3
3
|
import { innerRequest } from "../../utils/request";
|
|
4
4
|
|
|
@@ -25,13 +25,10 @@ export async function tabularRegression(
|
|
|
25
25
|
args: TabularRegressionArgs,
|
|
26
26
|
options?: Options
|
|
27
27
|
): Promise<TabularRegressionOutput> {
|
|
28
|
-
const
|
|
28
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression");
|
|
29
|
+
const { data: res } = await innerRequest<TabularRegressionOutput>(args, providerHelper, {
|
|
29
30
|
...options,
|
|
30
31
|
task: "tabular-regression",
|
|
31
32
|
});
|
|
32
|
-
|
|
33
|
-
if (!isValidOutput) {
|
|
34
|
-
throw new InferenceOutputError("Expected number[]");
|
|
35
|
-
}
|
|
36
|
-
return res;
|
|
33
|
+
return providerHelper.getResponse(res);
|
|
37
34
|
}
|
package/src/types.ts
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { ChatCompletionInput, PipelineType } from "@huggingface/tasks";
|
|
2
|
+
import type { InferenceProviderModelMapping } from "./lib/getInferenceProviderMapping";
|
|
2
3
|
|
|
3
4
|
/**
|
|
4
5
|
* HF model id, like "meta-llama/Llama-3.3-70B-Instruct"
|
|
@@ -34,7 +35,7 @@ export interface Options {
|
|
|
34
35
|
billTo?: string;
|
|
35
36
|
}
|
|
36
37
|
|
|
37
|
-
export type InferenceTask = Exclude<PipelineType, "other"
|
|
38
|
+
export type InferenceTask = Exclude<PipelineType, "other"> | "conversational";
|
|
38
39
|
|
|
39
40
|
export const INFERENCE_PROVIDERS = [
|
|
40
41
|
"black-forest-labs",
|
|
@@ -101,14 +102,6 @@ export type RequestArgs = BaseArgs &
|
|
|
101
102
|
parameters?: Record<string, unknown>;
|
|
102
103
|
};
|
|
103
104
|
|
|
104
|
-
export interface ProviderConfig {
|
|
105
|
-
makeBaseUrl: ((task?: InferenceTask) => string) | (() => string);
|
|
106
|
-
makeBody: (params: BodyParams) => Record<string, unknown>;
|
|
107
|
-
makeHeaders: (params: HeaderParams) => Record<string, string>;
|
|
108
|
-
makeUrl: (params: UrlParams) => string;
|
|
109
|
-
clientSideRoutingOnly?: boolean;
|
|
110
|
-
}
|
|
111
|
-
|
|
112
105
|
export type AuthMethod = "none" | "hf-token" | "credentials-include" | "provider-key";
|
|
113
106
|
|
|
114
107
|
export interface HeaderParams {
|
|
@@ -118,15 +111,13 @@ export interface HeaderParams {
|
|
|
118
111
|
|
|
119
112
|
export interface UrlParams {
|
|
120
113
|
authMethod: AuthMethod;
|
|
121
|
-
baseUrl: string;
|
|
122
114
|
model: string;
|
|
123
115
|
task?: InferenceTask;
|
|
124
|
-
chatCompletion?: boolean;
|
|
125
116
|
}
|
|
126
117
|
|
|
127
|
-
export interface BodyParams {
|
|
128
|
-
args:
|
|
129
|
-
chatCompletion?: boolean;
|
|
118
|
+
export interface BodyParams<T extends Record<string, unknown> = Record<string, unknown>> {
|
|
119
|
+
args: T;
|
|
130
120
|
model: string;
|
|
121
|
+
mapping?: InferenceProviderModelMapping | undefined;
|
|
131
122
|
task?: InferenceTask;
|
|
132
123
|
}
|
package/src/utils/request.ts
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import type { getProviderHelper } from "../lib/getProviderHelper";
|
|
1
2
|
import { makeRequestOptions } from "../lib/makeRequestOptions";
|
|
2
3
|
import type { InferenceTask, Options, RequestArgs } from "../types";
|
|
3
4
|
import type { EventSourceMessage } from "../vendor/fetch-event-source/parse";
|
|
@@ -16,6 +17,7 @@ export interface ResponseWrapper<T> {
|
|
|
16
17
|
*/
|
|
17
18
|
export async function innerRequest<T>(
|
|
18
19
|
args: RequestArgs,
|
|
20
|
+
providerHelper: ReturnType<typeof getProviderHelper>,
|
|
19
21
|
options?: Options & {
|
|
20
22
|
/** In most cases (unless we pass a endpointUrl) we know the task */
|
|
21
23
|
task?: InferenceTask;
|
|
@@ -23,13 +25,13 @@ export async function innerRequest<T>(
|
|
|
23
25
|
chatCompletion?: boolean;
|
|
24
26
|
}
|
|
25
27
|
): Promise<ResponseWrapper<T>> {
|
|
26
|
-
const { url, info } = await makeRequestOptions(args, options);
|
|
28
|
+
const { url, info } = await makeRequestOptions(args, providerHelper, options);
|
|
27
29
|
const response = await (options?.fetch ?? fetch)(url, info);
|
|
28
30
|
|
|
29
31
|
const requestContext: ResponseWrapper<T>["requestContext"] = { url, info };
|
|
30
32
|
|
|
31
33
|
if (options?.retry_on_error !== false && response.status === 503) {
|
|
32
|
-
return innerRequest(args, options);
|
|
34
|
+
return innerRequest(args, providerHelper, options);
|
|
33
35
|
}
|
|
34
36
|
|
|
35
37
|
if (!response.ok) {
|
|
@@ -65,6 +67,7 @@ export async function innerRequest<T>(
|
|
|
65
67
|
*/
|
|
66
68
|
export async function* innerStreamingRequest<T>(
|
|
67
69
|
args: RequestArgs,
|
|
70
|
+
providerHelper: ReturnType<typeof getProviderHelper>,
|
|
68
71
|
options?: Options & {
|
|
69
72
|
/** In most cases (unless we pass a endpointUrl) we know the task */
|
|
70
73
|
task?: InferenceTask;
|
|
@@ -72,11 +75,11 @@ export async function* innerStreamingRequest<T>(
|
|
|
72
75
|
chatCompletion?: boolean;
|
|
73
76
|
}
|
|
74
77
|
): AsyncGenerator<T> {
|
|
75
|
-
const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
|
|
78
|
+
const { url, info } = await makeRequestOptions({ ...args, stream: true }, providerHelper, options);
|
|
76
79
|
const response = await (options?.fetch ?? fetch)(url, info);
|
|
77
80
|
|
|
78
81
|
if (options?.retry_on_error !== false && response.status === 503) {
|
|
79
|
-
return yield* innerStreamingRequest(args, options);
|
|
82
|
+
return yield* innerStreamingRequest(args, providerHelper, options);
|
|
80
83
|
}
|
|
81
84
|
if (!response.ok) {
|
|
82
85
|
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
import type { InferenceProvider, InferenceTask, Options, RequestArgs } from "../types";
|
|
2
|
-
export declare function getProviderModelId(params: {
|
|
3
|
-
model: string;
|
|
4
|
-
provider: InferenceProvider;
|
|
5
|
-
}, args: RequestArgs, options?: {
|
|
6
|
-
task?: InferenceTask;
|
|
7
|
-
chatCompletion?: boolean;
|
|
8
|
-
fetch?: Options["fetch"];
|
|
9
|
-
}): Promise<string>;
|
|
10
|
-
//# sourceMappingURL=getProviderModelId.d.ts.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"getProviderModelId.d.ts","sourceRoot":"","sources":["../../../src/lib/getProviderModelId.ts"],"names":[],"mappings":"AACA,OAAO,KAAK,EAAE,iBAAiB,EAAE,aAAa,EAAW,OAAO,EAAE,WAAW,EAAE,MAAM,UAAU,CAAC;AAShG,wBAAsB,kBAAkB,CACvC,MAAM,EAAE;IACP,KAAK,EAAE,MAAM,CAAC;IACd,QAAQ,EAAE,iBAAiB,CAAC;CAC5B,EACD,IAAI,EAAE,WAAW,EACjB,OAAO,GAAE;IACR,IAAI,CAAC,EAAE,aAAa,CAAC;IACrB,cAAc,CAAC,EAAE,OAAO,CAAC;IACzB,KAAK,CAAC,EAAE,OAAO,CAAC,OAAO,CAAC,CAAC;CACpB,GACJ,OAAO,CAAC,MAAM,CAAC,CAoDjB"}
|
|
@@ -1,74 +0,0 @@
|
|
|
1
|
-
import type { WidgetType } from "@huggingface/tasks";
|
|
2
|
-
import type { InferenceProvider, InferenceTask, ModelId, Options, RequestArgs } from "../types";
|
|
3
|
-
import { HF_HUB_URL } from "../config";
|
|
4
|
-
import { HARDCODED_MODEL_ID_MAPPING } from "../providers/consts";
|
|
5
|
-
|
|
6
|
-
type InferenceProviderMapping = Partial<
|
|
7
|
-
Record<InferenceProvider, { providerId: string; status: "live" | "staging"; task: WidgetType }>
|
|
8
|
-
>;
|
|
9
|
-
const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
|
|
10
|
-
|
|
11
|
-
export async function getProviderModelId(
|
|
12
|
-
params: {
|
|
13
|
-
model: string;
|
|
14
|
-
provider: InferenceProvider;
|
|
15
|
-
},
|
|
16
|
-
args: RequestArgs,
|
|
17
|
-
options: {
|
|
18
|
-
task?: InferenceTask;
|
|
19
|
-
chatCompletion?: boolean;
|
|
20
|
-
fetch?: Options["fetch"];
|
|
21
|
-
} = {}
|
|
22
|
-
): Promise<string> {
|
|
23
|
-
if (params.provider === "hf-inference") {
|
|
24
|
-
return params.model;
|
|
25
|
-
}
|
|
26
|
-
if (!options.task) {
|
|
27
|
-
throw new Error("task must be specified when using a third-party provider");
|
|
28
|
-
}
|
|
29
|
-
const task: WidgetType =
|
|
30
|
-
options.task === "text-generation" && options.chatCompletion ? "conversational" : options.task;
|
|
31
|
-
|
|
32
|
-
// A dict called HARDCODED_MODEL_ID_MAPPING takes precedence in all cases (useful for dev purposes)
|
|
33
|
-
if (HARDCODED_MODEL_ID_MAPPING[params.provider]?.[params.model]) {
|
|
34
|
-
return HARDCODED_MODEL_ID_MAPPING[params.provider][params.model];
|
|
35
|
-
}
|
|
36
|
-
|
|
37
|
-
let inferenceProviderMapping: InferenceProviderMapping | null;
|
|
38
|
-
if (inferenceProviderMappingCache.has(params.model)) {
|
|
39
|
-
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
40
|
-
inferenceProviderMapping = inferenceProviderMappingCache.get(params.model)!;
|
|
41
|
-
} else {
|
|
42
|
-
inferenceProviderMapping = await (options?.fetch ?? fetch)(
|
|
43
|
-
`${HF_HUB_URL}/api/models/${params.model}?expand[]=inferenceProviderMapping`,
|
|
44
|
-
{
|
|
45
|
-
headers: args.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${args.accessToken}` } : {},
|
|
46
|
-
}
|
|
47
|
-
)
|
|
48
|
-
.then((resp) => resp.json())
|
|
49
|
-
.then((json) => json.inferenceProviderMapping)
|
|
50
|
-
.catch(() => null);
|
|
51
|
-
}
|
|
52
|
-
|
|
53
|
-
if (!inferenceProviderMapping) {
|
|
54
|
-
throw new Error(`We have not been able to find inference provider information for model ${params.model}.`);
|
|
55
|
-
}
|
|
56
|
-
|
|
57
|
-
const providerMapping = inferenceProviderMapping[params.provider];
|
|
58
|
-
if (providerMapping) {
|
|
59
|
-
if (providerMapping.task !== task) {
|
|
60
|
-
throw new Error(
|
|
61
|
-
`Model ${params.model} is not supported for task ${task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
|
|
62
|
-
);
|
|
63
|
-
}
|
|
64
|
-
if (providerMapping.status === "staging") {
|
|
65
|
-
console.warn(
|
|
66
|
-
`Model ${params.model} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
|
|
67
|
-
);
|
|
68
|
-
}
|
|
69
|
-
// TODO: how is it handled server-side if model has multiple tasks (e.g. `text-generation` + `conversational`)?
|
|
70
|
-
return providerMapping.providerId;
|
|
71
|
-
}
|
|
72
|
-
|
|
73
|
-
throw new Error(`Model ${params.model} is not supported provider ${params.provider}.`);
|
|
74
|
-
}
|