@huggingface/inference 3.7.1 → 3.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.cjs +247 -132
- package/dist/index.js +247 -132
- package/dist/src/lib/getInferenceProviderMapping.d.ts +21 -0
- package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -0
- package/dist/src/lib/makeRequestOptions.d.ts +5 -3
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/consts.d.ts +2 -3
- package/dist/src/providers/consts.d.ts.map +1 -1
- package/dist/src/providers/fal-ai.d.ts.map +1 -1
- package/dist/src/providers/hf-inference.d.ts +1 -0
- package/dist/src/providers/hf-inference.d.ts.map +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts +2 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/types.d.ts +2 -0
- package/dist/src/types.d.ts.map +1 -1
- package/dist/src/utils/request.d.ts +3 -2
- package/dist/src/utils/request.d.ts.map +1 -1
- package/package.json +3 -3
- package/src/lib/getInferenceProviderMapping.ts +96 -0
- package/src/lib/makeRequestOptions.ts +50 -12
- package/src/providers/consts.ts +5 -2
- package/src/providers/fal-ai.ts +31 -2
- package/src/providers/hf-inference.ts +8 -6
- package/src/snippets/getInferenceSnippets.ts +26 -8
- package/src/snippets/templates.exported.ts +25 -25
- package/src/tasks/audio/audioClassification.ts +1 -1
- package/src/tasks/audio/audioToAudio.ts +1 -1
- package/src/tasks/audio/automaticSpeechRecognition.ts +1 -1
- package/src/tasks/audio/textToSpeech.ts +1 -1
- package/src/tasks/custom/request.ts +3 -1
- package/src/tasks/custom/streamingRequest.ts +4 -1
- package/src/tasks/cv/imageClassification.ts +1 -1
- package/src/tasks/cv/imageSegmentation.ts +1 -1
- package/src/tasks/cv/imageToImage.ts +1 -1
- package/src/tasks/cv/imageToText.ts +1 -1
- package/src/tasks/cv/objectDetection.ts +1 -1
- package/src/tasks/cv/textToImage.ts +2 -2
- package/src/tasks/cv/textToVideo.ts +9 -5
- package/src/tasks/cv/zeroShotImageClassification.ts +1 -1
- package/src/tasks/multimodal/documentQuestionAnswering.ts +1 -0
- package/src/tasks/multimodal/visualQuestionAnswering.ts +1 -1
- package/src/tasks/nlp/chatCompletion.ts +1 -1
- package/src/tasks/nlp/chatCompletionStream.ts +3 -1
- package/src/tasks/nlp/featureExtraction.ts +1 -1
- package/src/tasks/nlp/fillMask.ts +1 -1
- package/src/tasks/nlp/questionAnswering.ts +8 -4
- package/src/tasks/nlp/sentenceSimilarity.ts +1 -1
- package/src/tasks/nlp/summarization.ts +1 -1
- package/src/tasks/nlp/tableQuestionAnswering.ts +8 -4
- package/src/tasks/nlp/textClassification.ts +1 -1
- package/src/tasks/nlp/textGeneration.ts +2 -3
- package/src/tasks/nlp/textGenerationStream.ts +3 -1
- package/src/tasks/nlp/tokenClassification.ts +8 -5
- package/src/tasks/nlp/translation.ts +1 -1
- package/src/tasks/nlp/zeroShotClassification.ts +8 -5
- package/src/tasks/tabular/tabularClassification.ts +1 -1
- package/src/tasks/tabular/tabularRegression.ts +1 -1
- package/src/types.ts +2 -0
- package/src/utils/request.ts +7 -4
- package/dist/src/lib/getProviderModelId.d.ts +0 -10
- package/dist/src/lib/getProviderModelId.d.ts.map +0 -1
- package/src/lib/getProviderModelId.ts +0 -74
|
@@ -9,7 +9,7 @@ export type TranslationArgs = BaseArgs & TranslationInput;
|
|
|
9
9
|
*/
|
|
10
10
|
export async function translation(args: TranslationArgs, options?: Options): Promise<TranslationOutput> {
|
|
11
11
|
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation");
|
|
12
|
-
const { data: res } = await innerRequest<TranslationOutput>(args, {
|
|
12
|
+
const { data: res } = await innerRequest<TranslationOutput>(args, providerHelper, {
|
|
13
13
|
...options,
|
|
14
14
|
task: "translation",
|
|
15
15
|
});
|
|
@@ -2,7 +2,6 @@ import type { ZeroShotClassificationInput, ZeroShotClassificationOutput } from "
|
|
|
2
2
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
3
|
import type { BaseArgs, Options } from "../../types";
|
|
4
4
|
import { innerRequest } from "../../utils/request";
|
|
5
|
-
import { toArray } from "../../utils/toArray";
|
|
6
5
|
|
|
7
6
|
export type ZeroShotClassificationArgs = BaseArgs & ZeroShotClassificationInput;
|
|
8
7
|
|
|
@@ -14,9 +13,13 @@ export async function zeroShotClassification(
|
|
|
14
13
|
options?: Options
|
|
15
14
|
): Promise<ZeroShotClassificationOutput> {
|
|
16
15
|
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification");
|
|
17
|
-
const { data: res } = await innerRequest<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>(
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
16
|
+
const { data: res } = await innerRequest<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>(
|
|
17
|
+
args,
|
|
18
|
+
providerHelper,
|
|
19
|
+
{
|
|
20
|
+
...options,
|
|
21
|
+
task: "zero-shot-classification",
|
|
22
|
+
}
|
|
23
|
+
);
|
|
21
24
|
return providerHelper.getResponse(res);
|
|
22
25
|
}
|
|
@@ -26,7 +26,7 @@ export async function tabularClassification(
|
|
|
26
26
|
options?: Options
|
|
27
27
|
): Promise<TabularClassificationOutput> {
|
|
28
28
|
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification");
|
|
29
|
-
const { data: res } = await innerRequest<TabularClassificationOutput>(args, {
|
|
29
|
+
const { data: res } = await innerRequest<TabularClassificationOutput>(args, providerHelper, {
|
|
30
30
|
...options,
|
|
31
31
|
task: "tabular-classification",
|
|
32
32
|
});
|
|
@@ -26,7 +26,7 @@ export async function tabularRegression(
|
|
|
26
26
|
options?: Options
|
|
27
27
|
): Promise<TabularRegressionOutput> {
|
|
28
28
|
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression");
|
|
29
|
-
const { data: res } = await innerRequest<TabularRegressionOutput>(args, {
|
|
29
|
+
const { data: res } = await innerRequest<TabularRegressionOutput>(args, providerHelper, {
|
|
30
30
|
...options,
|
|
31
31
|
task: "tabular-regression",
|
|
32
32
|
});
|
package/src/types.ts
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { ChatCompletionInput, PipelineType } from "@huggingface/tasks";
|
|
2
|
+
import type { InferenceProviderModelMapping } from "./lib/getInferenceProviderMapping";
|
|
2
3
|
|
|
3
4
|
/**
|
|
4
5
|
* HF model id, like "meta-llama/Llama-3.3-70B-Instruct"
|
|
@@ -117,5 +118,6 @@ export interface UrlParams {
|
|
|
117
118
|
export interface BodyParams<T extends Record<string, unknown> = Record<string, unknown>> {
|
|
118
119
|
args: T;
|
|
119
120
|
model: string;
|
|
121
|
+
mapping?: InferenceProviderModelMapping | undefined;
|
|
120
122
|
task?: InferenceTask;
|
|
121
123
|
}
|
package/src/utils/request.ts
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import type { getProviderHelper } from "../lib/getProviderHelper";
|
|
1
2
|
import { makeRequestOptions } from "../lib/makeRequestOptions";
|
|
2
3
|
import type { InferenceTask, Options, RequestArgs } from "../types";
|
|
3
4
|
import type { EventSourceMessage } from "../vendor/fetch-event-source/parse";
|
|
@@ -16,6 +17,7 @@ export interface ResponseWrapper<T> {
|
|
|
16
17
|
*/
|
|
17
18
|
export async function innerRequest<T>(
|
|
18
19
|
args: RequestArgs,
|
|
20
|
+
providerHelper: ReturnType<typeof getProviderHelper>,
|
|
19
21
|
options?: Options & {
|
|
20
22
|
/** In most cases (unless we pass a endpointUrl) we know the task */
|
|
21
23
|
task?: InferenceTask;
|
|
@@ -23,13 +25,13 @@ export async function innerRequest<T>(
|
|
|
23
25
|
chatCompletion?: boolean;
|
|
24
26
|
}
|
|
25
27
|
): Promise<ResponseWrapper<T>> {
|
|
26
|
-
const { url, info } = await makeRequestOptions(args, options);
|
|
28
|
+
const { url, info } = await makeRequestOptions(args, providerHelper, options);
|
|
27
29
|
const response = await (options?.fetch ?? fetch)(url, info);
|
|
28
30
|
|
|
29
31
|
const requestContext: ResponseWrapper<T>["requestContext"] = { url, info };
|
|
30
32
|
|
|
31
33
|
if (options?.retry_on_error !== false && response.status === 503) {
|
|
32
|
-
return innerRequest(args, options);
|
|
34
|
+
return innerRequest(args, providerHelper, options);
|
|
33
35
|
}
|
|
34
36
|
|
|
35
37
|
if (!response.ok) {
|
|
@@ -65,6 +67,7 @@ export async function innerRequest<T>(
|
|
|
65
67
|
*/
|
|
66
68
|
export async function* innerStreamingRequest<T>(
|
|
67
69
|
args: RequestArgs,
|
|
70
|
+
providerHelper: ReturnType<typeof getProviderHelper>,
|
|
68
71
|
options?: Options & {
|
|
69
72
|
/** In most cases (unless we pass a endpointUrl) we know the task */
|
|
70
73
|
task?: InferenceTask;
|
|
@@ -72,11 +75,11 @@ export async function* innerStreamingRequest<T>(
|
|
|
72
75
|
chatCompletion?: boolean;
|
|
73
76
|
}
|
|
74
77
|
): AsyncGenerator<T> {
|
|
75
|
-
const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
|
|
78
|
+
const { url, info } = await makeRequestOptions({ ...args, stream: true }, providerHelper, options);
|
|
76
79
|
const response = await (options?.fetch ?? fetch)(url, info);
|
|
77
80
|
|
|
78
81
|
if (options?.retry_on_error !== false && response.status === 503) {
|
|
79
|
-
return yield* innerStreamingRequest(args, options);
|
|
82
|
+
return yield* innerStreamingRequest(args, providerHelper, options);
|
|
80
83
|
}
|
|
81
84
|
if (!response.ok) {
|
|
82
85
|
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
import type { InferenceProvider, InferenceTask, Options, RequestArgs } from "../types";
|
|
2
|
-
export declare function getProviderModelId(params: {
|
|
3
|
-
model: string;
|
|
4
|
-
provider: InferenceProvider;
|
|
5
|
-
}, args: RequestArgs, options?: {
|
|
6
|
-
task?: InferenceTask;
|
|
7
|
-
chatCompletion?: boolean;
|
|
8
|
-
fetch?: Options["fetch"];
|
|
9
|
-
}): Promise<string>;
|
|
10
|
-
//# sourceMappingURL=getProviderModelId.d.ts.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"getProviderModelId.d.ts","sourceRoot":"","sources":["../../../src/lib/getProviderModelId.ts"],"names":[],"mappings":"AACA,OAAO,KAAK,EAAE,iBAAiB,EAAE,aAAa,EAAW,OAAO,EAAE,WAAW,EAAE,MAAM,UAAU,CAAC;AAShG,wBAAsB,kBAAkB,CACvC,MAAM,EAAE;IACP,KAAK,EAAE,MAAM,CAAC;IACd,QAAQ,EAAE,iBAAiB,CAAC;CAC5B,EACD,IAAI,EAAE,WAAW,EACjB,OAAO,GAAE;IACR,IAAI,CAAC,EAAE,aAAa,CAAC;IACrB,cAAc,CAAC,EAAE,OAAO,CAAC;IACzB,KAAK,CAAC,EAAE,OAAO,CAAC,OAAO,CAAC,CAAC;CACpB,GACJ,OAAO,CAAC,MAAM,CAAC,CAoDjB"}
|
|
@@ -1,74 +0,0 @@
|
|
|
1
|
-
import type { WidgetType } from "@huggingface/tasks";
|
|
2
|
-
import type { InferenceProvider, InferenceTask, ModelId, Options, RequestArgs } from "../types";
|
|
3
|
-
import { HF_HUB_URL } from "../config";
|
|
4
|
-
import { HARDCODED_MODEL_ID_MAPPING } from "../providers/consts";
|
|
5
|
-
|
|
6
|
-
type InferenceProviderMapping = Partial<
|
|
7
|
-
Record<InferenceProvider, { providerId: string; status: "live" | "staging"; task: WidgetType }>
|
|
8
|
-
>;
|
|
9
|
-
const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
|
|
10
|
-
|
|
11
|
-
export async function getProviderModelId(
|
|
12
|
-
params: {
|
|
13
|
-
model: string;
|
|
14
|
-
provider: InferenceProvider;
|
|
15
|
-
},
|
|
16
|
-
args: RequestArgs,
|
|
17
|
-
options: {
|
|
18
|
-
task?: InferenceTask;
|
|
19
|
-
chatCompletion?: boolean;
|
|
20
|
-
fetch?: Options["fetch"];
|
|
21
|
-
} = {}
|
|
22
|
-
): Promise<string> {
|
|
23
|
-
if (params.provider === "hf-inference") {
|
|
24
|
-
return params.model;
|
|
25
|
-
}
|
|
26
|
-
if (!options.task) {
|
|
27
|
-
throw new Error("task must be specified when using a third-party provider");
|
|
28
|
-
}
|
|
29
|
-
const task: WidgetType =
|
|
30
|
-
options.task === "text-generation" && options.chatCompletion ? "conversational" : options.task;
|
|
31
|
-
|
|
32
|
-
// A dict called HARDCODED_MODEL_ID_MAPPING takes precedence in all cases (useful for dev purposes)
|
|
33
|
-
if (HARDCODED_MODEL_ID_MAPPING[params.provider]?.[params.model]) {
|
|
34
|
-
return HARDCODED_MODEL_ID_MAPPING[params.provider][params.model];
|
|
35
|
-
}
|
|
36
|
-
|
|
37
|
-
let inferenceProviderMapping: InferenceProviderMapping | null;
|
|
38
|
-
if (inferenceProviderMappingCache.has(params.model)) {
|
|
39
|
-
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
40
|
-
inferenceProviderMapping = inferenceProviderMappingCache.get(params.model)!;
|
|
41
|
-
} else {
|
|
42
|
-
inferenceProviderMapping = await (options?.fetch ?? fetch)(
|
|
43
|
-
`${HF_HUB_URL}/api/models/${params.model}?expand[]=inferenceProviderMapping`,
|
|
44
|
-
{
|
|
45
|
-
headers: args.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${args.accessToken}` } : {},
|
|
46
|
-
}
|
|
47
|
-
)
|
|
48
|
-
.then((resp) => resp.json())
|
|
49
|
-
.then((json) => json.inferenceProviderMapping)
|
|
50
|
-
.catch(() => null);
|
|
51
|
-
}
|
|
52
|
-
|
|
53
|
-
if (!inferenceProviderMapping) {
|
|
54
|
-
throw new Error(`We have not been able to find inference provider information for model ${params.model}.`);
|
|
55
|
-
}
|
|
56
|
-
|
|
57
|
-
const providerMapping = inferenceProviderMapping[params.provider];
|
|
58
|
-
if (providerMapping) {
|
|
59
|
-
if (providerMapping.task !== task) {
|
|
60
|
-
throw new Error(
|
|
61
|
-
`Model ${params.model} is not supported for task ${task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
|
|
62
|
-
);
|
|
63
|
-
}
|
|
64
|
-
if (providerMapping.status === "staging") {
|
|
65
|
-
console.warn(
|
|
66
|
-
`Model ${params.model} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
|
|
67
|
-
);
|
|
68
|
-
}
|
|
69
|
-
// TODO: how is it handled server-side if model has multiple tasks (e.g. `text-generation` + `conversational`)?
|
|
70
|
-
return providerMapping.providerId;
|
|
71
|
-
}
|
|
72
|
-
|
|
73
|
-
throw new Error(`Model ${params.model} is not supported provider ${params.provider}.`);
|
|
74
|
-
}
|