@huggingface/inference 2.8.1 → 3.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. package/LICENSE +1 -1
  2. package/README.md +45 -17
  3. package/dist/index.cjs +388 -134
  4. package/dist/index.js +383 -134
  5. package/dist/src/config.d.ts +3 -0
  6. package/dist/src/config.d.ts.map +1 -0
  7. package/dist/src/index.d.ts +5 -0
  8. package/dist/src/index.d.ts.map +1 -1
  9. package/dist/src/lib/getDefaultTask.d.ts +0 -1
  10. package/dist/src/lib/getDefaultTask.d.ts.map +1 -1
  11. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  12. package/dist/src/providers/fal-ai.d.ts +6 -0
  13. package/dist/src/providers/fal-ai.d.ts.map +1 -0
  14. package/dist/src/providers/replicate.d.ts +6 -0
  15. package/dist/src/providers/replicate.d.ts.map +1 -0
  16. package/dist/src/providers/sambanova.d.ts +6 -0
  17. package/dist/src/providers/sambanova.d.ts.map +1 -0
  18. package/dist/src/providers/together.d.ts +12 -0
  19. package/dist/src/providers/together.d.ts.map +1 -0
  20. package/dist/src/providers/types.d.ts +4 -0
  21. package/dist/src/providers/types.d.ts.map +1 -0
  22. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  23. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  24. package/dist/src/tasks/custom/request.d.ts +1 -1
  25. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  26. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  27. package/dist/src/tasks/cv/textToImage.d.ts +8 -0
  28. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  29. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  30. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  31. package/dist/src/types.d.ts +16 -2
  32. package/dist/src/types.d.ts.map +1 -1
  33. package/package.json +2 -2
  34. package/src/config.ts +2 -0
  35. package/src/index.ts +5 -0
  36. package/src/lib/getDefaultTask.ts +1 -1
  37. package/src/lib/makeRequestOptions.ts +201 -59
  38. package/src/providers/fal-ai.ts +23 -0
  39. package/src/providers/replicate.ts +16 -0
  40. package/src/providers/sambanova.ts +23 -0
  41. package/src/providers/together.ts +60 -0
  42. package/src/providers/types.ts +6 -0
  43. package/src/tasks/audio/automaticSpeechRecognition.ts +10 -1
  44. package/src/tasks/audio/textToSpeech.ts +17 -2
  45. package/src/tasks/custom/request.ts +12 -6
  46. package/src/tasks/custom/streamingRequest.ts +18 -3
  47. package/src/tasks/cv/textToImage.ts +44 -1
  48. package/src/tasks/nlp/chatCompletion.ts +2 -2
  49. package/src/tasks/nlp/textGeneration.ts +43 -9
  50. package/src/types.ts +20 -2
package/dist/index.cjs CHANGED
@@ -20,9 +20,14 @@ var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: tru
20
20
  // src/index.ts
21
21
  var src_exports = {};
