@huggingface/inference 2.8.0 → 3.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +1 -1
- package/README.md +39 -16
- package/dist/index.cjs +364 -134
- package/dist/index.js +359 -134
- package/dist/src/config.d.ts +3 -0
- package/dist/src/config.d.ts.map +1 -0
- package/dist/src/index.d.ts +5 -0
- package/dist/src/index.d.ts.map +1 -1
- package/dist/src/lib/getDefaultTask.d.ts +0 -1
- package/dist/src/lib/getDefaultTask.d.ts.map +1 -1
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/fal-ai.d.ts +6 -0
- package/dist/src/providers/fal-ai.d.ts.map +1 -0
- package/dist/src/providers/replicate.d.ts +6 -0
- package/dist/src/providers/replicate.d.ts.map +1 -0
- package/dist/src/providers/sambanova.d.ts +6 -0
- package/dist/src/providers/sambanova.d.ts.map +1 -0
- package/dist/src/providers/together.d.ts +12 -0
- package/dist/src/providers/together.d.ts.map +1 -0
- package/dist/src/providers/types.d.ts +4 -0
- package/dist/src/providers/types.d.ts.map +1 -0
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts +1 -1
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts +8 -0
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/types.d.ts +16 -2
- package/dist/src/types.d.ts.map +1 -1
- package/package.json +2 -2
- package/src/config.ts +2 -0
- package/src/index.ts +5 -0
- package/src/lib/getDefaultTask.ts +1 -1
- package/src/lib/makeRequestOptions.ts +199 -59
- package/src/providers/fal-ai.ts +15 -0
- package/src/providers/replicate.ts +16 -0
- package/src/providers/sambanova.ts +23 -0
- package/src/providers/together.ts +58 -0
- package/src/providers/types.ts +6 -0
- package/src/tasks/audio/automaticSpeechRecognition.ts +10 -1
- package/src/tasks/custom/request.ts +12 -6
- package/src/tasks/custom/streamingRequest.ts +18 -3
- package/src/tasks/cv/textToImage.ts +44 -1
- package/src/tasks/nlp/chatCompletion.ts +2 -2
- package/src/tasks/nlp/textGeneration.ts +43 -9
- package/src/types.ts +20 -2
package/dist/index.js
CHANGED
|
@@ -40,131 +40,164 @@ __export(tasks_exports, {
|
|
|
40
40
|
zeroShotImageClassification: () => zeroShotImageClassification
|
|
41
41
|
});
|
|
42
42
|
|
|
43
|
-
// src/
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
{},
|
|
47
|
-
...props.map((prop) => {
|
|
48
|
-
if (o[prop] !== void 0) {
|
|
49
|
-
return { [prop]: o[prop] };
|
|
50
|
-
}
|
|
51
|
-
})
|
|
52
|
-
);
|
|
53
|
-
}
|
|
43
|
+
// src/config.ts
|
|
44
|
+
var HF_HUB_URL = "https://huggingface.co";
|
|
45
|
+
var HF_INFERENCE_API_URL = "https://api-inference.huggingface.co";
|
|
54
46
|
|
|
55
|
-
// src/
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
47
|
+
// src/providers/fal-ai.ts
|
|
48
|
+
var FAL_AI_API_BASE_URL = "https://fal.run";
|
|
49
|
+
var FAL_AI_SUPPORTED_MODEL_IDS = {
|
|
50
|
+
"text-to-image": {
|
|
51
|
+
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
|
|
52
|
+
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev"
|
|
53
|
+
},
|
|
54
|
+
"automatic-speech-recognition": {
|
|
55
|
+
"openai/whisper-large-v3": "fal-ai/whisper"
|
|
56
|
+
}
|
|
57
|
+
};
|
|
59
58
|
|
|
60
|
-
// src/
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
59
|
+
// src/providers/replicate.ts
|
|
60
|
+
var REPLICATE_API_BASE_URL = "https://api.replicate.com";
|
|
61
|
+
var REPLICATE_SUPPORTED_MODEL_IDS = {
|
|
62
|
+
"text-to-image": {
|
|
63
|
+
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
|
|
64
|
+
"ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637"
|
|
65
|
+
}
|
|
66
|
+
// "text-to-speech": {
|
|
67
|
+
// "SWivid/F5-TTS": "x-lance/f5-tts:87faf6dd7a692dd82043f662e76369cab126a2cf1937e25a9d41e0b834fd230e"
|
|
68
|
+
// },
|
|
69
|
+
};
|
|
70
|
+
|
|
71
|
+
// src/providers/sambanova.ts
|
|
72
|
+
var SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
|
|
73
|
+
var SAMBANOVA_SUPPORTED_MODEL_IDS = {
|
|
74
|
+
/** Chat completion / conversational */
|
|
75
|
+
conversational: {
|
|
76
|
+
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
|
|
77
|
+
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
|
|
78
|
+
"Qwen/QwQ-32B-Preview": "QwQ-32B-Preview",
|
|
79
|
+
"meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct",
|
|
80
|
+
"meta-llama/Llama-3.2-1B-Instruct": "Meta-Llama-3.2-1B-Instruct",
|
|
81
|
+
"meta-llama/Llama-3.2-3B-Instruct": "Meta-Llama-3.2-3B-Instruct",
|
|
82
|
+
"meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
|
|
83
|
+
"meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct",
|
|
84
|
+
"meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
|
|
85
|
+
"meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct",
|
|
86
|
+
"meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct",
|
|
87
|
+
"meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B"
|
|
88
|
+
}
|
|
89
|
+
};
|
|
90
|
+
|
|
91
|
+
// src/providers/together.ts
|
|
92
|
+
var TOGETHER_API_BASE_URL = "https://api.together.xyz";
|
|
93
|
+
var TOGETHER_SUPPORTED_MODEL_IDS = {
|
|
94
|
+
"text-to-image": {
|
|
95
|
+
"black-forest-labs/FLUX.1-Canny-dev": "black-forest-labs/FLUX.1-canny",
|
|
96
|
+
"black-forest-labs/FLUX.1-Depth-dev": "black-forest-labs/FLUX.1-depth",
|
|
97
|
+
"black-forest-labs/FLUX.1-dev": "black-forest-labs/FLUX.1-dev",
|
|
98
|
+
"black-forest-labs/FLUX.1-Redux-dev": "black-forest-labs/FLUX.1-redux",
|
|
99
|
+
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/FLUX.1-pro",
|
|
100
|
+
"stabilityai/stable-diffusion-xl-base-1.0": "stabilityai/stable-diffusion-xl-base-1.0"
|
|
101
|
+
},
|
|
102
|
+
conversational: {
|
|
103
|
+
"databricks/dbrx-instruct": "databricks/dbrx-instruct",
|
|
104
|
+
"deepseek-ai/deepseek-llm-67b-chat": "deepseek-ai/deepseek-llm-67b-chat",
|
|
105
|
+
"google/gemma-2-9b-it": "google/gemma-2-9b-it",
|
|
106
|
+
"google/gemma-2b-it": "google/gemma-2-27b-it",
|
|
107
|
+
"llava-hf/llava-v1.6-mistral-7b-hf": "llava-hf/llava-v1.6-mistral-7b-hf",
|
|
108
|
+
"meta-llama/Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
|
|
109
|
+
"meta-llama/Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
|
|
110
|
+
"meta-llama/Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
|
|
111
|
+
"meta-llama/Llama-3.2-11B-Vision-Instruct": "meta-llama/Llama-Vision-Free",
|
|
112
|
+
"meta-llama/Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
|
113
|
+
"meta-llama/Llama-3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
|
114
|
+
"meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
|
115
|
+
"meta-llama/Meta-Llama-3-70B-Instruct": "meta-llama/Llama-3-70b-chat-hf",
|
|
116
|
+
"meta-llama/Meta-Llama-3-8B-Instruct": "togethercomputer/Llama-3-8b-chat-hf-int4",
|
|
117
|
+
"meta-llama/Meta-Llama-3.1-405B-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
|
118
|
+
"meta-llama/Meta-Llama-3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
|
119
|
+
"meta-llama/Meta-Llama-3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K",
|
|
120
|
+
"microsoft/WizardLM-2-8x22B": "microsoft/WizardLM-2-8x22B",
|
|
121
|
+
"mistralai/Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3",
|
|
122
|
+
"mistralai/Mixtral-8x22B-Instruct-v0.1": "mistralai/Mixtral-8x22B-Instruct-v0.1",
|
|
123
|
+
"mistralai/Mixtral-8x7B-Instruct-v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
124
|
+
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
|
|
125
|
+
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
|
|
126
|
+
"Qwen/Qwen2-72B-Instruct": "Qwen/Qwen2-72B-Instruct",
|
|
127
|
+
"Qwen/Qwen2.5-72B-Instruct": "Qwen/Qwen2.5-72B-Instruct-Turbo",
|
|
128
|
+
"Qwen/Qwen2.5-7B-Instruct": "Qwen/Qwen2.5-7B-Instruct-Turbo",
|
|
129
|
+
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen/Qwen2.5-Coder-32B-Instruct",
|
|
130
|
+
"Qwen/QwQ-32B-Preview": "Qwen/QwQ-32B-Preview",
|
|
131
|
+
"scb10x/llama-3-typhoon-v1.5-8b-instruct": "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct",
|
|
132
|
+
"scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": "scb10x/scb10x-llama3-typhoon-v1-5x-4f316"
|
|
133
|
+
},
|
|
134
|
+
"text-generation": {
|
|
135
|
+
"meta-llama/Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B",
|
|
136
|
+
"mistralai/Mixtral-8x7B-v0.1": "mistralai/Mixtral-8x7B-v0.1"
|
|
137
|
+
}
|
|
138
|
+
};
|
|
66
139
|
|
|
67
140
|
// src/lib/isUrl.ts
|
|
68
141
|
function isUrl(modelOrUrl) {
|
|
69
142
|
return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
|
|
70
143
|
}
|
|
71
144
|
|
|
72
|
-
// src/lib/getDefaultTask.ts
|
|
73
|
-
var taskCache = /* @__PURE__ */ new Map();
|
|
74
|
-
var CACHE_DURATION = 10 * 60 * 1e3;
|
|
75
|
-
var MAX_CACHE_ITEMS = 1e3;
|
|
76
|
-
var HF_HUB_URL = "https://huggingface.co";
|
|
77
|
-
async function getDefaultTask(model, accessToken, options) {
|
|
78
|
-
if (isUrl(model)) {
|
|
79
|
-
return null;
|
|
80
|
-
}
|
|
81
|
-
const key = `${model}:${accessToken}`;
|
|
82
|
-
let cachedTask = taskCache.get(key);
|
|
83
|
-
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
|
|
84
|
-
taskCache.delete(key);
|
|
85
|
-
cachedTask = void 0;
|
|
86
|
-
}
|
|
87
|
-
if (cachedTask === void 0) {
|
|
88
|
-
const modelTask = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
|
|
89
|
-
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
|
|
90
|
-
}).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
|
|
91
|
-
if (!modelTask) {
|
|
92
|
-
return null;
|
|
93
|
-
}
|
|
94
|
-
cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
|
|
95
|
-
taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
|
|
96
|
-
if (taskCache.size > MAX_CACHE_ITEMS) {
|
|
97
|
-
taskCache.delete(taskCache.keys().next().value);
|
|
98
|
-
}
|
|
99
|
-
}
|
|
100
|
-
return cachedTask.task;
|
|
101
|
-
}
|
|
102
|
-
|
|
103
145
|
// src/lib/makeRequestOptions.ts
|
|
104
|
-
var
|
|
146
|
+
var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_HUB_URL}/api/inference-proxy/{{PROVIDER}}`;
|
|
105
147
|
var tasks = null;
|
|
106
148
|
async function makeRequestOptions(args, options) {
|
|
107
|
-
const { accessToken, endpointUrl, ...otherArgs } = args;
|
|
108
|
-
|
|
109
|
-
const {
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
taskHint,
|
|
113
|
-
wait_for_model,
|
|
114
|
-
use_cache,
|
|
115
|
-
dont_load_model,
|
|
116
|
-
chatCompletion: chatCompletion2
|
|
117
|
-
} = options ?? {};
|
|
118
|
-
const headers = {};
|
|
119
|
-
if (accessToken) {
|
|
120
|
-
headers["Authorization"] = `Bearer ${accessToken}`;
|
|
149
|
+
const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...otherArgs } = args;
|
|
150
|
+
const provider = maybeProvider ?? "hf-inference";
|
|
151
|
+
const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion: chatCompletion2 } = options ?? {};
|
|
152
|
+
if (endpointUrl && provider !== "hf-inference") {
|
|
153
|
+
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
|
|
121
154
|
}
|
|
122
|
-
if (
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
155
|
+
if (forceTask && provider !== "hf-inference") {
|
|
156
|
+
throw new Error(`Cannot use forceTask with a third-party provider.`);
|
|
157
|
+
}
|
|
158
|
+
if (maybeModel && isUrl(maybeModel)) {
|
|
159
|
+
throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
|
|
127
160
|
}
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
if (
|
|
131
|
-
model =
|
|
161
|
+
let model;
|
|
162
|
+
if (!maybeModel) {
|
|
163
|
+
if (taskHint) {
|
|
164
|
+
model = mapModel({ model: await loadDefaultModel(taskHint), provider, taskHint, chatCompletion: chatCompletion2 });
|
|
165
|
+
} else {
|
|
166
|
+
throw new Error("No model provided, and no default model found for this task");
|
|
132
167
|
}
|
|
168
|
+
} else {
|
|
169
|
+
model = mapModel({ model: maybeModel, provider, taskHint, chatCompletion: chatCompletion2 });
|
|
133
170
|
}
|
|
134
|
-
|
|
135
|
-
|
|
171
|
+
const authMethod = accessToken ? accessToken.startsWith("hf_") ? "hf-token" : "provider-key" : includeCredentials === "include" ? "credentials-include" : "none";
|
|
172
|
+
const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : makeUrl({
|
|
173
|
+
authMethod,
|
|
174
|
+
chatCompletion: chatCompletion2 ?? false,
|
|
175
|
+
forceTask,
|
|
176
|
+
model,
|
|
177
|
+
provider: provider ?? "hf-inference",
|
|
178
|
+
taskHint
|
|
179
|
+
});
|
|
180
|
+
const headers = {};
|
|
181
|
+
if (accessToken) {
|
|
182
|
+
headers["Authorization"] = provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
|
|
136
183
|
}
|
|
137
184
|
const binary = "data" in args && !!args.data;
|
|
138
185
|
if (!binary) {
|
|
139
186
|
headers["Content-Type"] = "application/json";
|
|
140
187
|
}
|
|
141
|
-
if (
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
if (use_cache === false) {
|
|
145
|
-
headers["X-Use-Cache"] = "false";
|
|
146
|
-
}
|
|
147
|
-
if (dont_load_model) {
|
|
148
|
-
headers["X-Load-Model"] = "0";
|
|
149
|
-
}
|
|
150
|
-
let url = (() => {
|
|
151
|
-
if (endpointUrl && isUrl(model)) {
|
|
152
|
-
throw new TypeError("Both model and endpointUrl cannot be URLs");
|
|
153
|
-
}
|
|
154
|
-
if (isUrl(model)) {
|
|
155
|
-
console.warn("Using a model URL is deprecated, please use the `endpointUrl` parameter instead");
|
|
156
|
-
return model;
|
|
188
|
+
if (provider === "hf-inference") {
|
|
189
|
+
if (wait_for_model) {
|
|
190
|
+
headers["X-Wait-For-Model"] = "true";
|
|
157
191
|
}
|
|
158
|
-
if (
|
|
159
|
-
|
|
192
|
+
if (use_cache === false) {
|
|
193
|
+
headers["X-Use-Cache"] = "false";
|
|
160
194
|
}
|
|
161
|
-
if (
|
|
162
|
-
|
|
195
|
+
if (dont_load_model) {
|
|
196
|
+
headers["X-Load-Model"] = "0";
|
|
163
197
|
}
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
url += "/v1/chat/completions";
|
|
198
|
+
}
|
|
199
|
+
if (provider === "replicate") {
|
|
200
|
+
headers["Prefer"] = "wait";
|
|
168
201
|
}
|
|
169
202
|
let credentials;
|
|
170
203
|
if (typeof includeCredentials === "string") {
|
|
@@ -172,17 +205,110 @@ async function makeRequestOptions(args, options) {
|
|
|
172
205
|
} else if (includeCredentials === true) {
|
|
173
206
|
credentials = "include";
|
|
174
207
|
}
|
|
208
|
+
if (provider === "replicate" && model.includes(":")) {
|
|
209
|
+
const version = model.split(":")[1];
|
|
210
|
+
otherArgs.version = version;
|
|
211
|
+
}
|
|
175
212
|
const info = {
|
|
176
213
|
headers,
|
|
177
214
|
method: "POST",
|
|
178
215
|
body: binary ? args.data : JSON.stringify({
|
|
179
|
-
...otherArgs
|
|
216
|
+
...otherArgs,
|
|
217
|
+
...chatCompletion2 || provider === "together" ? { model } : void 0
|
|
180
218
|
}),
|
|
181
|
-
...credentials
|
|
219
|
+
...credentials ? { credentials } : void 0,
|
|
182
220
|
signal: options?.signal
|
|
183
221
|
};
|
|
184
222
|
return { url, info };
|
|
185
223
|
}
|
|
224
|
+
function mapModel(params) {
|
|
225
|
+
if (params.provider === "hf-inference") {
|
|
226
|
+
return params.model;
|
|
227
|
+
}
|
|
228
|
+
if (!params.taskHint) {
|
|
229
|
+
throw new Error("taskHint must be specified when using a third-party provider");
|
|
230
|
+
}
|
|
231
|
+
const task = params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint;
|
|
232
|
+
const model = (() => {
|
|
233
|
+
switch (params.provider) {
|
|
234
|
+
case "fal-ai":
|
|
235
|
+
return FAL_AI_SUPPORTED_MODEL_IDS[task]?.[params.model];
|
|
236
|
+
case "replicate":
|
|
237
|
+
return REPLICATE_SUPPORTED_MODEL_IDS[task]?.[params.model];
|
|
238
|
+
case "sambanova":
|
|
239
|
+
return SAMBANOVA_SUPPORTED_MODEL_IDS[task]?.[params.model];
|
|
240
|
+
case "together":
|
|
241
|
+
return TOGETHER_SUPPORTED_MODEL_IDS[task]?.[params.model];
|
|
242
|
+
}
|
|
243
|
+
})();
|
|
244
|
+
if (!model) {
|
|
245
|
+
throw new Error(`Model ${params.model} is not supported for task ${task} and provider ${params.provider}`);
|
|
246
|
+
}
|
|
247
|
+
return model;
|
|
248
|
+
}
|
|
249
|
+
function makeUrl(params) {
|
|
250
|
+
if (params.authMethod === "none" && params.provider !== "hf-inference") {
|
|
251
|
+
throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
|
|
252
|
+
}
|
|
253
|
+
const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
|
|
254
|
+
switch (params.provider) {
|
|
255
|
+
case "fal-ai": {
|
|
256
|
+
const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FAL_AI_API_BASE_URL;
|
|
257
|
+
return `${baseUrl}/${params.model}`;
|
|
258
|
+
}
|
|
259
|
+
case "replicate": {
|
|
260
|
+
const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : REPLICATE_API_BASE_URL;
|
|
261
|
+
if (params.model.includes(":")) {
|
|
262
|
+
return `${baseUrl}/v1/predictions`;
|
|
263
|
+
}
|
|
264
|
+
return `${baseUrl}/v1/models/${params.model}/predictions`;
|
|
265
|
+
}
|
|
266
|
+
case "sambanova": {
|
|
267
|
+
const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : SAMBANOVA_API_BASE_URL;
|
|
268
|
+
if (params.taskHint === "text-generation" && params.chatCompletion) {
|
|
269
|
+
return `${baseUrl}/v1/chat/completions`;
|
|
270
|
+
}
|
|
271
|
+
return baseUrl;
|
|
272
|
+
}
|
|
273
|
+
case "together": {
|
|
274
|
+
const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : TOGETHER_API_BASE_URL;
|
|
275
|
+
if (params.taskHint === "text-to-image") {
|
|
276
|
+
return `${baseUrl}/v1/images/generations`;
|
|
277
|
+
}
|
|
278
|
+
if (params.taskHint === "text-generation") {
|
|
279
|
+
if (params.chatCompletion) {
|
|
280
|
+
return `${baseUrl}/v1/chat/completions`;
|
|
281
|
+
}
|
|
282
|
+
return `${baseUrl}/v1/completions`;
|
|
283
|
+
}
|
|
284
|
+
return baseUrl;
|
|
285
|
+
}
|
|
286
|
+
default: {
|
|
287
|
+
const url = params.forceTask ? `${HF_INFERENCE_API_URL}/pipeline/${params.forceTask}/${params.model}` : `${HF_INFERENCE_API_URL}/models/${params.model}`;
|
|
288
|
+
if (params.taskHint === "text-generation" && params.chatCompletion) {
|
|
289
|
+
return url + `/v1/chat/completions`;
|
|
290
|
+
}
|
|
291
|
+
return url;
|
|
292
|
+
}
|
|
293
|
+
}
|
|
294
|
+
}
|
|
295
|
+
async function loadDefaultModel(task) {
|
|
296
|
+
if (!tasks) {
|
|
297
|
+
tasks = await loadTaskInfo();
|
|
298
|
+
}
|
|
299
|
+
const taskInfo = tasks[task];
|
|
300
|
+
if ((taskInfo?.models.length ?? 0) <= 0) {
|
|
301
|
+
throw new Error(`No default model defined for task ${task}, please define the model explicitly.`);
|
|
302
|
+
}
|
|
303
|
+
return taskInfo.models[0].id;
|
|
304
|
+
}
|
|
305
|
+
async function loadTaskInfo() {
|
|
306
|
+
const res = await fetch(`${HF_HUB_URL}/api/tasks`);
|
|
307
|
+
if (!res.ok) {
|
|
308
|
+
throw new Error("Failed to load tasks definitions from Hugging Face Hub.");
|
|
309
|
+
}
|
|
310
|
+
return await res.json();
|
|
311
|
+
}
|
|
186
312
|
|
|
187
313
|
// src/tasks/custom/request.ts
|
|
188
314
|
async function request(args, options) {
|
|
@@ -195,16 +321,22 @@ async function request(args, options) {
|
|
|
195
321
|
});
|
|
196
322
|
}
|
|
197
323
|
if (!response.ok) {
|
|
198
|
-
|
|
324
|
+
const contentType = response.headers.get("Content-Type");
|
|
325
|
+
if (["application/json", "application/problem+json"].some((ct) => contentType?.startsWith(ct))) {
|
|
199
326
|
const output = await response.json();
|
|
200
327
|
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
|
|
201
|
-
throw new Error(
|
|
328
|
+
throw new Error(
|
|
329
|
+
`Server ${args.model} does not seem to support chat completion. Error: ${JSON.stringify(output.error)}`
|
|
330
|
+
);
|
|
202
331
|
}
|
|
203
|
-
if (output.error) {
|
|
204
|
-
throw new Error(output.error);
|
|
332
|
+
if (output.error || output.detail) {
|
|
333
|
+
throw new Error(JSON.stringify(output.error ?? output.detail));
|
|
334
|
+
} else {
|
|
335
|
+
throw new Error(output);
|
|
205
336
|
}
|
|
206
337
|
}
|
|
207
|
-
|
|
338
|
+
const message = contentType?.startsWith("text/plain;") ? await response.text() : void 0;
|
|
339
|
+
throw new Error(message ?? "An error occurred while fetching the blob");
|
|
208
340
|
}
|
|
209
341
|
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
|
|
210
342
|
return await response.json();
|
|
@@ -327,9 +459,12 @@ async function* streamingRequest(args, options) {
|
|
|
327
459
|
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
|
|
328
460
|
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
|
|
329
461
|
}
|
|
330
|
-
if (output.error) {
|
|
462
|
+
if (typeof output.error === "string") {
|
|
331
463
|
throw new Error(output.error);
|
|
332
464
|
}
|
|
465
|
+
if (output.error && "message" in output.error && typeof output.error.message === "string") {
|
|
466
|
+
throw new Error(output.error.message);
|
|
467
|
+
}
|
|
333
468
|
}
|
|
334
469
|
throw new Error(`Server response contains error: ${response.status}`);
|
|
335
470
|
}
|
|
@@ -358,8 +493,9 @@ async function* streamingRequest(args, options) {
|
|
|
358
493
|
try {
|
|
359
494
|
while (true) {
|
|
360
495
|
const { done, value } = await reader.read();
|
|
361
|
-
if (done)
|
|
496
|
+
if (done) {
|
|
362
497
|
return;
|
|
498
|
+
}
|
|
363
499
|
onChunk(value);
|
|
364
500
|
for (const event of events) {
|
|
365
501
|
if (event.data.length > 0) {
|
|
@@ -368,7 +504,8 @@ async function* streamingRequest(args, options) {
|
|
|
368
504
|
}
|
|
369
505
|
const data = JSON.parse(event.data);
|
|
370
506
|
if (typeof data === "object" && data !== null && "error" in data) {
|
|
371
|
-
|
|
507
|
+
const errorStr = typeof data.error === "string" ? data.error : typeof data.error === "object" && data.error && "message" in data.error && typeof data.error.message === "string" ? data.error.message : JSON.stringify(data.error);
|
|
508
|
+
throw new Error(`Error forwarded from backend: ` + errorStr);
|
|
372
509
|
}
|
|
373
510
|
yield data;
|
|
374
511
|
}
|
|
@@ -403,8 +540,29 @@ async function audioClassification(args, options) {
|
|
|
403
540
|
return res;
|
|
404
541
|
}
|
|
405
542
|
|
|
543
|
+
// src/utils/base64FromBytes.ts
|
|
544
|
+
function base64FromBytes(arr) {
|
|
545
|
+
if (globalThis.Buffer) {
|
|
546
|
+
return globalThis.Buffer.from(arr).toString("base64");
|
|
547
|
+
} else {
|
|
548
|
+
const bin = [];
|
|
549
|
+
arr.forEach((byte) => {
|
|
550
|
+
bin.push(String.fromCharCode(byte));
|
|
551
|
+
});
|
|
552
|
+
return globalThis.btoa(bin.join(""));
|
|
553
|
+
}
|
|
554
|
+
}
|
|
555
|
+
|
|
406
556
|
// src/tasks/audio/automaticSpeechRecognition.ts
|
|
407
557
|
async function automaticSpeechRecognition(args, options) {
|
|
558
|
+
if (args.provider === "fal-ai") {
|
|
559
|
+
const contentType = args.data instanceof Blob ? args.data.type : "audio/mpeg";
|
|
560
|
+
const base64audio = base64FromBytes(
|
|
561
|
+
new Uint8Array(args.data instanceof ArrayBuffer ? args.data : await args.data.arrayBuffer())
|
|
562
|
+
);
|
|
563
|
+
args.audio_url = `data:${contentType};base64,${base64audio}`;
|
|
564
|
+
delete args.data;
|
|
565
|
+
}
|
|
408
566
|
const res = await request(args, {
|
|
409
567
|
...options,
|
|
410
568
|
taskHint: "automatic-speech-recognition"
|
|
@@ -501,10 +659,35 @@ async function objectDetection(args, options) {
|
|
|
501
659
|
|
|
502
660
|
// src/tasks/cv/textToImage.ts
|
|
503
661
|
async function textToImage(args, options) {
|
|
662
|
+
if (args.provider === "together" || args.provider === "fal-ai") {
|
|
663
|
+
args.prompt = args.inputs;
|
|
664
|
+
args.inputs = "";
|
|
665
|
+
args.response_format = "base64";
|
|
666
|
+
} else if (args.provider === "replicate") {
|
|
667
|
+
args.input = { prompt: args.inputs };
|
|
668
|
+
delete args.inputs;
|
|
669
|
+
}
|
|
504
670
|
const res = await request(args, {
|
|
505
671
|
...options,
|
|
506
672
|
taskHint: "text-to-image"
|
|
507
673
|
});
|
|
674
|
+
if (res && typeof res === "object") {
|
|
675
|
+
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
|
|
676
|
+
const image = await fetch(res.images[0].url);
|
|
677
|
+
return await image.blob();
|
|
678
|
+
}
|
|
679
|
+
if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
|
|
680
|
+
const base64Data = res.data[0].b64_json;
|
|
681
|
+
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
682
|
+
const blob = await base64Response.blob();
|
|
683
|
+
return blob;
|
|
684
|
+
}
|
|
685
|
+
if ("output" in res && Array.isArray(res.output)) {
|
|
686
|
+
const urlResponse = await fetch(res.output[0]);
|
|
687
|
+
const blob = await urlResponse.blob();
|
|
688
|
+
return blob;
|
|
689
|
+
}
|
|
690
|
+
}
|
|
508
691
|
const isValidOutput = res && res instanceof Blob;
|
|
509
692
|
if (!isValidOutput) {
|
|
510
693
|
throw new InferenceOutputError("Expected Blob");
|
|
@@ -512,19 +695,6 @@ async function textToImage(args, options) {
|
|
|
512
695
|
return res;
|
|
513
696
|
}
|
|
514
697
|
|
|
515
|
-
// src/utils/base64FromBytes.ts
|
|
516
|
-
function base64FromBytes(arr) {
|
|
517
|
-
if (globalThis.Buffer) {
|
|
518
|
-
return globalThis.Buffer.from(arr).toString("base64");
|
|
519
|
-
} else {
|
|
520
|
-
const bin = [];
|
|
521
|
-
arr.forEach((byte) => {
|
|
522
|
-
bin.push(String.fromCharCode(byte));
|
|
523
|
-
});
|
|
524
|
-
return globalThis.btoa(bin.join(""));
|
|
525
|
-
}
|
|
526
|
-
}
|
|
527
|
-
|
|
528
698
|
// src/tasks/cv/imageToImage.ts
|
|
529
699
|
async function imageToImage(args, options) {
|
|
530
700
|
let reqArgs;
|
|
@@ -576,6 +746,36 @@ async function zeroShotImageClassification(args, options) {
|
|
|
576
746
|
return res;
|
|
577
747
|
}
|
|
578
748
|
|
|
749
|
+
// src/lib/getDefaultTask.ts
|
|
750
|
+
var taskCache = /* @__PURE__ */ new Map();
|
|
751
|
+
var CACHE_DURATION = 10 * 60 * 1e3;
|
|
752
|
+
var MAX_CACHE_ITEMS = 1e3;
|
|
753
|
+
async function getDefaultTask(model, accessToken, options) {
|
|
754
|
+
if (isUrl(model)) {
|
|
755
|
+
return null;
|
|
756
|
+
}
|
|
757
|
+
const key = `${model}:${accessToken}`;
|
|
758
|
+
let cachedTask = taskCache.get(key);
|
|
759
|
+
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
|
|
760
|
+
taskCache.delete(key);
|
|
761
|
+
cachedTask = void 0;
|
|
762
|
+
}
|
|
763
|
+
if (cachedTask === void 0) {
|
|
764
|
+
const modelTask = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
|
|
765
|
+
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
|
|
766
|
+
}).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
|
|
767
|
+
if (!modelTask) {
|
|
768
|
+
return null;
|
|
769
|
+
}
|
|
770
|
+
cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
|
|
771
|
+
taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
|
|
772
|
+
if (taskCache.size > MAX_CACHE_ITEMS) {
|
|
773
|
+
taskCache.delete(taskCache.keys().next().value);
|
|
774
|
+
}
|
|
775
|
+
}
|
|
776
|
+
return cachedTask.task;
|
|
777
|
+
}
|
|
778
|
+
|
|
579
779
|
// src/tasks/nlp/featureExtraction.ts
|
|
580
780
|
async function featureExtraction(args, options) {
|
|
581
781
|
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : void 0;
|
|
@@ -697,17 +897,33 @@ function toArray(obj) {
|
|
|
697
897
|
|
|
698
898
|
// src/tasks/nlp/textGeneration.ts
|
|
699
899
|
async function textGeneration(args, options) {
|
|
700
|
-
|
|
701
|
-
|
|
900
|
+
if (args.provider === "together") {
|
|
901
|
+
args.prompt = args.inputs;
|
|
902
|
+
const raw = await request(args, {
|
|
702
903
|
...options,
|
|
703
904
|
taskHint: "text-generation"
|
|
704
|
-
})
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
905
|
+
});
|
|
906
|
+
const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
|
|
907
|
+
if (!isValidOutput) {
|
|
908
|
+
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
909
|
+
}
|
|
910
|
+
const completion = raw.choices[0];
|
|
911
|
+
return {
|
|
912
|
+
generated_text: completion.text
|
|
913
|
+
};
|
|
914
|
+
} else {
|
|
915
|
+
const res = toArray(
|
|
916
|
+
await request(args, {
|
|
917
|
+
...options,
|
|
918
|
+
taskHint: "text-generation"
|
|
919
|
+
})
|
|
920
|
+
);
|
|
921
|
+
const isValidOutput = Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
|
|
922
|
+
if (!isValidOutput) {
|
|
923
|
+
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
924
|
+
}
|
|
925
|
+
return res?.[0];
|
|
709
926
|
}
|
|
710
|
-
return res?.[0];
|
|
711
927
|
}
|
|
712
928
|
|
|
713
929
|
// src/tasks/nlp/textGenerationStream.ts
|
|
@@ -774,7 +990,8 @@ async function chatCompletion(args, options) {
|
|
|
774
990
|
taskHint: "text-generation",
|
|
775
991
|
chatCompletion: true
|
|
776
992
|
});
|
|
777
|
-
const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" &&
|
|
993
|
+
const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai does not output a system_fingerprint
|
|
994
|
+
(res.system_fingerprint === void 0 || typeof res.system_fingerprint === "string") && typeof res?.usage === "object";
|
|
778
995
|
if (!isValidOutput) {
|
|
779
996
|
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
780
997
|
}
|
|
@@ -907,10 +1124,18 @@ var HfInferenceEndpoint = class {
|
|
|
907
1124
|
}
|
|
908
1125
|
}
|
|
909
1126
|
};
|
|
1127
|
+
|
|
1128
|
+
// src/types.ts
|
|
1129
|
+
var INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-inference"];
|
|
910
1130
|
export {
|
|
1131
|
+
FAL_AI_SUPPORTED_MODEL_IDS,
|
|
911
1132
|
HfInference,
|
|
912
1133
|
HfInferenceEndpoint,
|
|
1134
|
+
INFERENCE_PROVIDERS,
|
|
913
1135
|
InferenceOutputError,
|
|
1136
|
+
REPLICATE_SUPPORTED_MODEL_IDS,
|
|
1137
|
+
SAMBANOVA_SUPPORTED_MODEL_IDS,
|
|
1138
|
+
TOGETHER_SUPPORTED_MODEL_IDS,
|
|
914
1139
|
audioClassification,
|
|
915
1140
|
audioToAudio,
|
|
916
1141
|
automaticSpeechRecognition,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"config.d.ts","sourceRoot":"","sources":["../../src/config.ts"],"names":[],"mappings":"AAAA,eAAO,MAAM,UAAU,2BAA2B,CAAC;AACnD,eAAO,MAAM,oBAAoB,yCAAyC,CAAC"}
|
package/dist/src/index.d.ts
CHANGED
|
@@ -1,5 +1,10 @@
|
|
|
1
|
+
export type { ProviderMapping } from "./providers/types";
|
|
1
2
|
export { HfInference, HfInferenceEndpoint } from "./HfInference";
|
|
2
3
|
export { InferenceOutputError } from "./lib/InferenceOutputError";
|
|
4
|
+
export { FAL_AI_SUPPORTED_MODEL_IDS } from "./providers/fal-ai";
|
|
5
|
+
export { REPLICATE_SUPPORTED_MODEL_IDS } from "./providers/replicate";
|
|
6
|
+
export { SAMBANOVA_SUPPORTED_MODEL_IDS } from "./providers/sambanova";
|
|
7
|
+
export { TOGETHER_SUPPORTED_MODEL_IDS } from "./providers/together";
|
|
3
8
|
export * from "./types";
|
|
4
9
|
export * from "./tasks";
|
|
5
10
|
//# sourceMappingURL=index.d.ts.map
|
package/dist/src/index.d.ts.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,WAAW,EAAE,mBAAmB,EAAE,MAAM,eAAe,CAAC;AACjE,OAAO,EAAE,oBAAoB,EAAE,MAAM,4BAA4B,CAAC;AAClE,cAAc,SAAS,CAAC;AACxB,cAAc,SAAS,CAAC"}
|
|
1
|
+
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../src/index.ts"],"names":[],"mappings":"AAAA,YAAY,EAAE,eAAe,EAAE,MAAM,mBAAmB,CAAC;AACzD,OAAO,EAAE,WAAW,EAAE,mBAAmB,EAAE,MAAM,eAAe,CAAC;AACjE,OAAO,EAAE,oBAAoB,EAAE,MAAM,4BAA4B,CAAC;AAClE,OAAO,EAAE,0BAA0B,EAAE,MAAM,oBAAoB,CAAC;AAChE,OAAO,EAAE,6BAA6B,EAAE,MAAM,uBAAuB,CAAC;AACtE,OAAO,EAAE,6BAA6B,EAAE,MAAM,uBAAuB,CAAC;AACtE,OAAO,EAAE,4BAA4B,EAAE,MAAM,sBAAsB,CAAC;AACpE,cAAc,SAAS,CAAC;AACxB,cAAc,SAAS,CAAC"}
|
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"getDefaultTask.d.ts","sourceRoot":"","sources":["../../../src/lib/getDefaultTask.ts"],"names":[],"mappings":"
|
|
1
|
+
{"version":3,"file":"getDefaultTask.d.ts","sourceRoot":"","sources":["../../../src/lib/getDefaultTask.ts"],"names":[],"mappings":"AAYA,MAAM,WAAW,kBAAkB;IAClC,KAAK,CAAC,EAAE,OAAO,KAAK,CAAC;CACrB;AAED;;;;;GAKG;AACH,wBAAsB,cAAc,CACnC,KAAK,EAAE,MAAM,EACb,WAAW,EAAE,MAAM,GAAG,SAAS,EAC/B,OAAO,CAAC,EAAE,kBAAkB,GAC1B,OAAO,CAAC,MAAM,GAAG,IAAI,CAAC,CAkCxB"}
|