@huggingface/inference 3.6.2 → 3.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +0 -25
- package/dist/index.cjs +1232 -898
- package/dist/index.js +1234 -900
- package/dist/src/config.d.ts +1 -0
- package/dist/src/config.d.ts.map +1 -1
- package/dist/src/lib/getProviderHelper.d.ts +37 -0
- package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
- package/dist/src/lib/makeRequestOptions.d.ts +0 -2
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/black-forest-labs.d.ts +14 -18
- package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
- package/dist/src/providers/cerebras.d.ts +4 -2
- package/dist/src/providers/cerebras.d.ts.map +1 -1
- package/dist/src/providers/cohere.d.ts +5 -2
- package/dist/src/providers/cohere.d.ts.map +1 -1
- package/dist/src/providers/fal-ai.d.ts +50 -3
- package/dist/src/providers/fal-ai.d.ts.map +1 -1
- package/dist/src/providers/fireworks-ai.d.ts +5 -2
- package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
- package/dist/src/providers/hf-inference.d.ts +125 -2
- package/dist/src/providers/hf-inference.d.ts.map +1 -1
- package/dist/src/providers/hyperbolic.d.ts +31 -2
- package/dist/src/providers/hyperbolic.d.ts.map +1 -1
- package/dist/src/providers/nebius.d.ts +20 -18
- package/dist/src/providers/nebius.d.ts.map +1 -1
- package/dist/src/providers/novita.d.ts +21 -18
- package/dist/src/providers/novita.d.ts.map +1 -1
- package/dist/src/providers/openai.d.ts +4 -2
- package/dist/src/providers/openai.d.ts.map +1 -1
- package/dist/src/providers/providerHelper.d.ts +182 -0
- package/dist/src/providers/providerHelper.d.ts.map +1 -0
- package/dist/src/providers/replicate.d.ts +23 -19
- package/dist/src/providers/replicate.d.ts.map +1 -1
- package/dist/src/providers/sambanova.d.ts +4 -2
- package/dist/src/providers/sambanova.d.ts.map +1 -1
- package/dist/src/providers/together.d.ts +32 -2
- package/dist/src/providers/together.d.ts.map +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/audio/utils.d.ts +2 -1
- package/dist/src/tasks/audio/utils.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts +1 -2
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts +1 -2
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/index.d.ts +6 -6
- package/dist/src/tasks/index.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
- package/dist/src/types.d.ts +10 -13
- package/dist/src/types.d.ts.map +1 -1
- package/dist/src/utils/request.d.ts +27 -0
- package/dist/src/utils/request.d.ts.map +1 -0
- package/package.json +3 -3
- package/src/config.ts +1 -0
- package/src/lib/getProviderHelper.ts +270 -0
- package/src/lib/makeRequestOptions.ts +36 -90
- package/src/providers/black-forest-labs.ts +73 -22
- package/src/providers/cerebras.ts +6 -27
- package/src/providers/cohere.ts +9 -28
- package/src/providers/fal-ai.ts +195 -77
- package/src/providers/fireworks-ai.ts +8 -29
- package/src/providers/hf-inference.ts +555 -34
- package/src/providers/hyperbolic.ts +107 -29
- package/src/providers/nebius.ts +65 -29
- package/src/providers/novita.ts +68 -32
- package/src/providers/openai.ts +6 -32
- package/src/providers/providerHelper.ts +354 -0
- package/src/providers/replicate.ts +124 -34
- package/src/providers/sambanova.ts +5 -30
- package/src/providers/together.ts +92 -28
- package/src/snippets/getInferenceSnippets.ts +16 -9
- package/src/snippets/templates.exported.ts +2 -2
- package/src/tasks/audio/audioClassification.ts +6 -9
- package/src/tasks/audio/audioToAudio.ts +5 -28
- package/src/tasks/audio/automaticSpeechRecognition.ts +7 -6
- package/src/tasks/audio/textToSpeech.ts +6 -30
- package/src/tasks/audio/utils.ts +2 -1
- package/src/tasks/custom/request.ts +7 -34
- package/src/tasks/custom/streamingRequest.ts +5 -87
- package/src/tasks/cv/imageClassification.ts +5 -9
- package/src/tasks/cv/imageSegmentation.ts +5 -10
- package/src/tasks/cv/imageToImage.ts +5 -8
- package/src/tasks/cv/imageToText.ts +8 -13
- package/src/tasks/cv/objectDetection.ts +6 -21
- package/src/tasks/cv/textToImage.ts +10 -138
- package/src/tasks/cv/textToVideo.ts +11 -59
- package/src/tasks/cv/zeroShotImageClassification.ts +7 -12
- package/src/tasks/index.ts +6 -6
- package/src/tasks/multimodal/documentQuestionAnswering.ts +10 -26
- package/src/tasks/multimodal/visualQuestionAnswering.ts +6 -12
- package/src/tasks/nlp/chatCompletion.ts +7 -23
- package/src/tasks/nlp/chatCompletionStream.ts +4 -5
- package/src/tasks/nlp/featureExtraction.ts +5 -20
- package/src/tasks/nlp/fillMask.ts +5 -18
- package/src/tasks/nlp/questionAnswering.ts +5 -23
- package/src/tasks/nlp/sentenceSimilarity.ts +5 -18
- package/src/tasks/nlp/summarization.ts +5 -8
- package/src/tasks/nlp/tableQuestionAnswering.ts +5 -29
- package/src/tasks/nlp/textClassification.ts +8 -14
- package/src/tasks/nlp/textGeneration.ts +13 -80
- package/src/tasks/nlp/textGenerationStream.ts +2 -2
- package/src/tasks/nlp/tokenClassification.ts +8 -24
- package/src/tasks/nlp/translation.ts +5 -8
- package/src/tasks/nlp/zeroShotClassification.ts +8 -22
- package/src/tasks/tabular/tabularClassification.ts +5 -8
- package/src/tasks/tabular/tabularRegression.ts +5 -8
- package/src/types.ts +11 -14
- package/src/utils/request.ts +161 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import type { SummarizationInput, SummarizationOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
|
-
import {
|
|
4
|
+
import { innerRequest } from "../../utils/request";
|
|
5
5
|
|
|
6
6
|
export type SummarizationArgs = BaseArgs & SummarizationInput;
|
|
7
7
|
|
|
@@ -9,13 +9,10 @@ export type SummarizationArgs = BaseArgs & SummarizationInput;
|
|
|
9
9
|
* This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.
|
|
10
10
|
*/
|
|
11
11
|
export async function summarization(args: SummarizationArgs, options?: Options): Promise<SummarizationOutput> {
|
|
12
|
-
const
|
|
12
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization");
|
|
13
|
+
const { data: res } = await innerRequest<SummarizationOutput[]>(args, {
|
|
13
14
|
...options,
|
|
14
15
|
task: "summarization",
|
|
15
16
|
});
|
|
16
|
-
|
|
17
|
-
if (!isValidOutput) {
|
|
18
|
-
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
|
19
|
-
}
|
|
20
|
-
return res?.[0];
|
|
17
|
+
return providerHelper.getResponse(res);
|
|
21
18
|
}
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import type { TableQuestionAnsweringInput, TableQuestionAnsweringOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
|
-
import {
|
|
4
|
+
import { innerRequest } from "../../utils/request";
|
|
5
5
|
|
|
6
6
|
export type TableQuestionAnsweringArgs = BaseArgs & TableQuestionAnsweringInput;
|
|
7
7
|
|
|
@@ -12,34 +12,10 @@ export async function tableQuestionAnswering(
|
|
|
12
12
|
args: TableQuestionAnsweringArgs,
|
|
13
13
|
options?: Options
|
|
14
14
|
): Promise<TableQuestionAnsweringOutput[number]> {
|
|
15
|
-
const
|
|
15
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
|
|
16
|
+
const { data: res } = await innerRequest<TableQuestionAnsweringOutput | TableQuestionAnsweringOutput[number]>(args, {
|
|
16
17
|
...options,
|
|
17
18
|
task: "table-question-answering",
|
|
18
19
|
});
|
|
19
|
-
|
|
20
|
-
if (!isValidOutput) {
|
|
21
|
-
throw new InferenceOutputError(
|
|
22
|
-
"Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
|
|
23
|
-
);
|
|
24
|
-
}
|
|
25
|
-
return Array.isArray(res) ? res[0] : res;
|
|
26
|
-
}
|
|
27
|
-
|
|
28
|
-
function validate(elem: unknown): elem is TableQuestionAnsweringOutput[number] {
|
|
29
|
-
return (
|
|
30
|
-
typeof elem === "object" &&
|
|
31
|
-
!!elem &&
|
|
32
|
-
"aggregator" in elem &&
|
|
33
|
-
typeof elem.aggregator === "string" &&
|
|
34
|
-
"answer" in elem &&
|
|
35
|
-
typeof elem.answer === "string" &&
|
|
36
|
-
"cells" in elem &&
|
|
37
|
-
Array.isArray(elem.cells) &&
|
|
38
|
-
elem.cells.every((x: unknown): x is string => typeof x === "string") &&
|
|
39
|
-
"coordinates" in elem &&
|
|
40
|
-
Array.isArray(elem.coordinates) &&
|
|
41
|
-
elem.coordinates.every(
|
|
42
|
-
(coord: unknown): coord is number[] => Array.isArray(coord) && coord.every((x) => typeof x === "number")
|
|
43
|
-
)
|
|
44
|
-
);
|
|
20
|
+
return providerHelper.getResponse(res);
|
|
45
21
|
}
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import type { TextClassificationInput, TextClassificationOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
|
-
import {
|
|
4
|
+
import { innerRequest } from "../../utils/request";
|
|
5
5
|
|
|
6
6
|
export type TextClassificationArgs = BaseArgs & TextClassificationInput;
|
|
7
7
|
|
|
@@ -12,16 +12,10 @@ export async function textClassification(
|
|
|
12
12
|
args: TextClassificationArgs,
|
|
13
13
|
options?: Options
|
|
14
14
|
): Promise<TextClassificationOutput> {
|
|
15
|
-
const
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
)
|
|
21
|
-
const isValidOutput =
|
|
22
|
-
Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number");
|
|
23
|
-
if (!isValidOutput) {
|
|
24
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
25
|
-
}
|
|
26
|
-
return res;
|
|
15
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
|
|
16
|
+
const { data: res } = await innerRequest<TextClassificationOutput>(args, {
|
|
17
|
+
...options,
|
|
18
|
+
task: "text-classification",
|
|
19
|
+
});
|
|
20
|
+
return providerHelper.getResponse(res);
|
|
27
21
|
}
|
|
@@ -1,33 +1,11 @@
|
|
|
1
|
-
import type {
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
TextGenerationOutput,
|
|
5
|
-
TextGenerationOutputFinishReason,
|
|
6
|
-
} from "@huggingface/tasks";
|
|
7
|
-
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
|
1
|
+
import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks";
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
|
+
import type { HyperbolicTextCompletionOutput } from "../../providers/hyperbolic";
|
|
8
4
|
import type { BaseArgs, Options } from "../../types";
|
|
9
|
-
import {
|
|
10
|
-
import { request } from "../custom/request";
|
|
11
|
-
import { omit } from "../../utils/omit";
|
|
5
|
+
import { innerRequest } from "../../utils/request";
|
|
12
6
|
|
|
13
7
|
export type { TextGenerationInput, TextGenerationOutput };
|
|
14
8
|
|
|
15
|
-
interface TogeteherTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
|
|
16
|
-
choices: Array<{
|
|
17
|
-
text: string;
|
|
18
|
-
finish_reason: TextGenerationOutputFinishReason;
|
|
19
|
-
seed: number;
|
|
20
|
-
logprobs: unknown;
|
|
21
|
-
index: number;
|
|
22
|
-
}>;
|
|
23
|
-
}
|
|
24
|
-
|
|
25
|
-
interface HyperbolicTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
|
|
26
|
-
choices: Array<{
|
|
27
|
-
message: { content: string };
|
|
28
|
-
}>;
|
|
29
|
-
}
|
|
30
|
-
|
|
31
9
|
/**
|
|
32
10
|
* Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
|
|
33
11
|
*/
|
|
@@ -35,58 +13,13 @@ export async function textGeneration(
|
|
|
35
13
|
args: BaseArgs & TextGenerationInput,
|
|
36
14
|
options?: Options
|
|
37
15
|
): Promise<TextGenerationOutput> {
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
48
|
-
}
|
|
49
|
-
const completion = raw.choices[0];
|
|
50
|
-
return {
|
|
51
|
-
generated_text: completion.text,
|
|
52
|
-
};
|
|
53
|
-
} else if (args.provider === "hyperbolic") {
|
|
54
|
-
const payload = {
|
|
55
|
-
messages: [{ content: args.inputs, role: "user" }],
|
|
56
|
-
...(args.parameters
|
|
57
|
-
? {
|
|
58
|
-
max_tokens: args.parameters.max_new_tokens,
|
|
59
|
-
...omit(args.parameters, "max_new_tokens"),
|
|
60
|
-
}
|
|
61
|
-
: undefined),
|
|
62
|
-
...omit(args, ["inputs", "parameters"]),
|
|
63
|
-
};
|
|
64
|
-
const raw = await request<HyperbolicTextCompletionOutput>(payload, {
|
|
65
|
-
...options,
|
|
66
|
-
task: "text-generation",
|
|
67
|
-
});
|
|
68
|
-
const isValidOutput =
|
|
69
|
-
typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
|
|
70
|
-
if (!isValidOutput) {
|
|
71
|
-
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
72
|
-
}
|
|
73
|
-
const completion = raw.choices[0];
|
|
74
|
-
return {
|
|
75
|
-
generated_text: completion.message.content,
|
|
76
|
-
};
|
|
77
|
-
} else {
|
|
78
|
-
const res = toArray(
|
|
79
|
-
await request<TextGenerationOutput | TextGenerationOutput[]>(args, {
|
|
80
|
-
...options,
|
|
81
|
-
task: "text-generation",
|
|
82
|
-
})
|
|
83
|
-
);
|
|
84
|
-
|
|
85
|
-
const isValidOutput =
|
|
86
|
-
Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
|
|
87
|
-
if (!isValidOutput) {
|
|
88
|
-
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
89
|
-
}
|
|
90
|
-
return (res as TextGenerationOutput[])?.[0];
|
|
91
|
-
}
|
|
16
|
+
const provider = args.provider ?? "hf-inference";
|
|
17
|
+
const providerHelper = getProviderHelper(provider, "text-generation");
|
|
18
|
+
const { data: response } = await innerRequest<
|
|
19
|
+
HyperbolicTextCompletionOutput | TextGenerationOutput | TextGenerationOutput[]
|
|
20
|
+
>(args, {
|
|
21
|
+
...options,
|
|
22
|
+
task: "text-generation",
|
|
23
|
+
});
|
|
24
|
+
return providerHelper.getResponse(response);
|
|
92
25
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import type { TextGenerationInput } from "@huggingface/tasks";
|
|
2
2
|
import type { BaseArgs, Options } from "../../types";
|
|
3
|
-
import {
|
|
3
|
+
import { innerStreamingRequest } from "../../utils/request";
|
|
4
4
|
|
|
5
5
|
export interface TextGenerationStreamToken {
|
|
6
6
|
/** Token ID from the model tokenizer */
|
|
@@ -89,7 +89,7 @@ export async function* textGenerationStream(
|
|
|
89
89
|
args: BaseArgs & TextGenerationInput,
|
|
90
90
|
options?: Options
|
|
91
91
|
): AsyncGenerator<TextGenerationStreamOutput> {
|
|
92
|
-
yield*
|
|
92
|
+
yield* innerStreamingRequest<TextGenerationStreamOutput>(args, {
|
|
93
93
|
...options,
|
|
94
94
|
task: "text-generation",
|
|
95
95
|
});
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import type { TokenClassificationInput, TokenClassificationOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
|
+
import { innerRequest } from "../../utils/request";
|
|
4
5
|
import { toArray } from "../../utils/toArray";
|
|
5
|
-
import { request } from "../custom/request";
|
|
6
6
|
|
|
7
7
|
export type TokenClassificationArgs = BaseArgs & TokenClassificationInput;
|
|
8
8
|
|
|
@@ -13,26 +13,10 @@ export async function tokenClassification(
|
|
|
13
13
|
args: TokenClassificationArgs,
|
|
14
14
|
options?: Options
|
|
15
15
|
): Promise<TokenClassificationOutput> {
|
|
16
|
-
const
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
);
|
|
22
|
-
const isValidOutput =
|
|
23
|
-
Array.isArray(res) &&
|
|
24
|
-
res.every(
|
|
25
|
-
(x) =>
|
|
26
|
-
typeof x.end === "number" &&
|
|
27
|
-
typeof x.entity_group === "string" &&
|
|
28
|
-
typeof x.score === "number" &&
|
|
29
|
-
typeof x.start === "number" &&
|
|
30
|
-
typeof x.word === "string"
|
|
31
|
-
);
|
|
32
|
-
if (!isValidOutput) {
|
|
33
|
-
throw new InferenceOutputError(
|
|
34
|
-
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
|
|
35
|
-
);
|
|
36
|
-
}
|
|
37
|
-
return res;
|
|
16
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification");
|
|
17
|
+
const { data: res } = await innerRequest<TokenClassificationOutput[number] | TokenClassificationOutput>(args, {
|
|
18
|
+
...options,
|
|
19
|
+
task: "token-classification",
|
|
20
|
+
});
|
|
21
|
+
return providerHelper.getResponse(res);
|
|
38
22
|
}
|
|
@@ -1,20 +1,17 @@
|
|
|
1
1
|
import type { TranslationInput, TranslationOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
|
-
import {
|
|
4
|
+
import { innerRequest } from "../../utils/request";
|
|
5
5
|
|
|
6
6
|
export type TranslationArgs = BaseArgs & TranslationInput;
|
|
7
7
|
/**
|
|
8
8
|
* This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en.
|
|
9
9
|
*/
|
|
10
10
|
export async function translation(args: TranslationArgs, options?: Options): Promise<TranslationOutput> {
|
|
11
|
-
const
|
|
11
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation");
|
|
12
|
+
const { data: res } = await innerRequest<TranslationOutput>(args, {
|
|
12
13
|
...options,
|
|
13
14
|
task: "translation",
|
|
14
15
|
});
|
|
15
|
-
|
|
16
|
-
if (!isValidOutput) {
|
|
17
|
-
throw new InferenceOutputError("Expected type Array<{translation_text: string}>");
|
|
18
|
-
}
|
|
19
|
-
return res?.length === 1 ? res?.[0] : res;
|
|
16
|
+
return providerHelper.getResponse(res);
|
|
20
17
|
}
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import type { ZeroShotClassificationInput, ZeroShotClassificationOutput } from "@huggingface/tasks";
|
|
2
|
-
import {
|
|
2
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
|
+
import { innerRequest } from "../../utils/request";
|
|
4
5
|
import { toArray } from "../../utils/toArray";
|
|
5
|
-
import { request } from "../custom/request";
|
|
6
6
|
|
|
7
7
|
export type ZeroShotClassificationArgs = BaseArgs & ZeroShotClassificationInput;
|
|
8
8
|
|
|
@@ -13,24 +13,10 @@ export async function zeroShotClassification(
|
|
|
13
13
|
args: ZeroShotClassificationArgs,
|
|
14
14
|
options?: Options
|
|
15
15
|
): Promise<ZeroShotClassificationOutput> {
|
|
16
|
-
const
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
);
|
|
22
|
-
const isValidOutput =
|
|
23
|
-
Array.isArray(res) &&
|
|
24
|
-
res.every(
|
|
25
|
-
(x) =>
|
|
26
|
-
Array.isArray(x.labels) &&
|
|
27
|
-
x.labels.every((_label) => typeof _label === "string") &&
|
|
28
|
-
Array.isArray(x.scores) &&
|
|
29
|
-
x.scores.every((_score) => typeof _score === "number") &&
|
|
30
|
-
typeof x.sequence === "string"
|
|
31
|
-
);
|
|
32
|
-
if (!isValidOutput) {
|
|
33
|
-
throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
|
|
34
|
-
}
|
|
35
|
-
return res;
|
|
16
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification");
|
|
17
|
+
const { data: res } = await innerRequest<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>(args, {
|
|
18
|
+
...options,
|
|
19
|
+
task: "zero-shot-classification",
|
|
20
|
+
});
|
|
21
|
+
return providerHelper.getResponse(res);
|
|
36
22
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
2
2
|
import type { BaseArgs, Options } from "../../types";
|
|
3
|
-
import {
|
|
3
|
+
import { innerRequest } from "../../utils/request";
|
|
4
4
|
|
|
5
5
|
export type TabularClassificationArgs = BaseArgs & {
|
|
6
6
|
inputs: {
|
|
@@ -25,13 +25,10 @@ export async function tabularClassification(
|
|
|
25
25
|
args: TabularClassificationArgs,
|
|
26
26
|
options?: Options
|
|
27
27
|
): Promise<TabularClassificationOutput> {
|
|
28
|
-
const
|
|
28
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification");
|
|
29
|
+
const { data: res } = await innerRequest<TabularClassificationOutput>(args, {
|
|
29
30
|
...options,
|
|
30
31
|
task: "tabular-classification",
|
|
31
32
|
});
|
|
32
|
-
|
|
33
|
-
if (!isValidOutput) {
|
|
34
|
-
throw new InferenceOutputError("Expected number[]");
|
|
35
|
-
}
|
|
36
|
-
return res;
|
|
33
|
+
return providerHelper.getResponse(res);
|
|
37
34
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
2
2
|
import type { BaseArgs, Options } from "../../types";
|
|
3
|
-
import {
|
|
3
|
+
import { innerRequest } from "../../utils/request";
|
|
4
4
|
|
|
5
5
|
export type TabularRegressionArgs = BaseArgs & {
|
|
6
6
|
inputs: {
|
|
@@ -25,13 +25,10 @@ export async function tabularRegression(
|
|
|
25
25
|
args: TabularRegressionArgs,
|
|
26
26
|
options?: Options
|
|
27
27
|
): Promise<TabularRegressionOutput> {
|
|
28
|
-
const
|
|
28
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression");
|
|
29
|
+
const { data: res } = await innerRequest<TabularRegressionOutput>(args, {
|
|
29
30
|
...options,
|
|
30
31
|
task: "tabular-regression",
|
|
31
32
|
});
|
|
32
|
-
|
|
33
|
-
if (!isValidOutput) {
|
|
34
|
-
throw new InferenceOutputError("Expected number[]");
|
|
35
|
-
}
|
|
36
|
-
return res;
|
|
33
|
+
return providerHelper.getResponse(res);
|
|
37
34
|
}
|
package/src/types.ts
CHANGED
|
@@ -24,9 +24,17 @@ export interface Options {
|
|
|
24
24
|
* (Default: "same-origin"). String | Boolean. Credentials to use for the request. If this is a string, it will be passed straight on. If it's a boolean, true will be "include" and false will not send credentials at all.
|
|
25
25
|
*/
|
|
26
26
|
includeCredentials?: string | boolean;
|
|
27
|
+
|
|
28
|
+
/**
|
|
29
|
+
* The billing account to use for the requests.
|
|
30
|
+
*
|
|
31
|
+
* By default the requests are billed on the user's account.
|
|
32
|
+
* Requests can only be billed to an organization the user is a member of, and which has subscribed to Enterprise Hub.
|
|
33
|
+
*/
|
|
34
|
+
billTo?: string;
|
|
27
35
|
}
|
|
28
36
|
|
|
29
|
-
export type InferenceTask = Exclude<PipelineType, "other"
|
|
37
|
+
export type InferenceTask = Exclude<PipelineType, "other"> | "conversational";
|
|
30
38
|
|
|
31
39
|
export const INFERENCE_PROVIDERS = [
|
|
32
40
|
"black-forest-labs",
|
|
@@ -93,14 +101,6 @@ export type RequestArgs = BaseArgs &
|
|
|
93
101
|
parameters?: Record<string, unknown>;
|
|
94
102
|
};
|
|
95
103
|
|
|
96
|
-
export interface ProviderConfig {
|
|
97
|
-
makeBaseUrl: ((task?: InferenceTask) => string) | (() => string);
|
|
98
|
-
makeBody: (params: BodyParams) => Record<string, unknown>;
|
|
99
|
-
makeHeaders: (params: HeaderParams) => Record<string, string>;
|
|
100
|
-
makeUrl: (params: UrlParams) => string;
|
|
101
|
-
clientSideRoutingOnly?: boolean;
|
|
102
|
-
}
|
|
103
|
-
|
|
104
104
|
export type AuthMethod = "none" | "hf-token" | "credentials-include" | "provider-key";
|
|
105
105
|
|
|
106
106
|
export interface HeaderParams {
|
|
@@ -110,15 +110,12 @@ export interface HeaderParams {
|
|
|
110
110
|
|
|
111
111
|
export interface UrlParams {
|
|
112
112
|
authMethod: AuthMethod;
|
|
113
|
-
baseUrl: string;
|
|
114
113
|
model: string;
|
|
115
114
|
task?: InferenceTask;
|
|
116
|
-
chatCompletion?: boolean;
|
|
117
115
|
}
|
|
118
116
|
|
|
119
|
-
export interface BodyParams {
|
|
120
|
-
args:
|
|
121
|
-
chatCompletion?: boolean;
|
|
117
|
+
export interface BodyParams<T extends Record<string, unknown> = Record<string, unknown>> {
|
|
118
|
+
args: T;
|
|
122
119
|
model: string;
|
|
123
120
|
task?: InferenceTask;
|
|
124
121
|
}
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import { makeRequestOptions } from "../lib/makeRequestOptions";
|
|
2
|
+
import type { InferenceTask, Options, RequestArgs } from "../types";
|
|
3
|
+
import type { EventSourceMessage } from "../vendor/fetch-event-source/parse";
|
|
4
|
+
import { getLines, getMessages } from "../vendor/fetch-event-source/parse";
|
|
5
|
+
|
|
6
|
+
export interface ResponseWrapper<T> {
|
|
7
|
+
data: T;
|
|
8
|
+
requestContext: {
|
|
9
|
+
url: string;
|
|
10
|
+
info: RequestInit;
|
|
11
|
+
};
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* Primitive to make custom calls to the inference provider
|
|
16
|
+
*/
|
|
17
|
+
export async function innerRequest<T>(
|
|
18
|
+
args: RequestArgs,
|
|
19
|
+
options?: Options & {
|
|
20
|
+
/** In most cases (unless we pass a endpointUrl) we know the task */
|
|
21
|
+
task?: InferenceTask;
|
|
22
|
+
/** Is chat completion compatible */
|
|
23
|
+
chatCompletion?: boolean;
|
|
24
|
+
}
|
|
25
|
+
): Promise<ResponseWrapper<T>> {
|
|
26
|
+
const { url, info } = await makeRequestOptions(args, options);
|
|
27
|
+
const response = await (options?.fetch ?? fetch)(url, info);
|
|
28
|
+
|
|
29
|
+
const requestContext: ResponseWrapper<T>["requestContext"] = { url, info };
|
|
30
|
+
|
|
31
|
+
if (options?.retry_on_error !== false && response.status === 503) {
|
|
32
|
+
return innerRequest(args, options);
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
if (!response.ok) {
|
|
36
|
+
const contentType = response.headers.get("Content-Type");
|
|
37
|
+
if (["application/json", "application/problem+json"].some((ct) => contentType?.startsWith(ct))) {
|
|
38
|
+
const output = await response.json();
|
|
39
|
+
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
|
|
40
|
+
throw new Error(
|
|
41
|
+
`Server ${args.model} does not seem to support chat completion. Error: ${JSON.stringify(output.error)}`
|
|
42
|
+
);
|
|
43
|
+
}
|
|
44
|
+
if (output.error || output.detail) {
|
|
45
|
+
throw new Error(JSON.stringify(output.error ?? output.detail));
|
|
46
|
+
} else {
|
|
47
|
+
throw new Error(output);
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
const message = contentType?.startsWith("text/plain;") ? await response.text() : undefined;
|
|
51
|
+
throw new Error(message ?? "An error occurred while fetching the blob");
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
|
|
55
|
+
const data = (await response.json()) as T;
|
|
56
|
+
return { data, requestContext };
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
const blob = (await response.blob()) as T;
|
|
60
|
+
return { data: blob as unknown as T, requestContext };
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
/**
|
|
64
|
+
* Primitive to make custom inference calls that expect server-sent events, and returns the response through a generator
|
|
65
|
+
*/
|
|
66
|
+
export async function* innerStreamingRequest<T>(
|
|
67
|
+
args: RequestArgs,
|
|
68
|
+
options?: Options & {
|
|
69
|
+
/** In most cases (unless we pass a endpointUrl) we know the task */
|
|
70
|
+
task?: InferenceTask;
|
|
71
|
+
/** Is chat completion compatible */
|
|
72
|
+
chatCompletion?: boolean;
|
|
73
|
+
}
|
|
74
|
+
): AsyncGenerator<T> {
|
|
75
|
+
const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
|
|
76
|
+
const response = await (options?.fetch ?? fetch)(url, info);
|
|
77
|
+
|
|
78
|
+
if (options?.retry_on_error !== false && response.status === 503) {
|
|
79
|
+
return yield* innerStreamingRequest(args, options);
|
|
80
|
+
}
|
|
81
|
+
if (!response.ok) {
|
|
82
|
+
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
|
|
83
|
+
const output = await response.json();
|
|
84
|
+
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
|
|
85
|
+
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
|
|
86
|
+
}
|
|
87
|
+
if (typeof output.error === "string") {
|
|
88
|
+
throw new Error(output.error);
|
|
89
|
+
}
|
|
90
|
+
if (output.error && "message" in output.error && typeof output.error.message === "string") {
|
|
91
|
+
/// OpenAI errors
|
|
92
|
+
throw new Error(output.error.message);
|
|
93
|
+
}
|
|
94
|
+
// Sambanova errors
|
|
95
|
+
if (typeof output.message === "string") {
|
|
96
|
+
throw new Error(output.message);
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
throw new Error(`Server response contains error: ${response.status}`);
|
|
101
|
+
}
|
|
102
|
+
if (!response.headers.get("content-type")?.startsWith("text/event-stream")) {
|
|
103
|
+
throw new Error(
|
|
104
|
+
`Server does not support event stream content type, it returned ` + response.headers.get("content-type")
|
|
105
|
+
);
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
if (!response.body) {
|
|
109
|
+
return;
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
const reader = response.body.getReader();
|
|
113
|
+
let events: EventSourceMessage[] = [];
|
|
114
|
+
|
|
115
|
+
const onEvent = (event: EventSourceMessage) => {
|
|
116
|
+
// accumulate events in array
|
|
117
|
+
events.push(event);
|
|
118
|
+
};
|
|
119
|
+
|
|
120
|
+
const onChunk = getLines(
|
|
121
|
+
getMessages(
|
|
122
|
+
() => {},
|
|
123
|
+
() => {},
|
|
124
|
+
onEvent
|
|
125
|
+
)
|
|
126
|
+
);
|
|
127
|
+
|
|
128
|
+
try {
|
|
129
|
+
while (true) {
|
|
130
|
+
const { done, value } = await reader.read();
|
|
131
|
+
if (done) {
|
|
132
|
+
return;
|
|
133
|
+
}
|
|
134
|
+
onChunk(value);
|
|
135
|
+
for (const event of events) {
|
|
136
|
+
if (event.data.length > 0) {
|
|
137
|
+
if (event.data === "[DONE]") {
|
|
138
|
+
return;
|
|
139
|
+
}
|
|
140
|
+
const data = JSON.parse(event.data);
|
|
141
|
+
if (typeof data === "object" && data !== null && "error" in data) {
|
|
142
|
+
const errorStr =
|
|
143
|
+
typeof data.error === "string"
|
|
144
|
+
? data.error
|
|
145
|
+
: typeof data.error === "object" &&
|
|
146
|
+
data.error &&
|
|
147
|
+
"message" in data.error &&
|
|
148
|
+
typeof data.error.message === "string"
|
|
149
|
+
? data.error.message
|
|
150
|
+
: JSON.stringify(data.error);
|
|
151
|
+
throw new Error(`Error forwarded from backend: ` + errorStr);
|
|
152
|
+
}
|
|
153
|
+
yield data as T;
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
events = [];
|
|
157
|
+
}
|
|
158
|
+
} finally {
|
|
159
|
+
reader.releaseLock();
|
|
160
|
+
}
|
|
161
|
+
}
|