@huggingface/inference 2.8.1 → 3.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. package/LICENSE +1 -1
  2. package/README.md +45 -17
  3. package/dist/index.cjs +388 -134
  4. package/dist/index.js +383 -134
  5. package/dist/src/config.d.ts +3 -0
  6. package/dist/src/config.d.ts.map +1 -0
  7. package/dist/src/index.d.ts +5 -0
  8. package/dist/src/index.d.ts.map +1 -1
  9. package/dist/src/lib/getDefaultTask.d.ts +0 -1
  10. package/dist/src/lib/getDefaultTask.d.ts.map +1 -1
  11. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  12. package/dist/src/providers/fal-ai.d.ts +6 -0
  13. package/dist/src/providers/fal-ai.d.ts.map +1 -0
  14. package/dist/src/providers/replicate.d.ts +6 -0
  15. package/dist/src/providers/replicate.d.ts.map +1 -0
  16. package/dist/src/providers/sambanova.d.ts +6 -0
  17. package/dist/src/providers/sambanova.d.ts.map +1 -0
  18. package/dist/src/providers/together.d.ts +12 -0
  19. package/dist/src/providers/together.d.ts.map +1 -0
  20. package/dist/src/providers/types.d.ts +4 -0
  21. package/dist/src/providers/types.d.ts.map +1 -0
  22. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  23. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  24. package/dist/src/tasks/custom/request.d.ts +1 -1
  25. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  26. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  27. package/dist/src/tasks/cv/textToImage.d.ts +8 -0
  28. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  29. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  30. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  31. package/dist/src/types.d.ts +16 -2
  32. package/dist/src/types.d.ts.map +1 -1
  33. package/package.json +2 -2
  34. package/src/config.ts +2 -0
  35. package/src/index.ts +5 -0
  36. package/src/lib/getDefaultTask.ts +1 -1
  37. package/src/lib/makeRequestOptions.ts +201 -59
  38. package/src/providers/fal-ai.ts +23 -0
  39. package/src/providers/replicate.ts +16 -0
  40. package/src/providers/sambanova.ts +23 -0
  41. package/src/providers/together.ts +60 -0
  42. package/src/providers/types.ts +6 -0
  43. package/src/tasks/audio/automaticSpeechRecognition.ts +10 -1
  44. package/src/tasks/audio/textToSpeech.ts +17 -2
  45. package/src/tasks/custom/request.ts +12 -6
  46. package/src/tasks/custom/streamingRequest.ts +18 -3
  47. package/src/tasks/cv/textToImage.ts +44 -1
  48. package/src/tasks/nlp/chatCompletion.ts +2 -2
  49. package/src/tasks/nlp/textGeneration.ts +43 -9
  50. package/src/types.ts +20 -2
package/dist/index.js CHANGED
@@ -40,131 +40,175 @@ __export(tasks_exports, {
40
40
  zeroShotImageClassification: () => zeroShotImageClassification
41
41
  });
42
42
 
43
- // src/utils/pick.ts
44
- function pick(o, props) {
45
- return Object.assign(
46
- {},
47
- ...props.map((prop) => {
48
- if (o[prop] !== void 0) {
49
- return { [prop]: o[prop] };
50
- }
51
- })
52
- );
53
- }
43
+ // src/config.ts
44
+ var HF_HUB_URL = "https://huggingface.co";
45
+ var HF_INFERENCE_API_URL = "https://api-inference.huggingface.co";
54
46
 
