@huggingface/inference 2.8.0 → 3.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +1 -1
- package/README.md +39 -16
- package/dist/index.cjs +364 -134
- package/dist/index.js +359 -134
- package/dist/src/config.d.ts +3 -0
- package/dist/src/config.d.ts.map +1 -0
- package/dist/src/index.d.ts +5 -0
- package/dist/src/index.d.ts.map +1 -1
- package/dist/src/lib/getDefaultTask.d.ts +0 -1
- package/dist/src/lib/getDefaultTask.d.ts.map +1 -1
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/fal-ai.d.ts +6 -0
- package/dist/src/providers/fal-ai.d.ts.map +1 -0
- package/dist/src/providers/replicate.d.ts +6 -0
- package/dist/src/providers/replicate.d.ts.map +1 -0
- package/dist/src/providers/sambanova.d.ts +6 -0
- package/dist/src/providers/sambanova.d.ts.map +1 -0
- package/dist/src/providers/together.d.ts +12 -0
- package/dist/src/providers/together.d.ts.map +1 -0
- package/dist/src/providers/types.d.ts +4 -0
- package/dist/src/providers/types.d.ts.map +1 -0
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts +1 -1
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts +8 -0
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/types.d.ts +16 -2
- package/dist/src/types.d.ts.map +1 -1
- package/package.json +2 -2
- package/src/config.ts +2 -0
- package/src/index.ts +5 -0
- package/src/lib/getDefaultTask.ts +1 -1
- package/src/lib/makeRequestOptions.ts +199 -59
- package/src/providers/fal-ai.ts +15 -0
- package/src/providers/replicate.ts +16 -0
- package/src/providers/sambanova.ts +23 -0
- package/src/providers/together.ts +58 -0
- package/src/providers/types.ts +6 -0
- package/src/tasks/audio/automaticSpeechRecognition.ts +10 -1
- package/src/tasks/custom/request.ts +12 -6
- package/src/tasks/custom/streamingRequest.ts +18 -3
- package/src/tasks/cv/textToImage.ts +44 -1
- package/src/tasks/nlp/chatCompletion.ts +2 -2
- package/src/tasks/nlp/textGeneration.ts +43 -9
- package/src/types.ts +20 -2
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import type { ProviderMapping } from "./types";
|
|
2
|
+
|
|
3
|
+
export const TOGETHER_API_BASE_URL = "https://api.together.xyz";
|
|
4
|
+
|
|
5
|
+
/**
|
|
6
|
+
* Same comment as in sambanova.ts
|
|
7
|
+
*/
|
|
8
|
+
type TogetherId = string;
|
|
9
|
+
|
|
10
|
+
/**
|
|
11
|
+
* https://docs.together.ai/reference/models-1
|
|
12
|
+
*/
|
|
13
|
+
export const TOGETHER_SUPPORTED_MODEL_IDS: ProviderMapping<TogetherId> = {
|
|
14
|
+
"text-to-image": {
|
|
15
|
+
"black-forest-labs/FLUX.1-Canny-dev": "black-forest-labs/FLUX.1-canny",
|
|
16
|
+
"black-forest-labs/FLUX.1-Depth-dev": "black-forest-labs/FLUX.1-depth",
|
|
17
|
+
"black-forest-labs/FLUX.1-dev": "black-forest-labs/FLUX.1-dev",
|
|
18
|
+
"black-forest-labs/FLUX.1-Redux-dev": "black-forest-labs/FLUX.1-redux",
|
|
19
|
+
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/FLUX.1-pro",
|
|
20
|
+
"stabilityai/stable-diffusion-xl-base-1.0": "stabilityai/stable-diffusion-xl-base-1.0",
|
|
21
|
+
},
|
|
22
|
+
conversational: {
|
|
23
|
+
"databricks/dbrx-instruct": "databricks/dbrx-instruct",
|
|
24
|
+
"deepseek-ai/deepseek-llm-67b-chat": "deepseek-ai/deepseek-llm-67b-chat",
|
|
25
|
+
"google/gemma-2-9b-it": "google/gemma-2-9b-it",
|
|
26
|
+
"google/gemma-2b-it": "google/gemma-2-27b-it",
|
|
27
|
+
"llava-hf/llava-v1.6-mistral-7b-hf": "llava-hf/llava-v1.6-mistral-7b-hf",
|
|
28
|
+
"meta-llama/Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
|
|
29
|
+
"meta-llama/Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
|
|
30
|
+
"meta-llama/Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
|
|
31
|
+
"meta-llama/Llama-3.2-11B-Vision-Instruct": "meta-llama/Llama-Vision-Free",
|
|
32
|
+
"meta-llama/Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
|
33
|
+
"meta-llama/Llama-3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
|
34
|
+
"meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
|
35
|
+
"meta-llama/Meta-Llama-3-70B-Instruct": "meta-llama/Llama-3-70b-chat-hf",
|
|
36
|
+
"meta-llama/Meta-Llama-3-8B-Instruct": "togethercomputer/Llama-3-8b-chat-hf-int4",
|
|
37
|
+
"meta-llama/Meta-Llama-3.1-405B-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
|
38
|
+
"meta-llama/Meta-Llama-3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
|
39
|
+
"meta-llama/Meta-Llama-3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K",
|
|
40
|
+
"microsoft/WizardLM-2-8x22B": "microsoft/WizardLM-2-8x22B",
|
|
41
|
+
"mistralai/Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3",
|
|
42
|
+
"mistralai/Mixtral-8x22B-Instruct-v0.1": "mistralai/Mixtral-8x22B-Instruct-v0.1",
|
|
43
|
+
"mistralai/Mixtral-8x7B-Instruct-v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
44
|
+
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
|
|
45
|
+
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
|
|
46
|
+
"Qwen/Qwen2-72B-Instruct": "Qwen/Qwen2-72B-Instruct",
|
|
47
|
+
"Qwen/Qwen2.5-72B-Instruct": "Qwen/Qwen2.5-72B-Instruct-Turbo",
|
|
48
|
+
"Qwen/Qwen2.5-7B-Instruct": "Qwen/Qwen2.5-7B-Instruct-Turbo",
|
|
49
|
+
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen/Qwen2.5-Coder-32B-Instruct",
|
|
50
|
+
"Qwen/QwQ-32B-Preview": "Qwen/QwQ-32B-Preview",
|
|
51
|
+
"scb10x/llama-3-typhoon-v1.5-8b-instruct": "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct",
|
|
52
|
+
"scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": "scb10x/scb10x-llama3-typhoon-v1-5x-4f316",
|
|
53
|
+
},
|
|
54
|
+
"text-generation": {
|
|
55
|
+
"meta-llama/Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B",
|
|
56
|
+
"mistralai/Mixtral-8x7B-v0.1": "mistralai/Mixtral-8x7B-v0.1",
|
|
57
|
+
},
|
|
58
|
+
};
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
|
2
|
-
import type { BaseArgs, Options } from "../../types";
|
|
2
|
+
import type { BaseArgs, Options, RequestArgs } from "../../types";
|
|
3
|
+
import { base64FromBytes } from "../../utils/base64FromBytes";
|
|
3
4
|
import { request } from "../custom/request";
|
|
4
5
|
|
|
5
6
|
export type AutomaticSpeechRecognitionArgs = BaseArgs & {
|
|
@@ -24,6 +25,14 @@ export async function automaticSpeechRecognition(
|
|
|
24
25
|
args: AutomaticSpeechRecognitionArgs,
|
|
25
26
|
options?: Options
|
|
26
27
|
): Promise<AutomaticSpeechRecognitionOutput> {
|
|
28
|
+
if (args.provider === "fal-ai") {
|
|
29
|
+
const contentType = args.data instanceof Blob ? args.data.type : "audio/mpeg";
|
|
30
|
+
const base64audio = base64FromBytes(
|
|
31
|
+
new Uint8Array(args.data instanceof ArrayBuffer ? args.data : await args.data.arrayBuffer())
|
|
32
|
+
);
|
|
33
|
+
(args as RequestArgs & { audio_url: string }).audio_url = `data:${contentType};base64,${base64audio}`;
|
|
34
|
+
delete (args as RequestArgs & { data: unknown }).data;
|
|
35
|
+
}
|
|
27
36
|
const res = await request<AutomaticSpeechRecognitionOutput>(args, {
|
|
28
37
|
...options,
|
|
29
38
|
taskHint: "automatic-speech-recognition",
|
|
@@ -2,7 +2,7 @@ import type { InferenceTask, Options, RequestArgs } from "../../types";
|
|
|
2
2
|
import { makeRequestOptions } from "../../lib/makeRequestOptions";
|
|
3
3
|
|
|
4
4
|
/**
|
|
5
|
-
* Primitive to make custom calls to
|
|
5
|
+
* Primitive to make custom calls to the inference provider
|
|
6
6
|
*/
|
|
7
7
|
export async function request<T>(
|
|
8
8
|
args: RequestArgs,
|
|
@@ -26,16 +26,22 @@ export async function request<T>(
|
|
|
26
26
|
}
|
|
27
27
|
|
|
28
28
|
if (!response.ok) {
|
|
29
|
-
|
|
29
|
+
const contentType = response.headers.get("Content-Type");
|
|
30
|
+
if (["application/json", "application/problem+json"].some((ct) => contentType?.startsWith(ct))) {
|
|
30
31
|
const output = await response.json();
|
|
31
32
|
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
|
|
32
|
-
throw new Error(
|
|
33
|
+
throw new Error(
|
|
34
|
+
`Server ${args.model} does not seem to support chat completion. Error: ${JSON.stringify(output.error)}`
|
|
35
|
+
);
|
|
33
36
|
}
|
|
34
|
-
if (output.error) {
|
|
35
|
-
throw new Error(output.error);
|
|
37
|
+
if (output.error || output.detail) {
|
|
38
|
+
throw new Error(JSON.stringify(output.error ?? output.detail));
|
|
39
|
+
} else {
|
|
40
|
+
throw new Error(output);
|
|
36
41
|
}
|
|
37
42
|
}
|
|
38
|
-
|
|
43
|
+
const message = contentType?.startsWith("text/plain;") ? await response.text() : undefined;
|
|
44
|
+
throw new Error(message ?? "An error occurred while fetching the blob");
|
|
39
45
|
}
|
|
40
46
|
|
|
41
47
|
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
|
|
@@ -32,9 +32,13 @@ export async function* streamingRequest<T>(
|
|
|
32
32
|
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
|
|
33
33
|
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
|
|
34
34
|
}
|
|
35
|
-
if (output.error) {
|
|
35
|
+
if (typeof output.error === "string") {
|
|
36
36
|
throw new Error(output.error);
|
|
37
37
|
}
|
|
38
|
+
if (output.error && "message" in output.error && typeof output.error.message === "string") {
|
|
39
|
+
/// OpenAI errors
|
|
40
|
+
throw new Error(output.error.message);
|
|
41
|
+
}
|
|
38
42
|
}
|
|
39
43
|
|
|
40
44
|
throw new Error(`Server response contains error: ${response.status}`);
|
|
@@ -68,7 +72,9 @@ export async function* streamingRequest<T>(
|
|
|
68
72
|
try {
|
|
69
73
|
while (true) {
|
|
70
74
|
const { done, value } = await reader.read();
|
|
71
|
-
if (done)
|
|
75
|
+
if (done) {
|
|
76
|
+
return;
|
|
77
|
+
}
|
|
72
78
|
onChunk(value);
|
|
73
79
|
for (const event of events) {
|
|
74
80
|
if (event.data.length > 0) {
|
|
@@ -77,7 +83,16 @@ export async function* streamingRequest<T>(
|
|
|
77
83
|
}
|
|
78
84
|
const data = JSON.parse(event.data);
|
|
79
85
|
if (typeof data === "object" && data !== null && "error" in data) {
|
|
80
|
-
|
|
86
|
+
const errorStr =
|
|
87
|
+
typeof data.error === "string"
|
|
88
|
+
? data.error
|
|
89
|
+
: typeof data.error === "object" &&
|
|
90
|
+
data.error &&
|
|
91
|
+
"message" in data.error &&
|
|
92
|
+
typeof data.error.message === "string"
|
|
93
|
+
? data.error.message
|
|
94
|
+
: JSON.stringify(data.error);
|
|
95
|
+
throw new Error(`Error forwarded from backend: ` + errorStr);
|
|
81
96
|
}
|
|
82
97
|
yield data as T;
|
|
83
98
|
}
|
|
@@ -8,6 +8,15 @@ export type TextToImageArgs = BaseArgs & {
|
|
|
8
8
|
*/
|
|
9
9
|
inputs: string;
|
|
10
10
|
|
|
11
|
+
/**
|
|
12
|
+
* Same param but for external providers like Together, Replicate
|
|
13
|
+
*/
|
|
14
|
+
prompt?: string;
|
|
15
|
+
response_format?: "base64";
|
|
16
|
+
input?: {
|
|
17
|
+
prompt: string;
|
|
18
|
+
};
|
|
19
|
+
|
|
11
20
|
parameters?: {
|
|
12
21
|
/**
|
|
13
22
|
* An optional negative prompt for the image generation
|
|
@@ -34,15 +43,49 @@ export type TextToImageArgs = BaseArgs & {
|
|
|
34
43
|
|
|
35
44
|
export type TextToImageOutput = Blob;
|
|
36
45
|
|
|
46
|
+
interface Base64ImageGeneration {
|
|
47
|
+
data: Array<{
|
|
48
|
+
b64_json: string;
|
|
49
|
+
}>;
|
|
50
|
+
}
|
|
51
|
+
interface OutputUrlImageGeneration {
|
|
52
|
+
output: string[];
|
|
53
|
+
}
|
|
54
|
+
|
|
37
55
|
/**
|
|
38
56
|
* This task reads some text input and outputs an image.
|
|
39
57
|
* Recommended model: stabilityai/stable-diffusion-2
|
|
40
58
|
*/
|
|
41
59
|
export async function textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageOutput> {
|
|
42
|
-
|
|
60
|
+
if (args.provider === "together" || args.provider === "fal-ai") {
|
|
61
|
+
args.prompt = args.inputs;
|
|
62
|
+
args.inputs = "";
|
|
63
|
+
args.response_format = "base64";
|
|
64
|
+
} else if (args.provider === "replicate") {
|
|
65
|
+
args.input = { prompt: args.inputs };
|
|
66
|
+
delete (args as unknown as { inputs: unknown }).inputs;
|
|
67
|
+
}
|
|
68
|
+
const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(args, {
|
|
43
69
|
...options,
|
|
44
70
|
taskHint: "text-to-image",
|
|
45
71
|
});
|
|
72
|
+
if (res && typeof res === "object") {
|
|
73
|
+
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
|
|
74
|
+
const image = await fetch(res.images[0].url);
|
|
75
|
+
return await image.blob();
|
|
76
|
+
}
|
|
77
|
+
if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
|
|
78
|
+
const base64Data = res.data[0].b64_json;
|
|
79
|
+
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
80
|
+
const blob = await base64Response.blob();
|
|
81
|
+
return blob;
|
|
82
|
+
}
|
|
83
|
+
if ("output" in res && Array.isArray(res.output)) {
|
|
84
|
+
const urlResponse = await fetch(res.output[0]);
|
|
85
|
+
const blob = await urlResponse.blob();
|
|
86
|
+
return blob;
|
|
87
|
+
}
|
|
88
|
+
}
|
|
46
89
|
const isValidOutput = res && res instanceof Blob;
|
|
47
90
|
if (!isValidOutput) {
|
|
48
91
|
throw new InferenceOutputError("Expected Blob");
|
|
@@ -6,7 +6,6 @@ import type { ChatCompletionInput, ChatCompletionOutput } from "@huggingface/tas
|
|
|
6
6
|
/**
|
|
7
7
|
* Use the chat completion endpoint to generate a response to a prompt, using OpenAI message completion API no stream
|
|
8
8
|
*/
|
|
9
|
-
|
|
10
9
|
export async function chatCompletion(
|
|
11
10
|
args: BaseArgs & ChatCompletionInput,
|
|
12
11
|
options?: Options
|
|
@@ -22,7 +21,8 @@ export async function chatCompletion(
|
|
|
22
21
|
typeof res?.created === "number" &&
|
|
23
22
|
typeof res?.id === "string" &&
|
|
24
23
|
typeof res?.model === "string" &&
|
|
25
|
-
|
|
24
|
+
/// Together.ai does not output a system_fingerprint
|
|
25
|
+
(res.system_fingerprint === undefined || typeof res.system_fingerprint === "string") &&
|
|
26
26
|
typeof res?.usage === "object";
|
|
27
27
|
|
|
28
28
|
if (!isValidOutput) {
|
|
@@ -1,4 +1,9 @@
|
|
|
1
|
-
import type {
|
|
1
|
+
import type {
|
|
2
|
+
ChatCompletionOutput,
|
|
3
|
+
TextGenerationInput,
|
|
4
|
+
TextGenerationOutput,
|
|
5
|
+
TextGenerationOutputFinishReason,
|
|
6
|
+
} from "@huggingface/tasks";
|
|
2
7
|
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
|
3
8
|
import type { BaseArgs, Options } from "../../types";
|
|
4
9
|
import { toArray } from "../../utils/toArray";
|
|
@@ -6,6 +11,16 @@ import { request } from "../custom/request";
|
|
|
6
11
|
|
|
7
12
|
export type { TextGenerationInput, TextGenerationOutput };
|
|
8
13
|
|
|
14
|
+
interface TogeteherTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
|
|
15
|
+
choices: Array<{
|
|
16
|
+
text: string;
|
|
17
|
+
finish_reason: TextGenerationOutputFinishReason;
|
|
18
|
+
seed: number;
|
|
19
|
+
logprobs: unknown;
|
|
20
|
+
index: number;
|
|
21
|
+
}>;
|
|
22
|
+
}
|
|
23
|
+
|
|
9
24
|
/**
|
|
10
25
|
* Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
|
|
11
26
|
*/
|
|
@@ -13,15 +28,34 @@ export async function textGeneration(
|
|
|
13
28
|
args: BaseArgs & TextGenerationInput,
|
|
14
29
|
options?: Options
|
|
15
30
|
): Promise<TextGenerationOutput> {
|
|
16
|
-
|
|
17
|
-
|
|
31
|
+
if (args.provider === "together") {
|
|
32
|
+
args.prompt = args.inputs;
|
|
33
|
+
const raw = await request<TogeteherTextCompletionOutput>(args, {
|
|
18
34
|
...options,
|
|
19
35
|
taskHint: "text-generation",
|
|
20
|
-
})
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
36
|
+
});
|
|
37
|
+
const isValidOutput =
|
|
38
|
+
typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
|
|
39
|
+
if (!isValidOutput) {
|
|
40
|
+
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
41
|
+
}
|
|
42
|
+
const completion = raw.choices[0];
|
|
43
|
+
return {
|
|
44
|
+
generated_text: completion.text,
|
|
45
|
+
};
|
|
46
|
+
} else {
|
|
47
|
+
const res = toArray(
|
|
48
|
+
await request<TextGenerationOutput | TextGenerationOutput[]>(args, {
|
|
49
|
+
...options,
|
|
50
|
+
taskHint: "text-generation",
|
|
51
|
+
})
|
|
52
|
+
);
|
|
53
|
+
|
|
54
|
+
const isValidOutput =
|
|
55
|
+
Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
|
|
56
|
+
if (!isValidOutput) {
|
|
57
|
+
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
58
|
+
}
|
|
59
|
+
return (res as TextGenerationOutput[])?.[0];
|
|
25
60
|
}
|
|
26
|
-
return res?.[0];
|
|
27
61
|
}
|
package/src/types.ts
CHANGED
|
@@ -1,6 +1,11 @@
|
|
|
1
1
|
import type { PipelineType } from "@huggingface/tasks";
|
|
2
2
|
import type { ChatCompletionInput } from "@huggingface/tasks";
|
|
3
3
|
|
|
4
|
+
/**
|
|
5
|
+
* HF model id, like "meta-llama/Llama-3.3-70B-Instruct"
|
|
6
|
+
*/
|
|
7
|
+
export type ModelId = string;
|
|
8
|
+
|
|
4
9
|
export interface Options {
|
|
5
10
|
/**
|
|
6
11
|
* (Default: true) Boolean. If a request 503s and wait_for_model is set to false, the request will be retried with the same parameters but with wait_for_model set to true.
|
|
@@ -40,22 +45,28 @@ export interface Options {
|
|
|
40
45
|
|
|
41
46
|
export type InferenceTask = Exclude<PipelineType, "other">;
|
|
42
47
|
|
|
48
|
+
export const INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-inference"] as const;
|
|
49
|
+
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
|
|
50
|
+
|
|
43
51
|
export interface BaseArgs {
|
|
44
52
|
/**
|
|
45
53
|
* The access token to use. Without it, you'll get rate-limited quickly.
|
|
46
54
|
*
|
|
47
55
|
* Can be created for free in hf.co/settings/token
|
|
56
|
+
*
|
|
57
|
+
* You can also pass an external Inference provider's key if you intend to call a compatible provider like Sambanova, Together, Replicate...
|
|
48
58
|
*/
|
|
49
59
|
accessToken?: string;
|
|
60
|
+
|
|
50
61
|
/**
|
|
51
|
-
* The model to use.
|
|
62
|
+
* The HF model to use.
|
|
52
63
|
*
|
|
53
64
|
* If not specified, will call huggingface.co/api/tasks to get the default model for the task.
|
|
54
65
|
*
|
|
55
66
|
* /!\ Legacy behavior allows this to be an URL, but this is deprecated and will be removed in the future.
|
|
56
67
|
* Use the `endpointUrl` parameter instead.
|
|
57
68
|
*/
|
|
58
|
-
model?:
|
|
69
|
+
model?: ModelId;
|
|
59
70
|
|
|
60
71
|
/**
|
|
61
72
|
* The URL of the endpoint to use. If not specified, will call huggingface.co/api/tasks to get the default endpoint for the task.
|
|
@@ -63,6 +74,13 @@ export interface BaseArgs {
|
|
|
63
74
|
* If specified, will use this URL instead of the default one.
|
|
64
75
|
*/
|
|
65
76
|
endpointUrl?: string;
|
|
77
|
+
|
|
78
|
+
/**
|
|
79
|
+
* Set an Inference provider to run this model on.
|
|
80
|
+
*
|
|
81
|
+
* Defaults to the first provider in your user settings that is compatible with this model.
|
|
82
|
+
*/
|
|
83
|
+
provider?: InferenceProvider;
|
|
66
84
|
}
|
|
67
85
|
|
|
68
86
|
export type RequestArgs = BaseArgs &
|