@huggingface/inference 3.9.2 → 3.11.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +9 -7
- package/dist/index.cjs +771 -646
- package/dist/index.js +770 -646
- package/dist/src/InferenceClient.d.ts +16 -17
- package/dist/src/InferenceClient.d.ts.map +1 -1
- package/dist/src/lib/getInferenceProviderMapping.d.ts +6 -2
- package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -1
- package/dist/src/lib/getProviderHelper.d.ts.map +1 -1
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/consts.d.ts.map +1 -1
- package/dist/src/providers/ovhcloud.d.ts +38 -0
- package/dist/src/providers/ovhcloud.d.ts.map +1 -0
- package/dist/src/providers/providerHelper.d.ts +1 -1
- package/dist/src/providers/providerHelper.d.ts.map +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
- package/dist/src/snippets/templates.exported.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioToAudio.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
- package/dist/src/types.d.ts +7 -5
- package/dist/src/types.d.ts.map +1 -1
- package/dist/src/utils/typedEntries.d.ts +4 -0
- package/dist/src/utils/typedEntries.d.ts.map +1 -0
- package/package.json +3 -3
- package/src/InferenceClient.ts +32 -43
- package/src/lib/getInferenceProviderMapping.ts +68 -19
- package/src/lib/getProviderHelper.ts +5 -0
- package/src/lib/makeRequestOptions.ts +4 -3
- package/src/providers/consts.ts +1 -0
- package/src/providers/ovhcloud.ts +75 -0
- package/src/providers/providerHelper.ts +1 -1
- package/src/snippets/getInferenceSnippets.ts +5 -4
- package/src/snippets/templates.exported.ts +7 -3
- package/src/tasks/audio/audioClassification.ts +3 -1
- package/src/tasks/audio/audioToAudio.ts +4 -1
- package/src/tasks/audio/automaticSpeechRecognition.ts +3 -1
- package/src/tasks/audio/textToSpeech.ts +2 -1
- package/src/tasks/custom/request.ts +3 -1
- package/src/tasks/custom/streamingRequest.ts +3 -1
- package/src/tasks/cv/imageClassification.ts +3 -1
- package/src/tasks/cv/imageSegmentation.ts +3 -1
- package/src/tasks/cv/imageToImage.ts +3 -1
- package/src/tasks/cv/imageToText.ts +3 -1
- package/src/tasks/cv/objectDetection.ts +3 -1
- package/src/tasks/cv/textToImage.ts +2 -1
- package/src/tasks/cv/textToVideo.ts +2 -1
- package/src/tasks/cv/zeroShotImageClassification.ts +3 -1
- package/src/tasks/multimodal/documentQuestionAnswering.ts +3 -1
- package/src/tasks/multimodal/visualQuestionAnswering.ts +3 -1
- package/src/tasks/nlp/chatCompletion.ts +3 -1
- package/src/tasks/nlp/chatCompletionStream.ts +3 -1
- package/src/tasks/nlp/featureExtraction.ts +3 -1
- package/src/tasks/nlp/fillMask.ts +3 -1
- package/src/tasks/nlp/questionAnswering.ts +4 -1
- package/src/tasks/nlp/sentenceSimilarity.ts +3 -1
- package/src/tasks/nlp/summarization.ts +3 -1
- package/src/tasks/nlp/tableQuestionAnswering.ts +3 -1
- package/src/tasks/nlp/textClassification.ts +3 -1
- package/src/tasks/nlp/textGeneration.ts +3 -1
- package/src/tasks/nlp/textGenerationStream.ts +3 -1
- package/src/tasks/nlp/tokenClassification.ts +3 -1
- package/src/tasks/nlp/translation.ts +3 -1
- package/src/tasks/nlp/zeroShotClassification.ts +3 -1
- package/src/tasks/tabular/tabularClassification.ts +3 -1
- package/src/tasks/tabular/tabularRegression.ts +3 -1
- package/src/types.ts +9 -4
- package/src/utils/typedEntries.ts +5 -0
package/src/InferenceClient.ts
CHANGED
|
@@ -1,73 +1,62 @@
|
|
|
1
1
|
import * as tasks from "./tasks";
|
|
2
|
-
import type { Options
|
|
3
|
-
import
|
|
2
|
+
import type { Options } from "./types";
|
|
3
|
+
import { omit } from "./utils/omit";
|
|
4
|
+
import { typedEntries } from "./utils/typedEntries";
|
|
4
5
|
|
|
5
6
|
/* eslint-disable @typescript-eslint/no-empty-interface */
|
|
6
7
|
/* eslint-disable @typescript-eslint/no-unsafe-declaration-merging */
|
|
7
8
|
|
|
8
9
|
type Task = typeof tasks;
|
|
9
10
|
|
|
10
|
-
type TaskWithNoAccessToken = {
|
|
11
|
-
[key in keyof Task]: (
|
|
12
|
-
args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken">,
|
|
13
|
-
options?: Parameters<Task[key]>[1]
|
|
14
|
-
) => ReturnType<Task[key]>;
|
|
15
|
-
};
|
|
16
|
-
|
|
17
|
-
type TaskWithNoAccessTokenNoEndpointUrl = {
|
|
18
|
-
[key in keyof Task]: (
|
|
19
|
-
args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "endpointUrl">,
|
|
20
|
-
options?: Parameters<Task[key]>[1]
|
|
21
|
-
) => ReturnType<Task[key]>;
|
|
22
|
-
};
|
|
23
|
-
|
|
24
11
|
export class InferenceClient {
|
|
25
12
|
private readonly accessToken: string;
|
|
26
13
|
private readonly defaultOptions: Options;
|
|
27
14
|
|
|
28
|
-
constructor(
|
|
15
|
+
constructor(
|
|
16
|
+
accessToken = "",
|
|
17
|
+
defaultOptions: Options & {
|
|
18
|
+
endpointUrl?: string;
|
|
19
|
+
} = {}
|
|
20
|
+
) {
|
|
29
21
|
this.accessToken = accessToken;
|
|
30
22
|
this.defaultOptions = defaultOptions;
|
|
31
23
|
|
|
32
|
-
for (const [name, fn] of
|
|
24
|
+
for (const [name, fn] of typedEntries(tasks)) {
|
|
33
25
|
Object.defineProperty(this, name, {
|
|
34
26
|
enumerable: false,
|
|
35
|
-
value: (params:
|
|
27
|
+
value: (params: Parameters<typeof fn>[0], options: Parameters<typeof fn>[1]) =>
|
|
36
28
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
37
|
-
fn
|
|
29
|
+
(fn as any)(
|
|
30
|
+
/// ^ The cast of fn to any is necessary, otherwise TS can't compile because the generated union type is too complex
|
|
31
|
+
{ endpointUrl: defaultOptions.endpointUrl, accessToken, ...params },
|
|
32
|
+
{
|
|
33
|
+
...omit(defaultOptions, ["endpointUrl"]),
|
|
34
|
+
...options,
|
|
35
|
+
}
|
|
36
|
+
),
|
|
38
37
|
});
|
|
39
38
|
}
|
|
40
39
|
}
|
|
41
40
|
|
|
42
41
|
/**
|
|
43
|
-
* Returns
|
|
42
|
+
* Returns a new instance of InferenceClient tied to a specified endpoint.
|
|
43
|
+
*
|
|
44
|
+
* For backward compatibility mostly.
|
|
44
45
|
*/
|
|
45
|
-
public endpoint(endpointUrl: string):
|
|
46
|
-
return new
|
|
47
|
-
}
|
|
48
|
-
}
|
|
49
|
-
|
|
50
|
-
export class InferenceClientEndpoint {
|
|
51
|
-
constructor(endpointUrl: string, accessToken = "", defaultOptions: Options = {}) {
|
|
52
|
-
accessToken;
|
|
53
|
-
defaultOptions;
|
|
54
|
-
|
|
55
|
-
for (const [name, fn] of Object.entries(tasks)) {
|
|
56
|
-
Object.defineProperty(this, name, {
|
|
57
|
-
enumerable: false,
|
|
58
|
-
value: (params: RequestArgs, options: Options) =>
|
|
59
|
-
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
60
|
-
fn({ ...params, accessToken, endpointUrl } as any, { ...defaultOptions, ...options }),
|
|
61
|
-
});
|
|
62
|
-
}
|
|
46
|
+
public endpoint(endpointUrl: string): InferenceClient {
|
|
47
|
+
return new InferenceClient(this.accessToken, { ...this.defaultOptions, endpointUrl });
|
|
63
48
|
}
|
|
64
49
|
}
|
|
65
50
|
|
|
66
|
-
export interface InferenceClient extends
|
|
67
|
-
|
|
68
|
-
export interface InferenceClientEndpoint extends TaskWithNoAccessTokenNoEndpointUrl {}
|
|
51
|
+
export interface InferenceClient extends Task {}
|
|
69
52
|
|
|
70
53
|
/**
|
|
71
|
-
* For backward compatibility only.
|
|
54
|
+
* For backward compatibility only, will remove soon.
|
|
55
|
+
* @deprecated replace with InferenceClient
|
|
72
56
|
*/
|
|
73
57
|
export class HfInference extends InferenceClient {}
|
|
58
|
+
/**
|
|
59
|
+
* For backward compatibility only, will remove soon.
|
|
60
|
+
* @deprecated replace with InferenceClient
|
|
61
|
+
*/
|
|
62
|
+
export class InferenceClientEndpoint extends InferenceClient {}
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import type { WidgetType } from "@huggingface/tasks";
|
|
2
|
-
import type { InferenceProvider, ModelId } from "../types";
|
|
3
2
|
import { HF_HUB_URL } from "../config";
|
|
4
3
|
import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts";
|
|
5
4
|
import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference";
|
|
5
|
+
import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../types";
|
|
6
6
|
import { typedInclude } from "../utils/typedInclude";
|
|
7
7
|
|
|
8
8
|
export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
|
|
@@ -20,44 +20,62 @@ export interface InferenceProviderModelMapping {
|
|
|
20
20
|
task: WidgetType;
|
|
21
21
|
}
|
|
22
22
|
|
|
23
|
-
export async function
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
provider: InferenceProvider;
|
|
28
|
-
task: WidgetType;
|
|
29
|
-
},
|
|
30
|
-
options: {
|
|
23
|
+
export async function fetchInferenceProviderMappingForModel(
|
|
24
|
+
modelId: ModelId,
|
|
25
|
+
accessToken?: string,
|
|
26
|
+
options?: {
|
|
31
27
|
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
|
|
32
28
|
}
|
|
33
|
-
): Promise<
|
|
34
|
-
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
|
|
35
|
-
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
|
|
36
|
-
}
|
|
29
|
+
): Promise<InferenceProviderMapping> {
|
|
37
30
|
let inferenceProviderMapping: InferenceProviderMapping | null;
|
|
38
|
-
if (inferenceProviderMappingCache.has(
|
|
31
|
+
if (inferenceProviderMappingCache.has(modelId)) {
|
|
39
32
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
40
|
-
inferenceProviderMapping = inferenceProviderMappingCache.get(
|
|
33
|
+
inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!;
|
|
41
34
|
} else {
|
|
42
35
|
const resp = await (options?.fetch ?? fetch)(
|
|
43
|
-
`${HF_HUB_URL}/api/models/${
|
|
36
|
+
`${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
|
|
44
37
|
{
|
|
45
|
-
headers:
|
|
38
|
+
headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {},
|
|
46
39
|
}
|
|
47
40
|
);
|
|
48
41
|
if (resp.status === 404) {
|
|
49
|
-
throw new Error(`Model ${
|
|
42
|
+
throw new Error(`Model ${modelId} does not exist`);
|
|
50
43
|
}
|
|
51
44
|
inferenceProviderMapping = await resp
|
|
52
45
|
.json()
|
|
53
46
|
.then((json) => json.inferenceProviderMapping)
|
|
54
47
|
.catch(() => null);
|
|
48
|
+
|
|
49
|
+
if (inferenceProviderMapping) {
|
|
50
|
+
inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
|
|
51
|
+
}
|
|
55
52
|
}
|
|
56
53
|
|
|
57
54
|
if (!inferenceProviderMapping) {
|
|
58
|
-
throw new Error(`We have not been able to find inference provider information for model ${
|
|
55
|
+
throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
|
|
59
56
|
}
|
|
57
|
+
return inferenceProviderMapping;
|
|
58
|
+
}
|
|
60
59
|
|
|
60
|
+
export async function getInferenceProviderMapping(
|
|
61
|
+
params: {
|
|
62
|
+
accessToken?: string;
|
|
63
|
+
modelId: ModelId;
|
|
64
|
+
provider: InferenceProvider;
|
|
65
|
+
task: WidgetType;
|
|
66
|
+
},
|
|
67
|
+
options: {
|
|
68
|
+
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
|
|
69
|
+
}
|
|
70
|
+
): Promise<InferenceProviderModelMapping | null> {
|
|
71
|
+
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
|
|
72
|
+
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
|
|
73
|
+
}
|
|
74
|
+
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
|
|
75
|
+
params.modelId,
|
|
76
|
+
params.accessToken,
|
|
77
|
+
options
|
|
78
|
+
);
|
|
61
79
|
const providerMapping = inferenceProviderMapping[params.provider];
|
|
62
80
|
if (providerMapping) {
|
|
63
81
|
const equivalentTasks =
|
|
@@ -78,3 +96,34 @@ export async function getInferenceProviderMapping(
|
|
|
78
96
|
}
|
|
79
97
|
return null;
|
|
80
98
|
}
|
|
99
|
+
|
|
100
|
+
export async function resolveProvider(
|
|
101
|
+
provider?: InferenceProviderOrPolicy,
|
|
102
|
+
modelId?: string,
|
|
103
|
+
endpointUrl?: string
|
|
104
|
+
): Promise<InferenceProvider> {
|
|
105
|
+
if (endpointUrl) {
|
|
106
|
+
if (provider) {
|
|
107
|
+
throw new Error("Specifying both endpointUrl and provider is not supported.");
|
|
108
|
+
}
|
|
109
|
+
/// Defaulting to hf-inference helpers / API
|
|
110
|
+
return "hf-inference";
|
|
111
|
+
}
|
|
112
|
+
if (!provider) {
|
|
113
|
+
console.log(
|
|
114
|
+
"Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
|
|
115
|
+
);
|
|
116
|
+
provider = "auto";
|
|
117
|
+
}
|
|
118
|
+
if (provider === "auto") {
|
|
119
|
+
if (!modelId) {
|
|
120
|
+
throw new Error("Specifying a model is required when provider is 'auto'");
|
|
121
|
+
}
|
|
122
|
+
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
|
|
123
|
+
provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider | undefined;
|
|
124
|
+
}
|
|
125
|
+
if (!provider) {
|
|
126
|
+
throw new Error(`No Inference Provider available for model ${modelId}.`);
|
|
127
|
+
}
|
|
128
|
+
return provider;
|
|
129
|
+
}
|
|
@@ -11,6 +11,7 @@ import * as Nebius from "../providers/nebius";
|
|
|
11
11
|
import * as Novita from "../providers/novita";
|
|
12
12
|
import * as Nscale from "../providers/nscale";
|
|
13
13
|
import * as OpenAI from "../providers/openai";
|
|
14
|
+
import * as OvhCloud from "../providers/ovhcloud";
|
|
14
15
|
import type {
|
|
15
16
|
AudioClassificationTaskHelper,
|
|
16
17
|
AudioToAudioTaskHelper,
|
|
@@ -126,6 +127,10 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
|
|
|
126
127
|
openai: {
|
|
127
128
|
conversational: new OpenAI.OpenAIConversationalTask(),
|
|
128
129
|
},
|
|
130
|
+
ovhcloud: {
|
|
131
|
+
conversational: new OvhCloud.OvhCloudConversationalTask(),
|
|
132
|
+
"text-generation": new OvhCloud.OvhCloudTextGenerationTask(),
|
|
133
|
+
},
|
|
129
134
|
replicate: {
|
|
130
135
|
"text-to-image": new Replicate.ReplicateTextToImageTask(),
|
|
131
136
|
"text-to-speech": new Replicate.ReplicateTextToSpeechTask(),
|
|
@@ -27,8 +27,8 @@ export async function makeRequestOptions(
|
|
|
27
27
|
task?: InferenceTask;
|
|
28
28
|
}
|
|
29
29
|
): Promise<{ url: string; info: RequestInit }> {
|
|
30
|
-
const {
|
|
31
|
-
const provider =
|
|
30
|
+
const { model: maybeModel } = args;
|
|
31
|
+
const provider = providerHelper.provider;
|
|
32
32
|
const { task } = options ?? {};
|
|
33
33
|
|
|
34
34
|
// Validate inputs
|
|
@@ -113,8 +113,9 @@ export function makeRequestOptionsFromResolvedModel(
|
|
|
113
113
|
): { url: string; info: RequestInit } {
|
|
114
114
|
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
|
|
115
115
|
void model;
|
|
116
|
+
void maybeProvider;
|
|
116
117
|
|
|
117
|
-
const provider =
|
|
118
|
+
const provider = providerHelper.provider;
|
|
118
119
|
|
|
119
120
|
const { includeCredentials, task, signal, billTo } = options ?? {};
|
|
120
121
|
const authMethod = (() => {
|
package/src/providers/consts.ts
CHANGED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* See the registered mapping of HF model ID => OVHcloud model ID here:
|
|
3
|
+
*
|
|
4
|
+
* https://huggingface.co/api/partners/ovhcloud/models
|
|
5
|
+
*
|
|
6
|
+
* This is a publicly available mapping.
|
|
7
|
+
*
|
|
8
|
+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
|
|
9
|
+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
|
|
10
|
+
*
|
|
11
|
+
* - If you work at OVHcloud and want to update this mapping, please use the model mapping API we provide on huggingface.co
|
|
12
|
+
* - If you're a community member and want to add a new supported HF model to OVHcloud, please open an issue on the present repo
|
|
13
|
+
* and we will tag OVHcloud team members.
|
|
14
|
+
*
|
|
15
|
+
* Thanks!
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper";
|
|
19
|
+
import type { ChatCompletionOutput, TextGenerationOutput, TextGenerationOutputFinishReason } from "@huggingface/tasks";
|
|
20
|
+
import { InferenceOutputError } from "../lib/InferenceOutputError";
|
|
21
|
+
import type { BodyParams } from "../types";
|
|
22
|
+
import { omit } from "../utils/omit";
|
|
23
|
+
import type { TextGenerationInput } from "@huggingface/tasks";
|
|
24
|
+
|
|
25
|
+
const OVHCLOUD_API_BASE_URL = "https://oai.endpoints.kepler.ai.cloud.ovh.net";
|
|
26
|
+
|
|
27
|
+
interface OvhCloudTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
|
|
28
|
+
choices: Array<{
|
|
29
|
+
text: string;
|
|
30
|
+
finish_reason: TextGenerationOutputFinishReason;
|
|
31
|
+
logprobs: unknown;
|
|
32
|
+
index: number;
|
|
33
|
+
}>;
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
export class OvhCloudConversationalTask extends BaseConversationalTask {
|
|
37
|
+
constructor() {
|
|
38
|
+
super("ovhcloud", OVHCLOUD_API_BASE_URL);
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
export class OvhCloudTextGenerationTask extends BaseTextGenerationTask {
|
|
43
|
+
constructor() {
|
|
44
|
+
super("ovhcloud", OVHCLOUD_API_BASE_URL);
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
override preparePayload(params: BodyParams<TextGenerationInput>): Record<string, unknown> {
|
|
48
|
+
return {
|
|
49
|
+
model: params.model,
|
|
50
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
51
|
+
...(params.args.parameters
|
|
52
|
+
? {
|
|
53
|
+
max_tokens: (params.args.parameters as Record<string, unknown>).max_new_tokens,
|
|
54
|
+
...omit(params.args.parameters as Record<string, unknown>, "max_new_tokens"),
|
|
55
|
+
}
|
|
56
|
+
: undefined),
|
|
57
|
+
prompt: params.args.inputs,
|
|
58
|
+
};
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
override async getResponse(response: OvhCloudTextCompletionOutput): Promise<TextGenerationOutput> {
|
|
62
|
+
if (
|
|
63
|
+
typeof response === "object" &&
|
|
64
|
+
"choices" in response &&
|
|
65
|
+
Array.isArray(response?.choices) &&
|
|
66
|
+
typeof response?.model === "string"
|
|
67
|
+
) {
|
|
68
|
+
const completion = response.choices[0];
|
|
69
|
+
return {
|
|
70
|
+
generated_text: completion.text,
|
|
71
|
+
};
|
|
72
|
+
}
|
|
73
|
+
throw new InferenceOutputError("Expected OVHcloud text generation response format");
|
|
74
|
+
}
|
|
75
|
+
}
|
|
@@ -56,7 +56,7 @@ import { toArray } from "../utils/toArray";
|
|
|
56
56
|
*/
|
|
57
57
|
export abstract class TaskProviderHelper {
|
|
58
58
|
constructor(
|
|
59
|
-
|
|
59
|
+
readonly provider: InferenceProvider,
|
|
60
60
|
private baseUrl: string,
|
|
61
61
|
readonly clientSideRoutingOnly: boolean = false
|
|
62
62
|
) {}
|
|
@@ -8,11 +8,11 @@ import {
|
|
|
8
8
|
} from "@huggingface/tasks";
|
|
9
9
|
import type { PipelineType, WidgetType } from "@huggingface/tasks/src/pipelines.js";
|
|
10
10
|
import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingface/tasks/src/tasks/index.js";
|
|
11
|
+
import type { InferenceProviderModelMapping } from "../lib/getInferenceProviderMapping";
|
|
12
|
+
import { getProviderHelper } from "../lib/getProviderHelper";
|
|
11
13
|
import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions";
|
|
12
14
|
import type { InferenceProvider, InferenceTask, RequestArgs } from "../types";
|
|
13
15
|
import { templates } from "./templates.exported";
|
|
14
|
-
import type { InferenceProviderModelMapping } from "../lib/getInferenceProviderMapping";
|
|
15
|
-
import { getProviderHelper } from "../lib/getProviderHelper";
|
|
16
16
|
|
|
17
17
|
export type InferenceSnippetOptions = { streaming?: boolean; billTo?: string } & Record<string, unknown>;
|
|
18
18
|
|
|
@@ -112,6 +112,7 @@ const HF_JS_METHODS: Partial<Record<WidgetType, string>> = {
|
|
|
112
112
|
"text-generation": "textGeneration",
|
|
113
113
|
"text2text-generation": "textGeneration",
|
|
114
114
|
"token-classification": "tokenClassification",
|
|
115
|
+
"text-to-speech": "textToSpeech",
|
|
115
116
|
translation: "translation",
|
|
116
117
|
};
|
|
117
118
|
|
|
@@ -271,7 +272,7 @@ const prepareConversationalInput = (
|
|
|
271
272
|
return {
|
|
272
273
|
messages: opts?.messages ?? getModelInputSnippet(model),
|
|
273
274
|
...(opts?.temperature ? { temperature: opts?.temperature } : undefined),
|
|
274
|
-
max_tokens: opts?.max_tokens
|
|
275
|
+
...(opts?.max_tokens ? { max_tokens: opts?.max_tokens } : undefined),
|
|
275
276
|
...(opts?.top_p ? { top_p: opts?.top_p } : undefined),
|
|
276
277
|
};
|
|
277
278
|
};
|
|
@@ -310,7 +311,7 @@ const snippets: Partial<
|
|
|
310
311
|
"text-generation": snippetGenerator("basic"),
|
|
311
312
|
"text-to-audio": snippetGenerator("textToAudio"),
|
|
312
313
|
"text-to-image": snippetGenerator("textToImage"),
|
|
313
|
-
"text-to-speech": snippetGenerator("
|
|
314
|
+
"text-to-speech": snippetGenerator("textToSpeech"),
|
|
314
315
|
"text-to-video": snippetGenerator("textToVideo"),
|
|
315
316
|
"text2text-generation": snippetGenerator("basic"),
|
|
316
317
|
"token-classification": snippetGenerator("basic"),
|
|
@@ -7,6 +7,7 @@ export const templates: Record<string, Record<string, Record<string, string>>> =
|
|
|
7
7
|
"basicImage": "async function query(data) {\n\tconst response = await fetch(\n\t\t\"{{ fullUrl }}\",\n\t\t{\n\t\t\theaders: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n\t\t\t\t\"Content-Type\": \"image/jpeg\",\n{% if billTo %}\n\t\t\t\t\"X-HF-Bill-To\": \"{{ billTo }}\",\n{% endif %}\t\t\t},\n\t\t\tmethod: \"POST\",\n\t\t\tbody: JSON.stringify(data),\n\t\t}\n\t);\n\tconst result = await response.json();\n\treturn result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});",
|
|
8
8
|
"textToAudio": "{% if model.library_name == \"transformers\" %}\nasync function query(data) {\n\tconst response = await fetch(\n\t\t\"{{ fullUrl }}\",\n\t\t{\n\t\t\theaders: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n\t\t\t\t\"Content-Type\": \"application/json\",\n{% if billTo %}\n\t\t\t\t\"X-HF-Bill-To\": \"{{ billTo }}\",\n{% endif %}\t\t\t},\n\t\t\tmethod: \"POST\",\n\t\t\tbody: JSON.stringify(data),\n\t\t}\n\t);\n\tconst result = await response.blob();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n\tconst response = await fetch(\n\t\t\"{{ fullUrl }}\",\n\t\t{\n\t\t\theaders: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n\t\t\t\t\"Content-Type\": \"application/json\",\n\t\t\t},\n\t\t\tmethod: \"POST\",\n\t\t\tbody: JSON.stringify(data),\n\t\t}\n\t);\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ",
|
|
9
9
|
"textToImage": "async function query(data) {\n\tconst response = await fetch(\n\t\t\"{{ fullUrl }}\",\n\t\t{\n\t\t\theaders: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n\t\t\t\t\"Content-Type\": \"application/json\",\n{% if billTo %}\n\t\t\t\t\"X-HF-Bill-To\": \"{{ billTo }}\",\n{% endif %}\t\t\t},\n\t\t\tmethod: \"POST\",\n\t\t\tbody: JSON.stringify(data),\n\t\t}\n\t);\n\tconst result = await response.blob();\n\treturn result;\n}\n\n\nquery({ {{ providerInputs.asTsString }} }).then((response) => {\n // Use image\n});",
|
|
10
|
+
"textToSpeech": "{% if model.library_name == \"transformers\" %}\nasync function query(data) {\n\tconst response = await fetch(\n\t\t\"{{ fullUrl }}\",\n\t\t{\n\t\t\theaders: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n\t\t\t\t\"Content-Type\": \"application/json\",\n{% if billTo %}\n\t\t\t\t\"X-HF-Bill-To\": \"{{ billTo }}\",\n{% endif %}\t\t\t},\n\t\t\tmethod: \"POST\",\n\t\t\tbody: JSON.stringify(data),\n\t\t}\n\t);\n\tconst result = await response.blob();\n return result;\n}\n\nquery({ text: {{ inputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n\tconst response = await fetch(\n\t\t\"{{ fullUrl }}\",\n\t\t{\n\t\t\theaders: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n\t\t\t\t\"Content-Type\": \"application/json\",\n\t\t\t},\n\t\t\tmethod: \"POST\",\n\t\t\tbody: JSON.stringify(data),\n\t\t}\n\t);\n const result = await response.json();\n return result;\n}\n\nquery({ text: {{ inputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ",
|
|
10
11
|
"zeroShotClassification": "async function query(data) {\n const response = await fetch(\n\t\t\"{{ fullUrl }}\",\n {\n headers: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n \"Content-Type\": \"application/json\",\n{% if billTo %}\n \"X-HF-Bill-To\": \"{{ billTo }}\",\n{% endif %} },\n method: \"POST\",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({\n inputs: {{ providerInputs.asObj.inputs }},\n parameters: { candidate_labels: [\"refund\", \"legal\", \"faq\"] }\n}).then((response) => {\n console.log(JSON.stringify(response));\n});"
|
|
11
12
|
},
|
|
12
13
|
"huggingface.js": {
|
|
@@ -16,7 +17,8 @@ export const templates: Record<string, Record<string, Record<string, string>>> =
|
|
|
16
17
|
"conversational": "import { InferenceClient } from \"@huggingface/inference\";\n\nconst client = new InferenceClient(\"{{ accessToken }}\");\n\nconst chatCompletion = await client.chatCompletion({\n provider: \"{{ provider }}\",\n model: \"{{ model.id }}\",\n{{ inputs.asTsString }}\n}{% if billTo %}, {\n billTo: \"{{ billTo }}\",\n}{% endif %});\n\nconsole.log(chatCompletion.choices[0].message);",
|
|
17
18
|
"conversationalStream": "import { InferenceClient } from \"@huggingface/inference\";\n\nconst client = new InferenceClient(\"{{ accessToken }}\");\n\nlet out = \"\";\n\nconst stream = client.chatCompletionStream({\n provider: \"{{ provider }}\",\n model: \"{{ model.id }}\",\n{{ inputs.asTsString }}\n}{% if billTo %}, {\n billTo: \"{{ billTo }}\",\n}{% endif %});\n\nfor await (const chunk of stream) {\n\tif (chunk.choices && chunk.choices.length > 0) {\n\t\tconst newContent = chunk.choices[0].delta.content;\n\t\tout += newContent;\n\t\tconsole.log(newContent);\n\t}\n}",
|
|
18
19
|
"textToImage": "import { InferenceClient } from \"@huggingface/inference\";\n\nconst client = new InferenceClient(\"{{ accessToken }}\");\n\nconst image = await client.textToImage({\n provider: \"{{ provider }}\",\n model: \"{{ model.id }}\",\n\tinputs: {{ inputs.asObj.inputs }},\n\tparameters: { num_inference_steps: 5 },\n}{% if billTo %}, {\n billTo: \"{{ billTo }}\",\n}{% endif %});\n/// Use the generated image (it's a Blob)",
|
|
19
|
-
"
|
|
20
|
+
"textToSpeech": "import { InferenceClient } from \"@huggingface/inference\";\n\nconst client = new InferenceClient(\"{{ accessToken }}\");\n\nconst audio = await client.textToSpeech({\n provider: \"{{ provider }}\",\n model: \"{{ model.id }}\",\n\tinputs: {{ inputs.asObj.inputs }},\n}{% if billTo %}, {\n billTo: \"{{ billTo }}\",\n}{% endif %});\n// Use the generated audio (it's a Blob)",
|
|
21
|
+
"textToVideo": "import { InferenceClient } from \"@huggingface/inference\";\n\nconst client = new InferenceClient(\"{{ accessToken }}\");\n\nconst video = await client.textToVideo({\n provider: \"{{ provider }}\",\n model: \"{{ model.id }}\",\n\tinputs: {{ inputs.asObj.inputs }},\n}{% if billTo %}, {\n billTo: \"{{ billTo }}\",\n}{% endif %});\n// Use the generated video (it's a Blob)"
|
|
20
22
|
},
|
|
21
23
|
"openai": {
|
|
22
24
|
"conversational": "import { OpenAI } from \"openai\";\n\nconst client = new OpenAI({\n\tbaseURL: \"{{ baseUrl }}\",\n\tapiKey: \"{{ accessToken }}\",\n{% if billTo %}\n\tdefaultHeaders: {\n\t\t\"X-HF-Bill-To\": \"{{ billTo }}\" \n\t}\n{% endif %}\n});\n\nconst chatCompletion = await client.chat.completions.create({\n\tmodel: \"{{ providerModelId }}\",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);",
|
|
@@ -25,7 +27,7 @@ export const templates: Record<string, Record<string, Record<string, string>>> =
|
|
|
25
27
|
},
|
|
26
28
|
"python": {
|
|
27
29
|
"fal_client": {
|
|
28
|
-
"textToImage": "{% if provider == \"fal-ai\" %}\nimport fal_client\n\nresult = fal_client.subscribe(\n \"{{ providerModelId }}\",\n arguments={\n \"prompt\": {{ inputs.asObj.inputs }},\n },\n)\nprint(result)\n{% endif %} "
|
|
30
|
+
"textToImage": "{% if provider == \"fal-ai\" %}\nimport fal_client\n\n{% if providerInputs.asObj.loras is defined and providerInputs.asObj.loras != none %}\nresult = fal_client.subscribe(\n \"{{ providerModelId }}\",\n arguments={\n \"prompt\": {{ inputs.asObj.inputs }},\n \"loras\":{{ providerInputs.asObj.loras | tojson }},\n },\n)\n{% else %}\nresult = fal_client.subscribe(\n \"{{ providerModelId }}\",\n arguments={\n \"prompt\": {{ inputs.asObj.inputs }},\n },\n)\n{% endif %} \nprint(result)\n{% endif %} "
|
|
29
31
|
},
|
|
30
32
|
"huggingface_hub": {
|
|
31
33
|
"basic": "result = client.{{ methodName }}(\n inputs={{ inputs.asObj.inputs }},\n model=\"{{ model.id }}\",\n)",
|
|
@@ -37,6 +39,7 @@ export const templates: Record<string, Record<string, Record<string, string>>> =
|
|
|
37
39
|
"imageToImage": "# output is a PIL.Image object\nimage = client.image_to_image(\n \"{{ inputs.asObj.inputs }}\",\n prompt=\"{{ inputs.asObj.parameters.prompt }}\",\n model=\"{{ model.id }}\",\n) ",
|
|
38
40
|
"importInferenceClient": "from huggingface_hub import InferenceClient\n\nclient = InferenceClient(\n provider=\"{{ provider }}\",\n api_key=\"{{ accessToken }}\",\n{% if billTo %}\n bill_to=\"{{ billTo }}\",\n{% endif %}\n)",
|
|
39
41
|
"textToImage": "# output is a PIL.Image object\nimage = client.text_to_image(\n {{ inputs.asObj.inputs }},\n model=\"{{ model.id }}\",\n) ",
|
|
42
|
+
"textToSpeech": "# audio is returned as bytes\naudio = client.text_to_speech(\n {{ inputs.asObj.inputs }},\n model=\"{{ model.id }}\",\n) \n",
|
|
40
43
|
"textToVideo": "video = client.text_to_video(\n {{ inputs.asObj.inputs }},\n model=\"{{ model.id }}\",\n) "
|
|
41
44
|
},
|
|
42
45
|
"openai": {
|
|
@@ -53,8 +56,9 @@ export const templates: Record<string, Record<string, Record<string, string>>> =
|
|
|
53
56
|
"imageToImage": "def query(payload):\n with open(payload[\"inputs\"], \"rb\") as f:\n img = f.read()\n payload[\"inputs\"] = base64.b64encode(img).decode(\"utf-8\")\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nimage_bytes = query({\n{{ providerInputs.asJsonString }}\n})\n\n# You can access the image with PIL.Image for example\nimport io\nfrom PIL import Image\nimage = Image.open(io.BytesIO(image_bytes)) ",
|
|
54
57
|
"importRequests": "{% if importBase64 %}\nimport base64\n{% endif %}\n{% if importJson %}\nimport json\n{% endif %}\nimport requests\n\nAPI_URL = \"{{ fullUrl }}\"\nheaders = {\n \"Authorization\": \"{{ authorizationHeader }}\",\n{% if billTo %}\n \"X-HF-Bill-To\": \"{{ billTo }}\"\n{% endif %}\n}",
|
|
55
58
|
"tabular": "def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nresponse = query({\n \"inputs\": {\n \"data\": {{ providerInputs.asObj.inputs }}\n },\n}) ",
|
|
56
|
-
"textToAudio": "{% if model.library_name == \"transformers\" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\naudio_bytes = query({\n \"inputs\": {{
|
|
59
|
+
"textToAudio": "{% if model.library_name == \"transformers\" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\naudio_bytes = query({\n \"inputs\": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio_bytes)\n{% else %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\naudio, sampling_rate = query({\n \"inputs\": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio, rate=sampling_rate)\n{% endif %} ",
|
|
57
60
|
"textToImage": "{% if provider == \"hf-inference\" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nimage_bytes = query({\n \"inputs\": {{ providerInputs.asObj.inputs }},\n})\n\n# You can access the image with PIL.Image for example\nimport io\nfrom PIL import Image\nimage = Image.open(io.BytesIO(image_bytes))\n{% endif %}",
|
|
61
|
+
"textToSpeech": "{% if model.library_name == \"transformers\" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\naudio_bytes = query({\n \"text\": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio_bytes)\n{% else %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\naudio, sampling_rate = query({\n \"text\": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio, rate=sampling_rate)\n{% endif %} ",
|
|
58
62
|
"zeroShotClassification": "def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\noutput = query({\n \"inputs\": {{ providerInputs.asObj.inputs }},\n \"parameters\": {\"candidate_labels\": [\"refund\", \"legal\", \"faq\"]},\n}) ",
|
|
59
63
|
"zeroShotImageClassification": "def query(data):\n with open(data[\"image_path\"], \"rb\") as f:\n img = f.read()\n payload={\n \"parameters\": data[\"parameters\"],\n \"inputs\": base64.b64encode(img).decode(\"utf-8\")\n }\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\noutput = query({\n \"image_path\": {{ providerInputs.asObj.inputs }},\n \"parameters\": {\"candidate_labels\": [\"cat\", \"dog\", \"llama\"]},\n}) "
|
|
60
64
|
}
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { AudioClassificationInput, AudioClassificationOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -15,7 +16,8 @@ export async function audioClassification(
|
|
|
15
16
|
args: AudioClassificationArgs,
|
|
16
17
|
options?: Options
|
|
17
18
|
): Promise<AudioClassificationOutput> {
|
|
18
|
-
const
|
|
19
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
20
|
+
const providerHelper = getProviderHelper(provider, "audio-classification");
|
|
19
21
|
const payload = preparePayload(args);
|
|
20
22
|
const { data: res } = await innerRequest<AudioClassificationOutput>(payload, providerHelper, {
|
|
21
23
|
...options,
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
1
2
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
2
3
|
import type { BaseArgs, Options } from "../../types";
|
|
3
4
|
import { innerRequest } from "../../utils/request";
|
|
@@ -36,7 +37,9 @@ export interface AudioToAudioOutput {
|
|
|
36
37
|
* Example model: speechbrain/sepformer-wham does audio source separation.
|
|
37
38
|
*/
|
|
38
39
|
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
|
|
39
|
-
const
|
|
40
|
+
const model = "inputs" in args ? args.model : undefined;
|
|
41
|
+
const provider = await resolveProvider(args.provider, model);
|
|
42
|
+
const providerHelper = getProviderHelper(provider, "audio-to-audio");
|
|
40
43
|
const payload = preparePayload(args);
|
|
41
44
|
const { data: res } = await innerRequest<AudioToAudioOutput>(payload, providerHelper, {
|
|
42
45
|
...options,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
|
4
5
|
import { FAL_AI_SUPPORTED_BLOB_TYPES } from "../../providers/fal-ai";
|
|
@@ -18,7 +19,8 @@ export async function automaticSpeechRecognition(
|
|
|
18
19
|
args: AutomaticSpeechRecognitionArgs,
|
|
19
20
|
options?: Options
|
|
20
21
|
): Promise<AutomaticSpeechRecognitionOutput> {
|
|
21
|
-
const
|
|
22
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
23
|
+
const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
|
|
22
24
|
const payload = await buildPayload(args);
|
|
23
25
|
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, providerHelper, {
|
|
24
26
|
...options,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { TextToSpeechInput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -12,7 +13,7 @@ interface OutputUrlTextToSpeechGeneration {
|
|
|
12
13
|
* Recommended model: espnet/kan-bayashi_ljspeech_vits
|
|
13
14
|
*/
|
|
14
15
|
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
|
|
15
|
-
const provider = args.provider
|
|
16
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
16
17
|
const providerHelper = getProviderHelper(provider, "text-to-speech");
|
|
17
18
|
const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, providerHelper, {
|
|
18
19
|
...options,
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
1
2
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
2
3
|
import type { InferenceTask, Options, RequestArgs } from "../../types";
|
|
3
4
|
import { innerRequest } from "../../utils/request";
|
|
@@ -16,7 +17,8 @@ export async function request<T>(
|
|
|
16
17
|
console.warn(
|
|
17
18
|
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
|
|
18
19
|
);
|
|
19
|
-
const
|
|
20
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
21
|
+
const providerHelper = getProviderHelper(provider, options?.task);
|
|
20
22
|
const result = await innerRequest<T>(args, providerHelper, options);
|
|
21
23
|
return result.data;
|
|
22
24
|
}
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
1
2
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
2
3
|
import type { InferenceTask, Options, RequestArgs } from "../../types";
|
|
3
4
|
import { innerStreamingRequest } from "../../utils/request";
|
|
@@ -16,6 +17,7 @@ export async function* streamingRequest<T>(
|
|
|
16
17
|
console.warn(
|
|
17
18
|
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
|
|
18
19
|
);
|
|
19
|
-
const
|
|
20
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
21
|
+
const providerHelper = getProviderHelper(provider, options?.task);
|
|
20
22
|
yield* innerStreamingRequest(args, providerHelper, options);
|
|
21
23
|
}
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { ImageClassificationInput, ImageClassificationOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -14,7 +15,8 @@ export async function imageClassification(
|
|
|
14
15
|
args: ImageClassificationArgs,
|
|
15
16
|
options?: Options
|
|
16
17
|
): Promise<ImageClassificationOutput> {
|
|
17
|
-
const
|
|
18
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
19
|
+
const providerHelper = getProviderHelper(provider, "image-classification");
|
|
18
20
|
const payload = preparePayload(args);
|
|
19
21
|
const { data: res } = await innerRequest<ImageClassificationOutput>(payload, providerHelper, {
|
|
20
22
|
...options,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { ImageSegmentationInput, ImageSegmentationOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -14,7 +15,8 @@ export async function imageSegmentation(
|
|
|
14
15
|
args: ImageSegmentationArgs,
|
|
15
16
|
options?: Options
|
|
16
17
|
): Promise<ImageSegmentationOutput> {
|
|
17
|
-
const
|
|
18
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
19
|
+
const providerHelper = getProviderHelper(provider, "image-segmentation");
|
|
18
20
|
const payload = preparePayload(args);
|
|
19
21
|
const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, providerHelper, {
|
|
20
22
|
...options,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { ImageToImageInput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options, RequestArgs } from "../../types";
|
|
4
5
|
import { base64FromBytes } from "../../utils/base64FromBytes";
|
|
@@ -11,7 +12,8 @@ export type ImageToImageArgs = BaseArgs & ImageToImageInput;
|
|
|
11
12
|
* Recommended model: lllyasviel/sd-controlnet-depth
|
|
12
13
|
*/
|
|
13
14
|
export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> {
|
|
14
|
-
const
|
|
15
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
16
|
+
const providerHelper = getProviderHelper(provider, "image-to-image");
|
|
15
17
|
let reqArgs: RequestArgs;
|
|
16
18
|
if (!args.parameters) {
|
|
17
19
|
reqArgs = {
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import type { ImageToTextInput, ImageToTextOutput } from "@huggingface/tasks";
|
|
2
|
+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
|
|
2
3
|
import { getProviderHelper } from "../../lib/getProviderHelper";
|
|
3
4
|
import type { BaseArgs, Options } from "../../types";
|
|
4
5
|
import { innerRequest } from "../../utils/request";
|
|
@@ -10,7 +11,8 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput);
|
|
|
10
11
|
* This task reads some image input and outputs the text caption.
|
|
11
12
|
*/
|
|
12
13
|
export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
|
|
13
|
-
const
|
|
14
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
15
|
+
const providerHelper = getProviderHelper(provider, "image-to-text");
|
|
14
16
|
const payload = preparePayload(args);
|
|
15
17
|
const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, providerHelper, {
|
|
16
18
|
...options,
|