22
22
  __export(src_exports, {
23
+ FAL_AI_SUPPORTED_MODEL_IDS: () => FAL_AI_SUPPORTED_MODEL_IDS,
23
24
  HfInference: () => HfInference,
24
25
  HfInferenceEndpoint: () => HfInferenceEndpoint,
26
+ INFERENCE_PROVIDERS: () => INFERENCE_PROVIDERS,
25
27
  InferenceOutputError: () => InferenceOutputError,
28
+ REPLICATE_SUPPORTED_MODEL_IDS: () => REPLICATE_SUPPORTED_MODEL_IDS,
29
+ SAMBANOVA_SUPPORTED_MODEL_IDS: () => SAMBANOVA_SUPPORTED_MODEL_IDS,
30
+ TOGETHER_SUPPORTED_MODEL_IDS: () => TOGETHER_SUPPORTED_MODEL_IDS,
26
31
  audioClassification: () => audioClassification,
27
32
  audioToAudio: () => audioToAudio,
28
33
  automaticSpeechRecognition: () => automaticSpeechRecognition,
@@ -93,131 +98,175 @@ __export(tasks_exports, {
93
98
  zeroShotImageClassification: () => zeroShotImageClassification
94
99
  });
95
100
 
96
- // src/utils/pick.ts
97
- function pick(o, props) {
98
- return Object.assign(
99
- {},
100
- ...props.map((prop) => {
101
- if (o[prop] !== void 0) {
102
- return { [prop]: o[prop] };
103
- }
104
- })
105
- );
106
- }
101
+ // src/config.ts
102
+ var HF_HUB_URL = "https://huggingface.co";
103
+ var HF_INFERENCE_API_URL = "https://api-inference.huggingface.co";
107
104
 
108
- // src/utils/typedInclude.ts
109
- function typedInclude(arr, v) {
110
- return arr.includes(v);
111
- }
105
+ // src/providers/fal-ai.ts
106
+ var FAL_AI_API_BASE_URL = "https://fal.run";
107
+ var FAL_AI_SUPPORTED_MODEL_IDS = {
108
+ "text-to-image": {
109
+ "black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
110
+ "black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
111
+ "playgroundai/playground-v2.5-1024px-aesthetic": "fal-ai/playground-v25",
112
+ "ByteDance/SDXL-Lightning": "fal-ai/lightning-models",
113
+ "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "fal-ai/pixart-sigma",
114
+ "stabilityai/stable-diffusion-3-medium": "fal-ai/stable-diffusion-v3-medium",
115
+ "Warlord-K/Sana-1024": "fal-ai/sana",
116
+ "fal/AuraFlow-v0.2": "fal-ai/aura-flow",
117
+ "stabilityai/stable-diffusion-3.5-large": "fal-ai/stable-diffusion-v35-large",
118
+ "Kwai-Kolors/Kolors": "fal-ai/kolors"
119
+ },
120
+ "automatic-speech-recognition": {
121
+ "openai/whisper-large-v3": "fal-ai/whisper"
122
+ }
123
+ };
112
124
 
113
- // src/utils/omit.ts
114
- function omit(o, props) {
115
- const propsArr = Array.isArray(props) ? props : [props];
116
- const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
117
- return pick(o, letsKeep);
118
- }
125
+ // src/providers/replicate.ts
126
+ var REPLICATE_API_BASE_URL = "https://api.replicate.com";
127
+ var REPLICATE_SUPPORTED_MODEL_IDS = {
128
+ "text-to-image": {
129
+ "black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
130
+ "ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637"
131
+ },
132
+ "text-to-speech": {
133
+ "OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts:39a59319327b27327fa3095149c5a746e7f2aee18c75055c3368237a6503cd26"
134
+ }
135
+ };
136
+
137
+ // src/providers/sambanova.ts
138
+ var SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
139
+ var SAMBANOVA_SUPPORTED_MODEL_IDS = {
140
+ /** Chat completion / conversational */
141
+ conversational: {
142
+ "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
143
+ "Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
144
+ "Qwen/QwQ-32B-Preview": "QwQ-32B-Preview",
145
+ "meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct",
146
+ "meta-llama/Llama-3.2-1B-Instruct": "Meta-Llama-3.2-1B-Instruct",
147
+ "meta-llama/Llama-3.2-3B-Instruct": "Meta-Llama-3.2-3B-Instruct",
148
+ "meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
149
+ "meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct",
150
+ "meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
151
+ "meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct",
152
+ "meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct",
153
+ "meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B"
154
+ }
155
+ };
156
+
157
+ // src/providers/together.ts
158
+ var TOGETHER_API_BASE_URL = "https://api.together.xyz";
159
+ var TOGETHER_SUPPORTED_MODEL_IDS = {
160
+ "text-to-image": {
161
+ "black-forest-labs/FLUX.1-Canny-dev": "black-forest-labs/FLUX.1-canny",
162
+ "black-forest-labs/FLUX.1-Depth-dev": "black-forest-labs/FLUX.1-depth",
163
+ "black-forest-labs/FLUX.1-dev": "black-forest-labs/FLUX.1-dev",
164
+ "black-forest-labs/FLUX.1-Redux-dev": "black-forest-labs/FLUX.1-redux",
165
+ "black-forest-labs/FLUX.1-schnell": "black-forest-labs/FLUX.1-pro",
166
+ "stabilityai/stable-diffusion-xl-base-1.0": "stabilityai/stable-diffusion-xl-base-1.0"
167
+ },
168
+ conversational: {
169
+ "databricks/dbrx-instruct": "databricks/dbrx-instruct",
170
+ "deepseek-ai/DeepSeek-R1": "deepseek-ai/DeepSeek-R1",
171
+ "deepseek-ai/DeepSeek-V3": "deepseek-ai/DeepSeek-V3",
172
+ "deepseek-ai/deepseek-llm-67b-chat": "deepseek-ai/deepseek-llm-67b-chat",
173
+ "google/gemma-2-9b-it": "google/gemma-2-9b-it",
174
+ "google/gemma-2b-it": "google/gemma-2-27b-it",
175
+ "llava-hf/llava-v1.6-mistral-7b-hf": "llava-hf/llava-v1.6-mistral-7b-hf",
176
+ "meta-llama/Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
177
+ "meta-llama/Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
178
+ "meta-llama/Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
179
+ "meta-llama/Llama-3.2-11B-Vision-Instruct": "meta-llama/Llama-Vision-Free",
180
+ "meta-llama/Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
181
+ "meta-llama/Llama-3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
182
+ "meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
183
+ "meta-llama/Meta-Llama-3-70B-Instruct": "meta-llama/Llama-3-70b-chat-hf",
184
+ "meta-llama/Meta-Llama-3-8B-Instruct": "togethercomputer/Llama-3-8b-chat-hf-int4",
185
+ "meta-llama/Meta-Llama-3.1-405B-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
186
+ "meta-llama/Meta-Llama-3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
187
+ "meta-llama/Meta-Llama-3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K",
188
+ "microsoft/WizardLM-2-8x22B": "microsoft/WizardLM-2-8x22B",
189
+ "mistralai/Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3",
190
+ "mistralai/Mixtral-8x22B-Instruct-v0.1": "mistralai/Mixtral-8x22B-Instruct-v0.1",
191
+ "mistralai/Mixtral-8x7B-Instruct-v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1",
192
+ "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
193
+ "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
194
+ "Qwen/Qwen2-72B-Instruct": "Qwen/Qwen2-72B-Instruct",
195
+ "Qwen/Qwen2.5-72B-Instruct": "Qwen/Qwen2.5-72B-Instruct-Turbo",
196
+ "Qwen/Qwen2.5-7B-Instruct": "Qwen/Qwen2.5-7B-Instruct-Turbo",
197
+ "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen/Qwen2.5-Coder-32B-Instruct",
198
+ "Qwen/QwQ-32B-Preview": "Qwen/QwQ-32B-Preview",
199
+ "scb10x/llama-3-typhoon-v1.5-8b-instruct": "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct",
200
+ "scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": "scb10x/scb10x-llama3-typhoon-v1-5x-4f316"
201
+ },
202
+ "text-generation": {
203
+ "meta-llama/Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B",
204
+ "mistralai/Mixtral-8x7B-v0.1": "mistralai/Mixtral-8x7B-v0.1"
205
+ }
206
+ };
119
207
 
120
208
  // src/lib/isUrl.ts
121
209
  function isUrl(modelOrUrl) {
122
210
  return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
123
211
  }
124
212
 
125
- // src/lib/getDefaultTask.ts
126
- var taskCache = /* @__PURE__ */ new Map();
127
- var CACHE_DURATION = 10 * 60 * 1e3;
128
- var MAX_CACHE_ITEMS = 1e3;
129
- var HF_HUB_URL = "https://huggingface.co";
130
- async function getDefaultTask(model, accessToken, options) {
131
- if (isUrl(model)) {
132
- return null;
133
- }
134
- const key = `${model}:${accessToken}`;
135
- let cachedTask = taskCache.get(key);
136
- if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
137
- taskCache.delete(key);
138
- cachedTask = void 0;
139
- }
140
- if (cachedTask === void 0) {
141
- const modelTask = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
142
- headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
143
- }).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
144
- if (!modelTask) {
145
- return null;
146
- }
147
- cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
148
- taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
149
- if (taskCache.size > MAX_CACHE_ITEMS) {
150
- taskCache.delete(taskCache.keys().next().value);
151
- }
152
- }
153
- return cachedTask.task;
154
- }
155
-
156
213
  // src/lib/makeRequestOptions.ts