55
- // src/utils/typedInclude.ts
56
- function typedInclude(arr, v) {
57
- return arr.includes(v);
58
- }
47
+ // src/providers/fal-ai.ts
48
+ var FAL_AI_API_BASE_URL = "https://fal.run";
49
+ var FAL_AI_SUPPORTED_MODEL_IDS = {
50
+ "text-to-image": {
51
+ "black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
52
+ "black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
53
+ "playgroundai/playground-v2.5-1024px-aesthetic": "fal-ai/playground-v25",
54
+ "ByteDance/SDXL-Lightning": "fal-ai/lightning-models",
55
+ "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "fal-ai/pixart-sigma",
56
+ "stabilityai/stable-diffusion-3-medium": "fal-ai/stable-diffusion-v3-medium",
57
+ "Warlord-K/Sana-1024": "fal-ai/sana",
58
+ "fal/AuraFlow-v0.2": "fal-ai/aura-flow",
59
+ "stabilityai/stable-diffusion-3.5-large": "fal-ai/stable-diffusion-v35-large",
60
+ "Kwai-Kolors/Kolors": "fal-ai/kolors"
61
+ },
62
+ "automatic-speech-recognition": {
63
+ "openai/whisper-large-v3": "fal-ai/whisper"
64
+ }
65
+ };
59
66
 
60
- // src/utils/omit.ts
61
- function omit(o, props) {
62
- const propsArr = Array.isArray(props) ? props : [props];
63
- const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
64
- return pick(o, letsKeep);
65
- }
67
+ // src/providers/replicate.ts
68
+ var REPLICATE_API_BASE_URL = "https://api.replicate.com";
69
+ var REPLICATE_SUPPORTED_MODEL_IDS = {
70
+ "text-to-image": {
71
+ "black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
72
+ "ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637"
73
+ },
74
+ "text-to-speech": {
75
+ "OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts:39a59319327b27327fa3095149c5a746e7f2aee18c75055c3368237a6503cd26"
76
+ }
77
+ };
78
+
79
+ // src/providers/sambanova.ts
80
+ var SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
81
+ var SAMBANOVA_SUPPORTED_MODEL_IDS = {
82
+ /** Chat completion / conversational */
83
+ conversational: {
84
+ "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
85
+ "Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
86
+ "Qwen/QwQ-32B-Preview": "QwQ-32B-Preview",
87
+ "meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct",
88
+ "meta-llama/Llama-3.2-1B-Instruct": "Meta-Llama-3.2-1B-Instruct",
89
+ "meta-llama/Llama-3.2-3B-Instruct": "Meta-Llama-3.2-3B-Instruct",
90
+ "meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
91
+ "meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct",
92
+ "meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
93
+ "meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct",
94
+ "meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct",
95
+ "meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B"
96
+ }
97
+ };
98
+
99
+ // src/providers/together.ts
100
+ var TOGETHER_API_BASE_URL = "https://api.together.xyz";
101
+ var TOGETHER_SUPPORTED_MODEL_IDS = {
102
+ "text-to-image": {
103
+ "black-forest-labs/FLUX.1-Canny-dev": "black-forest-labs/FLUX.1-canny",
104
+ "black-forest-labs/FLUX.1-Depth-dev": "black-forest-labs/FLUX.1-depth",
105
+ "black-forest-labs/FLUX.1-dev": "black-forest-labs/FLUX.1-dev",
106
+ "black-forest-labs/FLUX.1-Redux-dev": "black-forest-labs/FLUX.1-redux",
107
+ "black-forest-labs/FLUX.1-schnell": "black-forest-labs/FLUX.1-pro",
108
+ "stabilityai/stable-diffusion-xl-base-1.0": "stabilityai/stable-diffusion-xl-base-1.0"
109
+ },
110
+ conversational: {
111
+ "databricks/dbrx-instruct": "databricks/dbrx-instruct",
112
+ "deepseek-ai/DeepSeek-R1": "deepseek-ai/DeepSeek-R1",
113
+ "deepseek-ai/DeepSeek-V3": "deepseek-ai/DeepSeek-V3",
114
+ "deepseek-ai/deepseek-llm-67b-chat": "deepseek-ai/deepseek-llm-67b-chat",
115
+ "google/gemma-2-9b-it": "google/gemma-2-9b-it",
116
+ "google/gemma-2b-it": "google/gemma-2-27b-it",
117
+ "llava-hf/llava-v1.6-mistral-7b-hf": "llava-hf/llava-v1.6-mistral-7b-hf",
118
+ "meta-llama/Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
119
+ "meta-llama/Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
120
+ "meta-llama/Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
121
+ "meta-llama/Llama-3.2-11B-Vision-Instruct": "meta-llama/Llama-Vision-Free",
122
+ "meta-llama/Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
123
+ "meta-llama/Llama-3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
124
+ "meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
125
+ "meta-llama/Meta-Llama-3-70B-Instruct": "meta-llama/Llama-3-70b-chat-hf",
126
+ "meta-llama/Meta-Llama-3-8B-Instruct": "togethercomputer/Llama-3-8b-chat-hf-int4",
127
+ "meta-llama/Meta-Llama-3.1-405B-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
128
+ "meta-llama/Meta-Llama-3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
129
+ "meta-llama/Meta-Llama-3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K",
130
+ "microsoft/WizardLM-2-8x22B": "microsoft/WizardLM-2-8x22B",
131
+ "mistralai/Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3",
132
+ "mistralai/Mixtral-8x22B-Instruct-v0.1": "mistralai/Mixtral-8x22B-Instruct-v0.1",
133
+ "mistralai/Mixtral-8x7B-Instruct-v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1",
134
+ "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
135
+ "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
136
+ "Qwen/Qwen2-72B-Instruct": "Qwen/Qwen2-72B-Instruct",
137
+ "Qwen/Qwen2.5-72B-Instruct": "Qwen/Qwen2.5-72B-Instruct-Turbo",
138
+ "Qwen/Qwen2.5-7B-Instruct": "Qwen/Qwen2.5-7B-Instruct-Turbo",
139
+ "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen/Qwen2.5-Coder-32B-Instruct",
140
+ "Qwen/QwQ-32B-Preview": "Qwen/QwQ-32B-Preview",
141
+ "scb10x/llama-3-typhoon-v1.5-8b-instruct": "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct",
142
+ "scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": "scb10x/scb10x-llama3-typhoon-v1-5x-4f316"
143
+ },
144
+ "text-generation": {
145
+ "meta-llama/Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B",
146
+ "mistralai/Mixtral-8x7B-v0.1": "mistralai/Mixtral-8x7B-v0.1"
147
+ }
148
+ };
66
149
 