157
- var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
214
+ var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_HUB_URL}/api/inference-proxy/{{PROVIDER}}`;
158
215
  var tasks = null;
159
216
  async function makeRequestOptions(args, options) {
160
- const { accessToken, endpointUrl, ...otherArgs } = args;
161
- let { model } = args;
162
- const {
163
- forceTask: task,
164
- includeCredentials,
165
- taskHint,
166
- wait_for_model,
167
- use_cache,
168
- dont_load_model,
169
- chatCompletion: chatCompletion2
170
- } = options ?? {};
171
- const headers = {};
172
- if (accessToken) {
173
- headers["Authorization"] = `Bearer ${accessToken}`;
217
+ const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
218
+ let otherArgs = remainingArgs;
219
+ const provider = maybeProvider ?? "hf-inference";
220
+ const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion: chatCompletion2 } = options ?? {};
221
+ if (endpointUrl && provider !== "hf-inference") {
222
+ throw new Error(`Cannot use endpointUrl with a third-party provider.`);
174
223
  }
175
- if (!model && !tasks && taskHint) {
176
- const res = await fetch(`${HF_HUB_URL}/api/tasks`);
177
- if (res.ok) {
178
- tasks = await res.json();
179
- }
224
+ if (forceTask && provider !== "hf-inference") {
225
+ throw new Error(`Cannot use forceTask with a third-party provider.`);
226
+ }
227
+ if (maybeModel && isUrl(maybeModel)) {
228
+ throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
180
229
  }
181
- if (!model && tasks && taskHint) {
182
- const taskInfo = tasks[taskHint];
183
- if (taskInfo) {
184
- model = taskInfo.models[0].id;
230
+ let model;
231
+ if (!maybeModel) {
232
+ if (taskHint) {
233
+ model = mapModel({ model: await loadDefaultModel(taskHint), provider, taskHint, chatCompletion: chatCompletion2 });
234
+ } else {
235
+ throw new Error("No model provided, and no default model found for this task");
185
236
  }
237
+ } else {
238
+ model = mapModel({ model: maybeModel, provider, taskHint, chatCompletion: chatCompletion2 });
186
239
  }
187
- if (!model) {
188
- throw new Error("No model provided, and no default model found for this task");
240
+ const authMethod = accessToken ? accessToken.startsWith("hf_") ? "hf-token" : "provider-key" : includeCredentials === "include" ? "credentials-include" : "none";
241
+ const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : makeUrl({
242
+ authMethod,
243
+ chatCompletion: chatCompletion2 ?? false,
244
+ forceTask,
245
+ model,
246
+ provider: provider ?? "hf-inference",
247
+ taskHint
248
+ });
249
+ const headers = {};
250
+ if (accessToken) {
251
+ headers["Authorization"] = provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
189
252
  }
190
253
  const binary = "data" in args && !!args.data;
191
254
  if (!binary) {
192
255
  headers["Content-Type"] = "application/json";
193
256
  }
194
- if (wait_for_model) {
195
- headers["X-Wait-For-Model"] = "true";
196
- }
197
- if (use_cache === false) {
198
- headers["X-Use-Cache"] = "false";
199
- }
200
- if (dont_load_model) {
201
- headers["X-Load-Model"] = "0";
202
- }
203
- let url = (() => {
204
- if (endpointUrl && isUrl(model)) {
205
- throw new TypeError("Both model and endpointUrl cannot be URLs");
257
+ if (provider === "hf-inference") {
258
+ if (wait_for_model) {
259
+ headers["X-Wait-For-Model"] = "true";
206
260
  }
207
- if (isUrl(model)) {
208
- console.warn("Using a model URL is deprecated, please use the `endpointUrl` parameter instead");
209
- return model;
261
+ if (use_cache === false) {
262
+ headers["X-Use-Cache"] = "false";
210
263
  }
211
- if (endpointUrl) {
212
- return endpointUrl;
264
+ if (dont_load_model) {
265
+ headers["X-Load-Model"] = "0";
213
266
  }
214
- if (task) {
215
- return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
216
- }
217
- return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
218
- })();
219
- if (chatCompletion2 && !url.endsWith("/chat/completions")) {
220
- url += "/v1/chat/completions";
267
+ }
268
+ if (provider === "replicate") {
269
+ headers["Prefer"] = "wait";
221
270
  }
222
271
  let credentials;
223
272
  if (typeof includeCredentials === "string") {
@@ -225,17 +274,110 @@ async function makeRequestOptions(args, options) {
225
274
  } else if (includeCredentials === true) {
226
275
  credentials = "include";
227
276
  }
277
+ if (provider === "replicate") {
278
+ const version = model.includes(":") ? model.split(":")[1] : void 0;
279
+ otherArgs = { input: otherArgs, version };
280
+ }
228
281
  const info = {
229
282
  headers,
230
283
  method: "POST",
231
284
  body: binary ? args.data : JSON.stringify({
232
- ...otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs
285
+ ...otherArgs,
286
+ ...chatCompletion2 || provider === "together" ? { model } : void 0
233
287
  }),
234
- ...credentials && { credentials },
288
+ ...credentials ? { credentials } : void 0,
235
289
  signal: options?.signal
236
290
  };
237
291
  return { url, info };
238
292
  }
293
+ function mapModel(params) {
294
+ if (params.provider === "hf-inference") {
295
+ return params.model;
296
+ }
297
+ if (!params.taskHint) {
298
+ throw new Error("taskHint must be specified when using a third-party provider");
299
+ }
300
+ const task = params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint;
301
+ const model = (() => {
302
+ switch (params.provider) {
303
+ case "fal-ai":
304
+ return FAL_AI_SUPPORTED_MODEL_IDS[task]?.[params.model];
305
+ case "replicate":
306
+ return REPLICATE_SUPPORTED_MODEL_IDS[task]?.[params.model];
307
+ case "sambanova":
308
+ return SAMBANOVA_SUPPORTED_MODEL_IDS[task]?.[params.model];
309
+ case "together":
310
+ return TOGETHER_SUPPORTED_MODEL_IDS[task]?.[params.model];
311
+ }
312
+ })();
313
+ if (!model) {
314
+ throw new Error(`Model ${params.model} is not supported for task ${task} and provider ${params.provider}`);
315
+ }
316
+ return model;
317
+ }
318
+ function makeUrl(params) {
319
+ if (params.authMethod === "none" && params.provider !== "hf-inference") {
320
+ throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
321
+ }
322
+ const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
323
+ switch (params.provider) {
324
+ case "fal-ai": {
325
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FAL_AI_API_BASE_URL;
326
+ return `${baseUrl}/${params.model}`;
327
+ }
328
+ case "replicate": {
329
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : REPLICATE_API_BASE_URL;
330
+ if (params.model.includes(":")) {
331
+ return `${baseUrl}/v1/predictions`;
332
+ }
333
+ return `${baseUrl}/v1/models/${params.model}/predictions`;
334
+ }
335
+ case "sambanova": {
336
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : SAMBANOVA_API_BASE_URL;
337
+ if (params.taskHint === "text-generation" && params.chatCompletion) {
338
+ return `${baseUrl}/v1/chat/completions`;
339
+ }
340
+ return baseUrl;
341
+ }
342
+ case "together": {
343
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : TOGETHER_API_BASE_URL;
344
+ if (params.taskHint === "text-to-image") {
345
+ return `${baseUrl}/v1/images/generations`;
346
+ }
347
+ if (params.taskHint === "text-generation") {
348
+ if (params.chatCompletion) {
349
+ return `${baseUrl}/v1/chat/completions`;
350
+ }
351
+ return `${baseUrl}/v1/completions`;
352
+ }
353
+ return baseUrl;
354
+ }
355
+ default: {
356
+ const url = params.forceTask ? `${HF_INFERENCE_API_URL}/pipeline/${params.forceTask}/${params.model}` : `${HF_INFERENCE_API_URL}/models/${params.model}`;
357
+ if (params.taskHint === "text-generation" && params.chatCompletion) {
358
+ return url + `/v1/chat/completions`;
359
+ }
360
+ return url;
361
+ }
362
+ }
363
+ }
364
+ async function loadDefaultModel(task) {
365
+ if (!tasks) {
366
+ tasks = await loadTaskInfo();
367
+ }
368
+ const taskInfo = tasks[task];
369
+ if ((taskInfo?.models.length ?? 0) <= 0) {
370
+ throw new Error(`No default model defined for task ${task}, please define the model explicitly.`);
371
+ }
372
+ return taskInfo.models[0].id;
373
+ }
374
+ async function loadTaskInfo() {
375
+ const res = await fetch(`${HF_HUB_URL}/api/tasks`);
376
+ if (!res.ok) {
377
+ throw new Error("Failed to load tasks definitions from Hugging Face Hub.");
378
+ }
379
+ return await res.json();
380
+ }
239
381
 
240
382
  // src/tasks/custom/request.ts
241
383
  async function request(args, options) {
@@ -248,16 +390,22 @@ async function request(args, options) {
248
390
  });
249
391
  }
250
392
  if (!response.ok) {
251
- if (response.headers.get("Content-Type")?.startsWith("application/json")) {
393
+ const contentType = response.headers.get("Content-Type");
394
+ if (["application/json", "application/problem+json"].some((ct) => contentType?.startsWith(ct))) {
252
395
  const output = await response.json();
253
396
  if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
254
- throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
397
+ throw new Error(
398
+ `Server ${args.model} does not seem to support chat completion. Error: ${JSON.stringify(output.error)}`
399
+ );
255
400
  }
256
- if (output.error) {
257
- throw new Error(JSON.stringify(output.error));
401
+ if (output.error || output.detail) {
402
+ throw new Error(JSON.stringify(output.error ?? output.detail));
403
+ } else {
404
+ throw new Error(output);
258
405
  }
259
406
  }
260
- throw new Error("An error occurred while fetching the blob");
407
+ const message = contentType?.startsWith("text/plain;") ? await response.text() : void 0;
408
+ throw new Error(message ?? "An error occurred while fetching the blob");
261
409
  }
262
410
  if (response.headers.get("Content-Type")?.startsWith("application/json")) {
263
411
  return await response.json();
@@ -380,9 +528,12 @@ async function* streamingRequest(args, options) {
380
528
  if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
381
529
  throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
382
530
  }
383
- if (output.error) {
531
+ if (typeof output.error === "string") {
384
532
  throw new Error(output.error);
385
533
  }
534
+ if (output.error && "message" in output.error && typeof output.error.message === "string") {
535
+ throw new Error(output.error.message);
536
+ }
386
537
  }
387
538
  throw new Error(`Server response contains error: ${response.status}`);
388
539
  }
@@ -411,8 +562,9 @@ async function* streamingRequest(args, options) {
411
562
  try {
412
563
  while (true) {
413
564
  const { done, value } = await reader.read();
414
- if (done)
565
+ if (done) {
415
566
  return;
567
+ }
416
568
  onChunk(value);
417
569
  for (const event of events) {
418
570
  if (event.data.length > 0) {
@@ -421,7 +573,8 @@ async function* streamingRequest(args, options) {
421
573
  }
422
574
  const data = JSON.parse(event.data);
423
575
  if (typeof data === "object" && data !== null && "error" in data) {
424
- throw new Error(data.error);
576
+ const errorStr = typeof data.error === "string" ? data.error : typeof data.error === "object" && data.error && "message" in data.error && typeof data.error.message === "string" ? data.error.message : JSON.stringify(data.error);
577
+ throw new Error(`Error forwarded from backend: ` + errorStr);
425
578
  }
426
579
  yield data;
427
580
  }
@@ -456,8 +609,29 @@ async function audioClassification(args, options) {
456
609
  return res;
457
610
  }
458
611
 
612
+ // src/utils/base64FromBytes.ts
613
+ function base64FromBytes(arr) {
614
+ if (globalThis.Buffer) {
615
+ return globalThis.Buffer.from(arr).toString("base64");
616
+ } else {
617
+ const bin = [];
618
+ arr.forEach((byte) => {
619
+ bin.push(String.fromCharCode(byte));
620
+ });
621
+ return globalThis.btoa(bin.join(""));
622
+ }
623
+ }
624
+
459
625
  // src/tasks/audio/automaticSpeechRecognition.ts
460
626
  async function automaticSpeechRecognition(args, options) {
627
+ if (args.provider === "fal-ai") {
628
+ const contentType = args.data instanceof Blob ? args.data.type : "audio/mpeg";
629
+ const base64audio = base64FromBytes(
630
+ new Uint8Array(args.data instanceof ArrayBuffer ? args.data : await args.data.arrayBuffer())
631
+ );
632
+ args.audio_url = `data:${contentType};base64,${base64audio}`;
633
+ delete args.data;
634
+ }
461
635
  const res = await request(args, {
462
636
  ...options,
463
637
  taskHint: "automatic-speech-recognition"
@@ -475,6 +649,19 @@ async function textToSpeech(args, options) {
475
649
  ...options,
476
650
  taskHint: "text-to-speech"
477
651
  });
652
+ if (res && typeof res === "object") {
653
+ if ("output" in res) {
654
+ if (typeof res.output === "string") {
655
+ const urlResponse = await fetch(res.output);
656
+ const blob = await urlResponse.blob();
657
+ return blob;
658
+ } else if (Array.isArray(res.output)) {
659
+ const urlResponse = await fetch(res.output[0]);
660
+ const blob = await urlResponse.blob();
661
+ return blob;
662
+ }
663
+ }
664
+ }
478
665
  const isValidOutput = res && res instanceof Blob;
479
666
  if (!isValidOutput) {
480
667
  throw new InferenceOutputError("Expected Blob");
@@ -554,10 +741,35 @@ async function objectDetection(args, options) {
554
741
 
555
742
  // src/tasks/cv/textToImage.ts
556
743
  async function textToImage(args, options) {
744
+ if (args.provider === "together" || args.provider === "fal-ai") {
745
+ args.prompt = args.inputs;
746
+ delete args.inputs;
747
+ args.response_format = "base64";
748
+ } else if (args.provider === "replicate") {
749
+ args.prompt = args.inputs;
750
+ delete args.inputs;
751
+ }
557
752
  const res = await request(args, {
558
753
  ...options,
559
754
  taskHint: "text-to-image"
560
755
  });
756
+ if (res && typeof res === "object") {
757
+ if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
758
+ const image = await fetch(res.images[0].url);
759
+ return await image.blob();
760
+ }
761
+ if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
762
+ const base64Data = res.data[0].b64_json;
763
+ const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
764
+ const blob = await base64Response.blob();
765
+ return blob;
766
+ }
767
+ if ("output" in res && Array.isArray(res.output)) {
768
+ const urlResponse = await fetch(res.output[0]);
769
+ const blob = await urlResponse.blob();
770
+ return blob;
771
+ }
772
+ }
561
773
  const isValidOutput = res && res instanceof Blob;
562
774
  if (!isValidOutput) {
563
775
  throw new InferenceOutputError("Expected Blob");
@@ -565,19 +777,6 @@ async function textToImage(args, options) {
565
777
  return res;
566
778
  }
567
779
 
568
- // src/utils/base64FromBytes.ts
569
- function base64FromBytes(arr) {
570
- if (globalThis.Buffer) {
571
- return globalThis.Buffer.from(arr).toString("base64");
572
- } else {
573
- const bin = [];
574
- arr.forEach((byte) => {
575
- bin.push(String.fromCharCode(byte));
576
- });
577
- return globalThis.btoa(bin.join(""));
578
- }
579
- }
580
-
581
780
  // src/tasks/cv/imageToImage.ts
582
781
  async function imageToImage(args, options) {
583
782
  let reqArgs;
@@ -629,6 +828,36 @@ async function zeroShotImageClassification(args, options) {
629
828
  return res;
630
829
  }
631
830
 
831
+ // src/lib/getDefaultTask.ts
832
+ var taskCache = /* @__PURE__ */ new Map();
833
+ var CACHE_DURATION = 10 * 60 * 1e3;
834
+ var MAX_CACHE_ITEMS = 1e3;
835
+ async function getDefaultTask(model, accessToken, options) {
836
+ if (isUrl(model)) {
837
+ return null;
838
+ }
839
+ const key = `${model}:${accessToken}`;
840
+ let cachedTask = taskCache.get(key);
841
+ if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
842
+ taskCache.delete(key);
843
+ cachedTask = void 0;
844
+ }
845
+ if (cachedTask === void 0) {
846
+ const modelTask = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
847
+ headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
848
+ }).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
849
+ if (!modelTask) {
850
+ return null;
851
+ }
852
+ cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
853
+ taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
854
+ if (taskCache.size > MAX_CACHE_ITEMS) {
855
+ taskCache.delete(taskCache.keys().next().value);
856
+ }
857
+ }
858
+ return cachedTask.task;
859
+ }
860
+
632
861
  // src/tasks/nlp/featureExtraction.ts
633
862
  async function featureExtraction(args, options) {
634
863
  const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : void 0;
@@ -750,17 +979,33 @@ function toArray(obj) {
750
979
 
751
980
  // src/tasks/nlp/textGeneration.ts
752
981
  async function textGeneration(args, options) {
753
- const res = toArray(
754
- await request(args, {
982
+ if (args.provider === "together") {
983
+ args.prompt = args.inputs;
984
+ const raw = await request(args, {
755
985
  ...options,
756
986
  taskHint: "text-generation"
757
- })
758
- );
759
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string");
760
- if (!isValidOutput) {
761
- throw new InferenceOutputError("Expected Array<{generated_text: string}>");
987
+ });
988
+ const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
989
+ if (!isValidOutput) {
990
+ throw new InferenceOutputError("Expected ChatCompletionOutput");
991
+ }
992
+ const completion = raw.choices[0];
993
+ return {
994
+ generated_text: completion.text
995
+ };
996
+ } else {
997
+ const res = toArray(
998
+ await request(args, {
999
+ ...options,
1000
+ taskHint: "text-generation"
1001
+ })
1002
+ );
1003
+ const isValidOutput = Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
1004
+ if (!isValidOutput) {
1005
+ throw new InferenceOutputError("Expected Array<{generated_text: string}>");
1006
+ }
1007
+ return res?.[0];
762
1008
  }
763
- return res?.[0];
764
1009
  }
765
1010
 
766
1011
  // src/tasks/nlp/textGenerationStream.ts
@@ -827,7 +1072,8 @@ async function chatCompletion(args, options) {
827
1072
  taskHint: "text-generation",
828
1073
  chatCompletion: true
829
1074
  });
830
- const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && typeof res?.system_fingerprint === "string" && typeof res?.usage === "object";
1075
+ const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai does not output a system_fingerprint
1076
+ (res.system_fingerprint === void 0 || typeof res.system_fingerprint === "string") && typeof res?.usage === "object";
831
1077
  if (!isValidOutput) {
832
1078
  throw new InferenceOutputError("Expected ChatCompletionOutput");
833
1079
  }
@@ -960,11 +1206,19 @@ var HfInferenceEndpoint = class {
960
1206
  }
961
1207
  }
962
1208
  };
1209
+
1210
+ // src/types.ts
1211
+ var INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-inference"];
963
1212
  // Annotate the CommonJS export names for ESM import in node:
964
1213
  0 && (module.exports = {
1214
+ FAL_AI_SUPPORTED_MODEL_IDS,
965
1215
  HfInference,
966
1216
  HfInferenceEndpoint,
1217
+ INFERENCE_PROVIDERS,
967
1218
  InferenceOutputError,
1219
+ REPLICATE_SUPPORTED_MODEL_IDS,
1220
+ SAMBANOVA_SUPPORTED_MODEL_IDS,
1221
+ TOGETHER_SUPPORTED_MODEL_IDS,
968
1222
  audioClassification,
969
1223
  audioToAudio,
970
1224
  automaticSpeechRecognition,