67
150
  // src/lib/isUrl.ts
68
151
  function isUrl(modelOrUrl) {
69
152
  return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
70
153
  }
71
154
 
72
- // src/lib/getDefaultTask.ts
73
- var taskCache = /* @__PURE__ */ new Map();
74
- var CACHE_DURATION = 10 * 60 * 1e3;
75
- var MAX_CACHE_ITEMS = 1e3;
76
- var HF_HUB_URL = "https://huggingface.co";
77
- async function getDefaultTask(model, accessToken, options) {
78
- if (isUrl(model)) {
79
- return null;
80
- }
81
- const key = `${model}:${accessToken}`;
82
- let cachedTask = taskCache.get(key);
83
- if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
84
- taskCache.delete(key);
85
- cachedTask = void 0;
86
- }
87
- if (cachedTask === void 0) {
88
- const modelTask = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
89
- headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
90
- }).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
91
- if (!modelTask) {
92
- return null;
93
- }
94
- cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
95
- taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
96
- if (taskCache.size > MAX_CACHE_ITEMS) {
97
- taskCache.delete(taskCache.keys().next().value);
98
- }
99
- }
100
- return cachedTask.task;
101
- }
102
-
103
155
  // src/lib/makeRequestOptions.ts
104
- var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
156
+ var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_HUB_URL}/api/inference-proxy/{{PROVIDER}}`;
105
157
  var tasks = null;
106
158
  async function makeRequestOptions(args, options) {
107
- const { accessToken, endpointUrl, ...otherArgs } = args;
108
- let { model } = args;
109
- const {
110
- forceTask: task,
111
- includeCredentials,
112
- taskHint,
113
- wait_for_model,
114
- use_cache,
115
- dont_load_model,
116
- chatCompletion: chatCompletion2
117
- } = options ?? {};
118
- const headers = {};
119
- if (accessToken) {
120
- headers["Authorization"] = `Bearer ${accessToken}`;
159
+ const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
160
+ let otherArgs = remainingArgs;
161
+ const provider = maybeProvider ?? "hf-inference";
162
+ const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion: chatCompletion2 } = options ?? {};
163
+ if (endpointUrl && provider !== "hf-inference") {
164
+ throw new Error(`Cannot use endpointUrl with a third-party provider.`);
121
165
  }
122
- if (!model && !tasks && taskHint) {
123
- const res = await fetch(`${HF_HUB_URL}/api/tasks`);
124
- if (res.ok) {
125
- tasks = await res.json();
126
- }
166
+ if (forceTask && provider !== "hf-inference") {
167
+ throw new Error(`Cannot use forceTask with a third-party provider.`);
168
+ }
169
+ if (maybeModel && isUrl(maybeModel)) {
170
+ throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
127
171
  }
128
- if (!model && tasks && taskHint) {
129
- const taskInfo = tasks[taskHint];
130
- if (taskInfo) {
131
- model = taskInfo.models[0].id;
172
+ let model;
173
+ if (!maybeModel) {
174
+ if (taskHint) {
175
+ model = mapModel({ model: await loadDefaultModel(taskHint), provider, taskHint, chatCompletion: chatCompletion2 });
176
+ } else {
177
+ throw new Error("No model provided, and no default model found for this task");
132
178
  }
179
+ } else {
180
+ model = mapModel({ model: maybeModel, provider, taskHint, chatCompletion: chatCompletion2 });
133
181
  }
134
- if (!model) {
135
- throw new Error("No model provided, and no default model found for this task");
182
+ const authMethod = accessToken ? accessToken.startsWith("hf_") ? "hf-token" : "provider-key" : includeCredentials === "include" ? "credentials-include" : "none";
183
+ const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : makeUrl({
184
+ authMethod,
185
+ chatCompletion: chatCompletion2 ?? false,
186
+ forceTask,
187
+ model,
188
+ provider: provider ?? "hf-inference",
189
+ taskHint
190
+ });
191
+ const headers = {};
192
+ if (accessToken) {
193
+ headers["Authorization"] = provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
136
194
  }
137
195
  const binary = "data" in args && !!args.data;
138
196
  if (!binary) {
139
197
  headers["Content-Type"] = "application/json";
140
198
  }
141
- if (wait_for_model) {
142
- headers["X-Wait-For-Model"] = "true";
143
- }
144
- if (use_cache === false) {
145
- headers["X-Use-Cache"] = "false";
146
- }
147
- if (dont_load_model) {
148
- headers["X-Load-Model"] = "0";
149
- }
150
- let url = (() => {
151
- if (endpointUrl && isUrl(model)) {
152
- throw new TypeError("Both model and endpointUrl cannot be URLs");
199
+ if (provider === "hf-inference") {
200
+ if (wait_for_model) {
201
+ headers["X-Wait-For-Model"] = "true";
153
202
  }
154
- if (isUrl(model)) {
155
- console.warn("Using a model URL is deprecated, please use the `endpointUrl` parameter instead");
156
- return model;
203
+ if (use_cache === false) {
204
+ headers["X-Use-Cache"] = "false";
157
205
  }
158
- if (endpointUrl) {
159
- return endpointUrl;
206
+ if (dont_load_model) {
207
+ headers["X-Load-Model"] = "0";
160
208
  }
161
- if (task) {
162
- return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
163
- }
164
- return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
165
- })();
166
- if (chatCompletion2 && !url.endsWith("/chat/completions")) {
167
- url += "/v1/chat/completions";
209
+ }
210
+ if (provider === "replicate") {
211
+ headers["Prefer"] = "wait";
168
212
  }
169
213
  let credentials;
170
214
  if (typeof includeCredentials === "string") {
@@ -172,17 +216,110 @@ async function makeRequestOptions(args, options) {
172
216
  } else if (includeCredentials === true) {
173
217
  credentials = "include";
174
218
  }
219
+ if (provider === "replicate") {
220
+ const version = model.includes(":") ? model.split(":")[1] : void 0;
221
+ otherArgs = { input: otherArgs, version };
222
+ }
175
223
  const info = {
176
224
  headers,
177
225
  method: "POST",
178
226
  body: binary ? args.data : JSON.stringify({
179
- ...otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs
227
+ ...otherArgs,
228
+ ...chatCompletion2 || provider === "together" ? { model } : void 0
180
229
  }),
181
- ...credentials && { credentials },
230
+ ...credentials ? { credentials } : void 0,
182
231
  signal: options?.signal
183
232
  };
184
233
  return { url, info };
185
234
  }
235
+ function mapModel(params) {
236
+ if (params.provider === "hf-inference") {
237
+ return params.model;
238
+ }
239
+ if (!params.taskHint) {
240
+ throw new Error("taskHint must be specified when using a third-party provider");
241
+ }
242
+ const task = params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint;
243
+ const model = (() => {
244
+ switch (params.provider) {
245
+ case "fal-ai":
246
+ return FAL_AI_SUPPORTED_MODEL_IDS[task]?.[params.model];
247
+ case "replicate":
248
+ return REPLICATE_SUPPORTED_MODEL_IDS[task]?.[params.model];
249
+ case "sambanova":
250
+ return SAMBANOVA_SUPPORTED_MODEL_IDS[task]?.[params.model];
251
+ case "together":
252
+ return TOGETHER_SUPPORTED_MODEL_IDS[task]?.[params.model];
253
+ }
254
+ })();
255
+ if (!model) {
256
+ throw new Error(`Model ${params.model} is not supported for task ${task} and provider ${params.provider}`);
257
+ }
258
+ return model;
259
+ }
260
+ function makeUrl(params) {
261
+ if (params.authMethod === "none" && params.provider !== "hf-inference") {
262
+ throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
263
+ }
264
+ const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
265
+ switch (params.provider) {
266
+ case "fal-ai": {
267
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FAL_AI_API_BASE_URL;
268
+ return `${baseUrl}/${params.model}`;
269
+ }
270
+ case "replicate": {
271
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : REPLICATE_API_BASE_URL;
272
+ if (params.model.includes(":")) {
273
+ return `${baseUrl}/v1/predictions`;
274
+ }
275
+ return `${baseUrl}/v1/models/${params.model}/predictions`;
276
+ }
277
+ case "sambanova": {
278
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : SAMBANOVA_API_BASE_URL;
279
+ if (params.taskHint === "text-generation" && params.chatCompletion) {
280
+ return `${baseUrl}/v1/chat/completions`;
281
+ }
282
+ return baseUrl;
283
+ }
284
+ case "together": {
285
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : TOGETHER_API_BASE_URL;
286
+ if (params.taskHint === "text-to-image") {
287
+ return `${baseUrl}/v1/images/generations`;
288
+ }
289
+ if (params.taskHint === "text-generation") {
290
+ if (params.chatCompletion) {
291
+ return `${baseUrl}/v1/chat/completions`;
292
+ }
293
+ return `${baseUrl}/v1/completions`;
294
+ }
295
+ return baseUrl;
296
+ }
297
+ default: {
298
+ const url = params.forceTask ? `${HF_INFERENCE_API_URL}/pipeline/${params.forceTask}/${params.model}` : `${HF_INFERENCE_API_URL}/models/${params.model}`;
299
+ if (params.taskHint === "text-generation" && params.chatCompletion) {
300
+ return url + `/v1/chat/completions`;
301
+ }
302
+ return url;
303
+ }
304
+ }
305
+ }
306
+ async function loadDefaultModel(task) {
307
+ if (!tasks) {
308
+ tasks = await loadTaskInfo();
309
+ }
310
+ const taskInfo = tasks[task];
311
+ if ((taskInfo?.models.length ?? 0) <= 0) {
312
+ throw new Error(`No default model defined for task ${task}, please define the model explicitly.`);
313
+ }
314
+ return taskInfo.models[0].id;
315
+ }
316
+ async function loadTaskInfo() {
317
+ const res = await fetch(`${HF_HUB_URL}/api/tasks`);
318
+ if (!res.ok) {
319
+ throw new Error("Failed to load tasks definitions from Hugging Face Hub.");
320
+ }
321
+ return await res.json();
322
+ }
186
323
 
187
324
  // src/tasks/custom/request.ts
188
325
  async function request(args, options) {
@@ -195,16 +332,22 @@ async function request(args, options) {
195
332
  });
196
333
  }
197
334
  if (!response.ok) {
198
- if (response.headers.get("Content-Type")?.startsWith("application/json")) {
335
+ const contentType = response.headers.get("Content-Type");
336
+ if (["application/json", "application/problem+json"].some((ct) => contentType?.startsWith(ct))) {
199
337
  const output = await response.json();
200
338
  if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
201
- throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
339
+ throw new Error(
340
+ `Server ${args.model} does not seem to support chat completion. Error: ${JSON.stringify(output.error)}`
341
+ );
202
342
  }
203
- if (output.error) {
204
- throw new Error(JSON.stringify(output.error));
343
+ if (output.error || output.detail) {
344
+ throw new Error(JSON.stringify(output.error ?? output.detail));
345
+ } else {
346
+ throw new Error(output);
205
347
  }
206
348
  }
207
- throw new Error("An error occurred while fetching the blob");
349
+ const message = contentType?.startsWith("text/plain;") ? await response.text() : void 0;
350
+ throw new Error(message ?? "An error occurred while fetching the blob");
208
351
  }
209
352
  if (response.headers.get("Content-Type")?.startsWith("application/json")) {
210
353
  return await response.json();
@@ -327,9 +470,12 @@ async function* streamingRequest(args, options) {
327
470
  if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
328
471
  throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
329
472
  }
330
- if (output.error) {
473
+ if (typeof output.error === "string") {
331
474
  throw new Error(output.error);
332
475
  }
476
+ if (output.error && "message" in output.error && typeof output.error.message === "string") {
477
+ throw new Error(output.error.message);
478
+ }
333
479
  }
334
480
  throw new Error(`Server response contains error: ${response.status}`);
335
481
  }
@@ -358,8 +504,9 @@ async function* streamingRequest(args, options) {
358
504
  try {
359
505
  while (true) {
360
506
  const { done, value } = await reader.read();
361
- if (done)
507
+ if (done) {
362
508
  return;
509
+ }
363
510
  onChunk(value);
364
511
  for (const event of events) {
365
512
  if (event.data.length > 0) {
@@ -368,7 +515,8 @@ async function* streamingRequest(args, options) {
368
515
  }
369
516
  const data = JSON.parse(event.data);
370
517
  if (typeof data === "object" && data !== null && "error" in data) {
371
- throw new Error(data.error);
518
+ const errorStr = typeof data.error === "string" ? data.error : typeof data.error === "object" && data.error && "message" in data.error && typeof data.error.message === "string" ? data.error.message : JSON.stringify(data.error);
519
+ throw new Error(`Error forwarded from backend: ` + errorStr);
372
520
  }
373
521
  yield data;
374
522
  }
@@ -403,8 +551,29 @@ async function audioClassification(args, options) {
403
551
  return res;
404
552
  }
405
553
 
554
+ // src/utils/base64FromBytes.ts
555
+ function base64FromBytes(arr) {
556
+ if (globalThis.Buffer) {
557
+ return globalThis.Buffer.from(arr).toString("base64");
558
+ } else {
559
+ const bin = [];
560
+ arr.forEach((byte) => {
561
+ bin.push(String.fromCharCode(byte));
562
+ });
563
+ return globalThis.btoa(bin.join(""));
564
+ }
565
+ }
566
+
406
567
  // src/tasks/audio/automaticSpeechRecognition.ts
407
568
  async function automaticSpeechRecognition(args, options) {
569
+ if (args.provider === "fal-ai") {
570
+ const contentType = args.data instanceof Blob ? args.data.type : "audio/mpeg";
571
+ const base64audio = base64FromBytes(
572
+ new Uint8Array(args.data instanceof ArrayBuffer ? args.data : await args.data.arrayBuffer())
573
+ );
574
+ args.audio_url = `data:${contentType};base64,${base64audio}`;
575
+ delete args.data;
576
+ }
408
577
  const res = await request(args, {
409
578
  ...options,
410
579
  taskHint: "automatic-speech-recognition"
@@ -422,6 +591,19 @@ async function textToSpeech(args, options) {
422
591
  ...options,
423
592
  taskHint: "text-to-speech"
424
593
  });
594
+ if (res && typeof res === "object") {
595
+ if ("output" in res) {
596
+ if (typeof res.output === "string") {
597
+ const urlResponse = await fetch(res.output);
598
+ const blob = await urlResponse.blob();
599
+ return blob;
600
+ } else if (Array.isArray(res.output)) {
601
+ const urlResponse = await fetch(res.output[0]);
602
+ const blob = await urlResponse.blob();
603
+ return blob;
604
+ }
605
+ }
606
+ }
425
607
  const isValidOutput = res && res instanceof Blob;
426
608
  if (!isValidOutput) {
427
609
  throw new InferenceOutputError("Expected Blob");
@@ -501,10 +683,35 @@ async function objectDetection(args, options) {
501
683
 
502
684
  // src/tasks/cv/textToImage.ts
503
685
  async function textToImage(args, options) {
686
+ if (args.provider === "together" || args.provider === "fal-ai") {
687
+ args.prompt = args.inputs;
688
+ delete args.inputs;
689
+ args.response_format = "base64";
690
+ } else if (args.provider === "replicate") {
691
+ args.prompt = args.inputs;
692
+ delete args.inputs;
693
+ }
504
694
  const res = await request(args, {
505
695
  ...options,
506
696
  taskHint: "text-to-image"
507
697
  });
698
+ if (res && typeof res === "object") {
699
+ if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
700
+ const image = await fetch(res.images[0].url);
701
+ return await image.blob();
702
+ }
703
+ if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
704
+ const base64Data = res.data[0].b64_json;
705
+ const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
706
+ const blob = await base64Response.blob();
707
+ return blob;
708
+ }
709
+ if ("output" in res && Array.isArray(res.output)) {
710
+ const urlResponse = await fetch(res.output[0]);
711
+ const blob = await urlResponse.blob();
712
+ return blob;
713
+ }
714
+ }
508
715
  const isValidOutput = res && res instanceof Blob;
509
716
  if (!isValidOutput) {
510
717
  throw new InferenceOutputError("Expected Blob");
@@ -512,19 +719,6 @@ async function textToImage(args, options) {
512
719
  return res;
513
720
  }
514
721
 
515
- // src/utils/base64FromBytes.ts
516
- function base64FromBytes(arr) {
517
- if (globalThis.Buffer) {
518
- return globalThis.Buffer.from(arr).toString("base64");
519
- } else {
520
- const bin = [];
521
- arr.forEach((byte) => {
522
- bin.push(String.fromCharCode(byte));
523
- });
524
- return globalThis.btoa(bin.join(""));
525
- }
526
- }
527
-
528
722
  // src/tasks/cv/imageToImage.ts
529
723
  async function imageToImage(args, options) {
530
724
  let reqArgs;
@@ -576,6 +770,36 @@ async function zeroShotImageClassification(args, options) {
576
770
  return res;
577
771
  }
578
772
 
773
+ // src/lib/getDefaultTask.ts
774
+ var taskCache = /* @__PURE__ */ new Map();
775
+ var CACHE_DURATION = 10 * 60 * 1e3;
776
+ var MAX_CACHE_ITEMS = 1e3;
777
+ async function getDefaultTask(model, accessToken, options) {
778
+ if (isUrl(model)) {
779
+ return null;
780
+ }
781
+ const key = `${model}:${accessToken}`;
782
+ let cachedTask = taskCache.get(key);
783
+ if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
784
+ taskCache.delete(key);
785
+ cachedTask = void 0;
786
+ }
787
+ if (cachedTask === void 0) {
788
+ const modelTask = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
789
+ headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
790
+ }).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
791
+ if (!modelTask) {
792
+ return null;
793
+ }
794
+ cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
795
+ taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
796
+ if (taskCache.size > MAX_CACHE_ITEMS) {
797
+ taskCache.delete(taskCache.keys().next().value);
798
+ }
799
+ }
800
+ return cachedTask.task;
801
+ }
802
+
579
803
  // src/tasks/nlp/featureExtraction.ts
580
804
  async function featureExtraction(args, options) {
581
805
  const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : void 0;
@@ -697,17 +921,33 @@ function toArray(obj) {
697
921
 
698
922
  // src/tasks/nlp/textGeneration.ts
699
923
  async function textGeneration(args, options) {
700
- const res = toArray(
701
- await request(args, {
924
+ if (args.provider === "together") {
925
+ args.prompt = args.inputs;
926
+ const raw = await request(args, {
702
927
  ...options,
703
928
  taskHint: "text-generation"
704
- })
705
- );
706
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string");
707
- if (!isValidOutput) {
708
- throw new InferenceOutputError("Expected Array<{generated_text: string}>");
929
+ });
930
+ const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
931
+ if (!isValidOutput) {
932
+ throw new InferenceOutputError("Expected ChatCompletionOutput");
933
+ }
934
+ const completion = raw.choices[0];
935
+ return {
936
+ generated_text: completion.text
937
+ };
938
+ } else {
939
+ const res = toArray(
940
+ await request(args, {
941
+ ...options,
942
+ taskHint: "text-generation"
943
+ })
944
+ );
945
+ const isValidOutput = Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
946
+ if (!isValidOutput) {
947
+ throw new InferenceOutputError("Expected Array<{generated_text: string}>");
948
+ }
949
+ return res?.[0];
709
950
  }
710
- return res?.[0];
711
951
  }
712
952
 
713
953
  // src/tasks/nlp/textGenerationStream.ts
@@ -774,7 +1014,8 @@ async function chatCompletion(args, options) {
774
1014
  taskHint: "text-generation",
775
1015
  chatCompletion: true
776
1016
  });
777
- const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && typeof res?.system_fingerprint === "string" && typeof res?.usage === "object";
1017
+ const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai does not output a system_fingerprint
1018
+ (res.system_fingerprint === void 0 || typeof res.system_fingerprint === "string") && typeof res?.usage === "object";
778
1019
  if (!isValidOutput) {
779
1020
  throw new InferenceOutputError("Expected ChatCompletionOutput");
780
1021
  }
@@ -907,10 +1148,18 @@ var HfInferenceEndpoint = class {
907
1148
  }
908
1149
  }
909
1150
  };
1151
+
1152
+ // src/types.ts
1153
+ var INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-inference"];
910
1154
  export {
1155
+ FAL_AI_SUPPORTED_MODEL_IDS,
911
1156
  HfInference,
912
1157
  HfInferenceEndpoint,
1158
+ INFERENCE_PROVIDERS,
913
1159
  InferenceOutputError,
1160
+ REPLICATE_SUPPORTED_MODEL_IDS,
1161
+ SAMBANOVA_SUPPORTED_MODEL_IDS,
1162
+ TOGETHER_SUPPORTED_MODEL_IDS,
914
1163
  audioClassification,
915
1164
  audioToAudio,
916
1165
  automaticSpeechRecognition,
@@ -0,0 +1,3 @@
1
+ export declare const HF_HUB_URL = "https://huggingface.co";
2
+ export declare const HF_INFERENCE_API_URL = "https://api-inference.huggingface.co";
3
+ //# sourceMappingURL=config.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"config.d.ts","sourceRoot":"","sources":["../../src/config.ts"],"names":[],"mappings":"AAAA,eAAO,MAAM,UAAU,2BAA2B,CAAC;AACnD,eAAO,MAAM,oBAAoB,yCAAyC,CAAC"}