@huggingface/inference 2.8.0 → 3.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/LICENSE +1 -1
  2. package/README.md +39 -16
  3. package/dist/index.cjs +364 -134
  4. package/dist/index.js +359 -134
  5. package/dist/src/config.d.ts +3 -0
  6. package/dist/src/config.d.ts.map +1 -0
  7. package/dist/src/index.d.ts +5 -0
  8. package/dist/src/index.d.ts.map +1 -1
  9. package/dist/src/lib/getDefaultTask.d.ts +0 -1
  10. package/dist/src/lib/getDefaultTask.d.ts.map +1 -1
  11. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  12. package/dist/src/providers/fal-ai.d.ts +6 -0
  13. package/dist/src/providers/fal-ai.d.ts.map +1 -0
  14. package/dist/src/providers/replicate.d.ts +6 -0
  15. package/dist/src/providers/replicate.d.ts.map +1 -0
  16. package/dist/src/providers/sambanova.d.ts +6 -0
  17. package/dist/src/providers/sambanova.d.ts.map +1 -0
  18. package/dist/src/providers/together.d.ts +12 -0
  19. package/dist/src/providers/together.d.ts.map +1 -0
  20. package/dist/src/providers/types.d.ts +4 -0
  21. package/dist/src/providers/types.d.ts.map +1 -0
  22. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  23. package/dist/src/tasks/custom/request.d.ts +1 -1
  24. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  25. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  26. package/dist/src/tasks/cv/textToImage.d.ts +8 -0
  27. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  28. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  29. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  30. package/dist/src/types.d.ts +16 -2
  31. package/dist/src/types.d.ts.map +1 -1
  32. package/package.json +2 -2
  33. package/src/config.ts +2 -0
  34. package/src/index.ts +5 -0
  35. package/src/lib/getDefaultTask.ts +1 -1
  36. package/src/lib/makeRequestOptions.ts +199 -59
  37. package/src/providers/fal-ai.ts +15 -0
  38. package/src/providers/replicate.ts +16 -0
  39. package/src/providers/sambanova.ts +23 -0
  40. package/src/providers/together.ts +58 -0
  41. package/src/providers/types.ts +6 -0
  42. package/src/tasks/audio/automaticSpeechRecognition.ts +10 -1
  43. package/src/tasks/custom/request.ts +12 -6
  44. package/src/tasks/custom/streamingRequest.ts +18 -3
  45. package/src/tasks/cv/textToImage.ts +44 -1
  46. package/src/tasks/nlp/chatCompletion.ts +2 -2
  47. package/src/tasks/nlp/textGeneration.ts +43 -9
  48. package/src/types.ts +20 -2
package/dist/index.cjs CHANGED
@@ -20,9 +20,14 @@ var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: tru
20
20
  // src/index.ts
21
21
  var src_exports = {};
22
22
  __export(src_exports, {
23
+ FAL_AI_SUPPORTED_MODEL_IDS: () => FAL_AI_SUPPORTED_MODEL_IDS,
23
24
  HfInference: () => HfInference,
24
25
  HfInferenceEndpoint: () => HfInferenceEndpoint,
26
+ INFERENCE_PROVIDERS: () => INFERENCE_PROVIDERS,
25
27
  InferenceOutputError: () => InferenceOutputError,
28
+ REPLICATE_SUPPORTED_MODEL_IDS: () => REPLICATE_SUPPORTED_MODEL_IDS,
29
+ SAMBANOVA_SUPPORTED_MODEL_IDS: () => SAMBANOVA_SUPPORTED_MODEL_IDS,
30
+ TOGETHER_SUPPORTED_MODEL_IDS: () => TOGETHER_SUPPORTED_MODEL_IDS,
26
31
  audioClassification: () => audioClassification,
27
32
  audioToAudio: () => audioToAudio,
28
33
  automaticSpeechRecognition: () => automaticSpeechRecognition,
@@ -93,131 +98,164 @@ __export(tasks_exports, {
93
98
  zeroShotImageClassification: () => zeroShotImageClassification
94
99
  });
95
100
 
96
- // src/utils/pick.ts
97
- function pick(o, props) {
98
- return Object.assign(
99
- {},
100
- ...props.map((prop) => {
101
- if (o[prop] !== void 0) {
102
- return { [prop]: o[prop] };
103
- }
104
- })
105
- );
106
- }
101
+ // src/config.ts
102
+ var HF_HUB_URL = "https://huggingface.co";
103
+ var HF_INFERENCE_API_URL = "https://api-inference.huggingface.co";
107
104
 
108
- // src/utils/typedInclude.ts
109
- function typedInclude(arr, v) {
110
- return arr.includes(v);
111
- }
105
+ // src/providers/fal-ai.ts
106
+ var FAL_AI_API_BASE_URL = "https://fal.run";
107
+ var FAL_AI_SUPPORTED_MODEL_IDS = {
108
+ "text-to-image": {
109
+ "black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
110
+ "black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev"
111
+ },
112
+ "automatic-speech-recognition": {
113
+ "openai/whisper-large-v3": "fal-ai/whisper"
114
+ }
115
+ };
112
116
 
113
- // src/utils/omit.ts
114
- function omit(o, props) {
115
- const propsArr = Array.isArray(props) ? props : [props];
116
- const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
117
- return pick(o, letsKeep);
118
- }
117
+ // src/providers/replicate.ts
118
+ var REPLICATE_API_BASE_URL = "https://api.replicate.com";
119
+ var REPLICATE_SUPPORTED_MODEL_IDS = {
120
+ "text-to-image": {
121
+ "black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
122
+ "ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637"
123
+ }
124
+ // "text-to-speech": {
125
+ // "SWivid/F5-TTS": "x-lance/f5-tts:87faf6dd7a692dd82043f662e76369cab126a2cf1937e25a9d41e0b834fd230e"
126
+ // },
127
+ };
128
+
129
+ // src/providers/sambanova.ts
130
+ var SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
131
+ var SAMBANOVA_SUPPORTED_MODEL_IDS = {
132
+ /** Chat completion / conversational */
133
+ conversational: {
134
+ "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
135
+ "Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
136
+ "Qwen/QwQ-32B-Preview": "QwQ-32B-Preview",
137
+ "meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct",
138
+ "meta-llama/Llama-3.2-1B-Instruct": "Meta-Llama-3.2-1B-Instruct",
139
+ "meta-llama/Llama-3.2-3B-Instruct": "Meta-Llama-3.2-3B-Instruct",
140
+ "meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
141
+ "meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct",
142
+ "meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
143
+ "meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct",
144
+ "meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct",
145
+ "meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B"
146
+ }
147
+ };
148
+
149
+ // src/providers/together.ts
150
+ var TOGETHER_API_BASE_URL = "https://api.together.xyz";
151
+ var TOGETHER_SUPPORTED_MODEL_IDS = {
152
+ "text-to-image": {
153
+ "black-forest-labs/FLUX.1-Canny-dev": "black-forest-labs/FLUX.1-canny",
154
+ "black-forest-labs/FLUX.1-Depth-dev": "black-forest-labs/FLUX.1-depth",
155
+ "black-forest-labs/FLUX.1-dev": "black-forest-labs/FLUX.1-dev",
156
+ "black-forest-labs/FLUX.1-Redux-dev": "black-forest-labs/FLUX.1-redux",
157
+ "black-forest-labs/FLUX.1-schnell": "black-forest-labs/FLUX.1-pro",
158
+ "stabilityai/stable-diffusion-xl-base-1.0": "stabilityai/stable-diffusion-xl-base-1.0"
159
+ },
160
+ conversational: {
161
+ "databricks/dbrx-instruct": "databricks/dbrx-instruct",
162
+ "deepseek-ai/deepseek-llm-67b-chat": "deepseek-ai/deepseek-llm-67b-chat",
163
+ "google/gemma-2-9b-it": "google/gemma-2-9b-it",
164
+ "google/gemma-2b-it": "google/gemma-2-27b-it",
165
+ "llava-hf/llava-v1.6-mistral-7b-hf": "llava-hf/llava-v1.6-mistral-7b-hf",
166
+ "meta-llama/Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
167
+ "meta-llama/Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
168
+ "meta-llama/Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
169
+ "meta-llama/Llama-3.2-11B-Vision-Instruct": "meta-llama/Llama-Vision-Free",
170
+ "meta-llama/Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
171
+ "meta-llama/Llama-3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
172
+ "meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
173
+ "meta-llama/Meta-Llama-3-70B-Instruct": "meta-llama/Llama-3-70b-chat-hf",
174
+ "meta-llama/Meta-Llama-3-8B-Instruct": "togethercomputer/Llama-3-8b-chat-hf-int4",
175
+ "meta-llama/Meta-Llama-3.1-405B-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
176
+ "meta-llama/Meta-Llama-3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
177
+ "meta-llama/Meta-Llama-3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K",
178
+ "microsoft/WizardLM-2-8x22B": "microsoft/WizardLM-2-8x22B",
179
+ "mistralai/Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3",
180
+ "mistralai/Mixtral-8x22B-Instruct-v0.1": "mistralai/Mixtral-8x22B-Instruct-v0.1",
181
+ "mistralai/Mixtral-8x7B-Instruct-v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1",
182
+ "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
183
+ "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
184
+ "Qwen/Qwen2-72B-Instruct": "Qwen/Qwen2-72B-Instruct",
185
+ "Qwen/Qwen2.5-72B-Instruct": "Qwen/Qwen2.5-72B-Instruct-Turbo",
186
+ "Qwen/Qwen2.5-7B-Instruct": "Qwen/Qwen2.5-7B-Instruct-Turbo",
187
+ "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen/Qwen2.5-Coder-32B-Instruct",
188
+ "Qwen/QwQ-32B-Preview": "Qwen/QwQ-32B-Preview",
189
+ "scb10x/llama-3-typhoon-v1.5-8b-instruct": "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct",
190
+ "scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": "scb10x/scb10x-llama3-typhoon-v1-5x-4f316"
191
+ },
192
+ "text-generation": {
193
+ "meta-llama/Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B",
194
+ "mistralai/Mixtral-8x7B-v0.1": "mistralai/Mixtral-8x7B-v0.1"
195
+ }
196
+ };
119
197
 
120
198
  // src/lib/isUrl.ts
121
199
  function isUrl(modelOrUrl) {
122
200
  return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
123
201
  }
124
202
 
125
- // src/lib/getDefaultTask.ts
126
- var taskCache = /* @__PURE__ */ new Map();
127
- var CACHE_DURATION = 10 * 60 * 1e3;
128
- var MAX_CACHE_ITEMS = 1e3;
129
- var HF_HUB_URL = "https://huggingface.co";
130
- async function getDefaultTask(model, accessToken, options) {
131
- if (isUrl(model)) {
132
- return null;
133
- }
134
- const key = `${model}:${accessToken}`;
135
- let cachedTask = taskCache.get(key);
136
- if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
137
- taskCache.delete(key);
138
- cachedTask = void 0;
139
- }
140
- if (cachedTask === void 0) {
141
- const modelTask = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
142
- headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
143
- }).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
144
- if (!modelTask) {
145
- return null;
146
- }
147
- cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
148
- taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
149
- if (taskCache.size > MAX_CACHE_ITEMS) {
150
- taskCache.delete(taskCache.keys().next().value);
151
- }
152
- }
153
- return cachedTask.task;
154
- }
155
-
156
203
  // src/lib/makeRequestOptions.ts
157
- var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
204
+ var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_HUB_URL}/api/inference-proxy/{{PROVIDER}}`;
158
205
  var tasks = null;
159
206
  async function makeRequestOptions(args, options) {
160
- const { accessToken, endpointUrl, ...otherArgs } = args;
161
- let { model } = args;
162
- const {
163
- forceTask: task,
164
- includeCredentials,
165
- taskHint,
166
- wait_for_model,
167
- use_cache,
168
- dont_load_model,
169
- chatCompletion: chatCompletion2
170
- } = options ?? {};
171
- const headers = {};
172
- if (accessToken) {
173
- headers["Authorization"] = `Bearer ${accessToken}`;
207
+ const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...otherArgs } = args;
208
+ const provider = maybeProvider ?? "hf-inference";
209
+ const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion: chatCompletion2 } = options ?? {};
210
+ if (endpointUrl && provider !== "hf-inference") {
211
+ throw new Error(`Cannot use endpointUrl with a third-party provider.`);
174
212
  }
175
- if (!model && !tasks && taskHint) {
176
- const res = await fetch(`${HF_HUB_URL}/api/tasks`);
177
- if (res.ok) {
178
- tasks = await res.json();
179
- }
213
+ if (forceTask && provider !== "hf-inference") {
214
+ throw new Error(`Cannot use forceTask with a third-party provider.`);
215
+ }
216
+ if (maybeModel && isUrl(maybeModel)) {
217
+ throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
180
218
  }
181
- if (!model && tasks && taskHint) {
182
- const taskInfo = tasks[taskHint];
183
- if (taskInfo) {
184
- model = taskInfo.models[0].id;
219
+ let model;
220
+ if (!maybeModel) {
221
+ if (taskHint) {
222
+ model = mapModel({ model: await loadDefaultModel(taskHint), provider, taskHint, chatCompletion: chatCompletion2 });
223
+ } else {
224
+ throw new Error("No model provided, and no default model found for this task");
185
225
  }
226
+ } else {
227
+ model = mapModel({ model: maybeModel, provider, taskHint, chatCompletion: chatCompletion2 });
186
228
  }
187
- if (!model) {
188
- throw new Error("No model provided, and no default model found for this task");
229
+ const authMethod = accessToken ? accessToken.startsWith("hf_") ? "hf-token" : "provider-key" : includeCredentials === "include" ? "credentials-include" : "none";
230
+ const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : makeUrl({
231
+ authMethod,
232
+ chatCompletion: chatCompletion2 ?? false,
233
+ forceTask,
234
+ model,
235
+ provider: provider ?? "hf-inference",
236
+ taskHint
237
+ });
238
+ const headers = {};
239
+ if (accessToken) {
240
+ headers["Authorization"] = provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
189
241
  }
190
242
  const binary = "data" in args && !!args.data;
191
243
  if (!binary) {
192
244
  headers["Content-Type"] = "application/json";
193
245
  }
194
- if (wait_for_model) {
195
- headers["X-Wait-For-Model"] = "true";
196
- }
197
- if (use_cache === false) {
198
- headers["X-Use-Cache"] = "false";
199
- }
200
- if (dont_load_model) {
201
- headers["X-Load-Model"] = "0";
202
- }
203
- let url = (() => {
204
- if (endpointUrl && isUrl(model)) {
205
- throw new TypeError("Both model and endpointUrl cannot be URLs");
206
- }
207
- if (isUrl(model)) {
208
- console.warn("Using a model URL is deprecated, please use the `endpointUrl` parameter instead");
209
- return model;
246
+ if (provider === "hf-inference") {
247
+ if (wait_for_model) {
248
+ headers["X-Wait-For-Model"] = "true";
210
249
  }
211
- if (endpointUrl) {
212
- return endpointUrl;
250
+ if (use_cache === false) {
251
+ headers["X-Use-Cache"] = "false";
213
252
  }
214
- if (task) {
215
- return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
253
+ if (dont_load_model) {
254
+ headers["X-Load-Model"] = "0";
216
255
  }
217
- return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
218
- })();
219
- if (chatCompletion2 && !url.endsWith("/chat/completions")) {
220
- url += "/v1/chat/completions";
256
+ }
257
+ if (provider === "replicate") {
258
+ headers["Prefer"] = "wait";
221
259
  }
222
260
  let credentials;
223
261
  if (typeof includeCredentials === "string") {
@@ -225,17 +263,110 @@ async function makeRequestOptions(args, options) {
225
263
  } else if (includeCredentials === true) {
226
264
  credentials = "include";
227
265
  }
266
+ if (provider === "replicate" && model.includes(":")) {
267
+ const version = model.split(":")[1];
268
+ otherArgs.version = version;
269
+ }
228
270
  const info = {
229
271
  headers,
230
272
  method: "POST",
231
273
  body: binary ? args.data : JSON.stringify({
232
- ...otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs
274
+ ...otherArgs,
275
+ ...chatCompletion2 || provider === "together" ? { model } : void 0
233
276
  }),
234
- ...credentials && { credentials },
277
+ ...credentials ? { credentials } : void 0,
235
278
  signal: options?.signal
236
279
  };
237
280
  return { url, info };
238
281
  }
282
+ function mapModel(params) {
283
+ if (params.provider === "hf-inference") {
284
+ return params.model;
285
+ }
286
+ if (!params.taskHint) {
287
+ throw new Error("taskHint must be specified when using a third-party provider");
288
+ }
289
+ const task = params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint;
290
+ const model = (() => {
291
+ switch (params.provider) {
292
+ case "fal-ai":
293
+ return FAL_AI_SUPPORTED_MODEL_IDS[task]?.[params.model];
294
+ case "replicate":
295
+ return REPLICATE_SUPPORTED_MODEL_IDS[task]?.[params.model];
296
+ case "sambanova":
297
+ return SAMBANOVA_SUPPORTED_MODEL_IDS[task]?.[params.model];
298
+ case "together":
299
+ return TOGETHER_SUPPORTED_MODEL_IDS[task]?.[params.model];
300
+ }
301
+ })();
302
+ if (!model) {
303
+ throw new Error(`Model ${params.model} is not supported for task ${task} and provider ${params.provider}`);
304
+ }
305
+ return model;
306
+ }
307
+ function makeUrl(params) {
308
+ if (params.authMethod === "none" && params.provider !== "hf-inference") {
309
+ throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
310
+ }
311
+ const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
312
+ switch (params.provider) {
313
+ case "fal-ai": {
314
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FAL_AI_API_BASE_URL;
315
+ return `${baseUrl}/${params.model}`;
316
+ }
317
+ case "replicate": {
318
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : REPLICATE_API_BASE_URL;
319
+ if (params.model.includes(":")) {
320
+ return `${baseUrl}/v1/predictions`;
321
+ }
322
+ return `${baseUrl}/v1/models/${params.model}/predictions`;
323
+ }
324
+ case "sambanova": {
325
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : SAMBANOVA_API_BASE_URL;
326
+ if (params.taskHint === "text-generation" && params.chatCompletion) {
327
+ return `${baseUrl}/v1/chat/completions`;
328
+ }
329
+ return baseUrl;
330
+ }
331
+ case "together": {
332
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : TOGETHER_API_BASE_URL;
333
+ if (params.taskHint === "text-to-image") {
334
+ return `${baseUrl}/v1/images/generations`;
335
+ }
336
+ if (params.taskHint === "text-generation") {
337
+ if (params.chatCompletion) {
338
+ return `${baseUrl}/v1/chat/completions`;
339
+ }
340
+ return `${baseUrl}/v1/completions`;
341
+ }
342
+ return baseUrl;
343
+ }
344
+ default: {
345
+ const url = params.forceTask ? `${HF_INFERENCE_API_URL}/pipeline/${params.forceTask}/${params.model}` : `${HF_INFERENCE_API_URL}/models/${params.model}`;
346
+ if (params.taskHint === "text-generation" && params.chatCompletion) {
347
+ return url + `/v1/chat/completions`;
348
+ }
349
+ return url;
350
+ }
351
+ }
352
+ }
353
+ async function loadDefaultModel(task) {
354
+ if (!tasks) {
355
+ tasks = await loadTaskInfo();
356
+ }
357
+ const taskInfo = tasks[task];
358
+ if ((taskInfo?.models.length ?? 0) <= 0) {
359
+ throw new Error(`No default model defined for task ${task}, please define the model explicitly.`);
360
+ }
361
+ return taskInfo.models[0].id;
362
+ }
363
+ async function loadTaskInfo() {
364
+ const res = await fetch(`${HF_HUB_URL}/api/tasks`);
365
+ if (!res.ok) {
366
+ throw new Error("Failed to load tasks definitions from Hugging Face Hub.");
367
+ }
368
+ return await res.json();
369
+ }
239
370
 
240
371
  // src/tasks/custom/request.ts
241
372
  async function request(args, options) {
@@ -248,16 +379,22 @@ async function request(args, options) {
248
379
  });
249
380
  }
250
381
  if (!response.ok) {
251
- if (response.headers.get("Content-Type")?.startsWith("application/json")) {
382
+ const contentType = response.headers.get("Content-Type");
383
+ if (["application/json", "application/problem+json"].some((ct) => contentType?.startsWith(ct))) {
252
384
  const output = await response.json();
253
385
  if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
254
- throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
386
+ throw new Error(
387
+ `Server ${args.model} does not seem to support chat completion. Error: ${JSON.stringify(output.error)}`
388
+ );
255
389
  }
256
- if (output.error) {
257
- throw new Error(output.error);
390
+ if (output.error || output.detail) {
391
+ throw new Error(JSON.stringify(output.error ?? output.detail));
392
+ } else {
393
+ throw new Error(output);
258
394
  }
259
395
  }
260
- throw new Error("An error occurred while fetching the blob");
396
+ const message = contentType?.startsWith("text/plain;") ? await response.text() : void 0;
397
+ throw new Error(message ?? "An error occurred while fetching the blob");
261
398
  }
262
399
  if (response.headers.get("Content-Type")?.startsWith("application/json")) {
263
400
  return await response.json();
@@ -380,9 +517,12 @@ async function* streamingRequest(args, options) {
380
517
  if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
381
518
  throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
382
519
  }
383
- if (output.error) {
520
+ if (typeof output.error === "string") {
384
521
  throw new Error(output.error);
385
522
  }
523
+ if (output.error && "message" in output.error && typeof output.error.message === "string") {
524
+ throw new Error(output.error.message);
525
+ }
386
526
  }
387
527
  throw new Error(`Server response contains error: ${response.status}`);
388
528
  }
@@ -411,8 +551,9 @@ async function* streamingRequest(args, options) {
411
551
  try {
412
552
  while (true) {
413
553
  const { done, value } = await reader.read();
414
- if (done)
554
+ if (done) {
415
555
  return;
556
+ }
416
557
  onChunk(value);
417
558
  for (const event of events) {
418
559
  if (event.data.length > 0) {
@@ -421,7 +562,8 @@ async function* streamingRequest(args, options) {
421
562
  }
422
563
  const data = JSON.parse(event.data);
423
564
  if (typeof data === "object" && data !== null && "error" in data) {
424
- throw new Error(data.error);
565
+ const errorStr = typeof data.error === "string" ? data.error : typeof data.error === "object" && data.error && "message" in data.error && typeof data.error.message === "string" ? data.error.message : JSON.stringify(data.error);
566
+ throw new Error(`Error forwarded from backend: ` + errorStr);
425
567
  }
426
568
  yield data;
427
569
  }
@@ -456,8 +598,29 @@ async function audioClassification(args, options) {
456
598
  return res;
457
599
  }
458
600
 
601
+ // src/utils/base64FromBytes.ts
602
+ function base64FromBytes(arr) {
603
+ if (globalThis.Buffer) {
604
+ return globalThis.Buffer.from(arr).toString("base64");
605
+ } else {
606
+ const bin = [];
607
+ arr.forEach((byte) => {
608
+ bin.push(String.fromCharCode(byte));
609
+ });
610
+ return globalThis.btoa(bin.join(""));
611
+ }
612
+ }
613
+
459
614
  // src/tasks/audio/automaticSpeechRecognition.ts
460
615
  async function automaticSpeechRecognition(args, options) {
616
+ if (args.provider === "fal-ai") {
617
+ const contentType = args.data instanceof Blob ? args.data.type : "audio/mpeg";
618
+ const base64audio = base64FromBytes(
619
+ new Uint8Array(args.data instanceof ArrayBuffer ? args.data : await args.data.arrayBuffer())
620
+ );
621
+ args.audio_url = `data:${contentType};base64,${base64audio}`;
622
+ delete args.data;
623
+ }
461
624
  const res = await request(args, {
462
625
  ...options,
463
626
  taskHint: "automatic-speech-recognition"
@@ -554,10 +717,35 @@ async function objectDetection(args, options) {
554
717
 
555
718
  // src/tasks/cv/textToImage.ts
556
719
  async function textToImage(args, options) {
720
+ if (args.provider === "together" || args.provider === "fal-ai") {
721
+ args.prompt = args.inputs;
722
+ args.inputs = "";
723
+ args.response_format = "base64";
724
+ } else if (args.provider === "replicate") {
725
+ args.input = { prompt: args.inputs };
726
+ delete args.inputs;
727
+ }
557
728
  const res = await request(args, {
558
729
  ...options,
559
730
  taskHint: "text-to-image"
560
731
  });
732
+ if (res && typeof res === "object") {
733
+ if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
734
+ const image = await fetch(res.images[0].url);
735
+ return await image.blob();
736
+ }
737
+ if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
738
+ const base64Data = res.data[0].b64_json;
739
+ const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
740
+ const blob = await base64Response.blob();
741
+ return blob;
742
+ }
743
+ if ("output" in res && Array.isArray(res.output)) {
744
+ const urlResponse = await fetch(res.output[0]);
745
+ const blob = await urlResponse.blob();
746
+ return blob;
747
+ }
748
+ }
561
749
  const isValidOutput = res && res instanceof Blob;
562
750
  if (!isValidOutput) {
563
751
  throw new InferenceOutputError("Expected Blob");
@@ -565,19 +753,6 @@ async function textToImage(args, options) {
565
753
  return res;
566
754
  }
567
755
 
568
- // src/utils/base64FromBytes.ts
569
- function base64FromBytes(arr) {
570
- if (globalThis.Buffer) {
571
- return globalThis.Buffer.from(arr).toString("base64");
572
- } else {
573
- const bin = [];
574
- arr.forEach((byte) => {
575
- bin.push(String.fromCharCode(byte));
576
- });
577
- return globalThis.btoa(bin.join(""));
578
- }
579
- }
580
-
581
756
  // src/tasks/cv/imageToImage.ts
582
757
  async function imageToImage(args, options) {
583
758
  let reqArgs;
@@ -629,6 +804,36 @@ async function zeroShotImageClassification(args, options) {
629
804
  return res;
630
805
  }
631
806
 
807
+ // src/lib/getDefaultTask.ts
808
+ var taskCache = /* @__PURE__ */ new Map();
809
+ var CACHE_DURATION = 10 * 60 * 1e3;
810
+ var MAX_CACHE_ITEMS = 1e3;
811
+ async function getDefaultTask(model, accessToken, options) {
812
+ if (isUrl(model)) {
813
+ return null;
814
+ }
815
+ const key = `${model}:${accessToken}`;
816
+ let cachedTask = taskCache.get(key);
817
+ if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
818
+ taskCache.delete(key);
819
+ cachedTask = void 0;
820
+ }
821
+ if (cachedTask === void 0) {
822
+ const modelTask = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
823
+ headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
824
+ }).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
825
+ if (!modelTask) {
826
+ return null;
827
+ }
828
+ cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
829
+ taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
830
+ if (taskCache.size > MAX_CACHE_ITEMS) {
831
+ taskCache.delete(taskCache.keys().next().value);
832
+ }
833
+ }
834
+ return cachedTask.task;
835
+ }
836
+
632
837
  // src/tasks/nlp/featureExtraction.ts
633
838
  async function featureExtraction(args, options) {
634
839
  const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : void 0;
@@ -750,17 +955,33 @@ function toArray(obj) {
750
955
 
751
956
  // src/tasks/nlp/textGeneration.ts
752
957
  async function textGeneration(args, options) {
753
- const res = toArray(
754
- await request(args, {
958
+ if (args.provider === "together") {
959
+ args.prompt = args.inputs;
960
+ const raw = await request(args, {
755
961
  ...options,
756
962
  taskHint: "text-generation"
757
- })
758
- );
759
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string");
760
- if (!isValidOutput) {
761
- throw new InferenceOutputError("Expected Array<{generated_text: string}>");
963
+ });
964
+ const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
965
+ if (!isValidOutput) {
966
+ throw new InferenceOutputError("Expected ChatCompletionOutput");
967
+ }
968
+ const completion = raw.choices[0];
969
+ return {
970
+ generated_text: completion.text
971
+ };
972
+ } else {
973
+ const res = toArray(
974
+ await request(args, {
975
+ ...options,
976
+ taskHint: "text-generation"
977
+ })
978
+ );
979
+ const isValidOutput = Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
980
+ if (!isValidOutput) {
981
+ throw new InferenceOutputError("Expected Array<{generated_text: string}>");
982
+ }
983
+ return res?.[0];
762
984
  }
763
- return res?.[0];
764
985
  }
765
986
 
766
987
  // src/tasks/nlp/textGenerationStream.ts
@@ -827,7 +1048,8 @@ async function chatCompletion(args, options) {
827
1048
  taskHint: "text-generation",
828
1049
  chatCompletion: true
829
1050
  });
830
- const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && typeof res?.system_fingerprint === "string" && typeof res?.usage === "object";
1051
+ const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai does not output a system_fingerprint
1052
+ (res.system_fingerprint === void 0 || typeof res.system_fingerprint === "string") && typeof res?.usage === "object";
831
1053
  if (!isValidOutput) {
832
1054
  throw new InferenceOutputError("Expected ChatCompletionOutput");
833
1055
  }
@@ -960,11 +1182,19 @@ var HfInferenceEndpoint = class {
960
1182
  }
961
1183
  }
962
1184
  };
1185
+
1186
+ // src/types.ts
1187
+ var INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-inference"];
963
1188
  // Annotate the CommonJS export names for ESM import in node:
964
1189
  0 && (module.exports = {
1190
+ FAL_AI_SUPPORTED_MODEL_IDS,
965
1191
  HfInference,
966
1192
  HfInferenceEndpoint,
1193
+ INFERENCE_PROVIDERS,
967
1194
  InferenceOutputError,
1195
+ REPLICATE_SUPPORTED_MODEL_IDS,
1196
+ SAMBANOVA_SUPPORTED_MODEL_IDS,
1197
+ TOGETHER_SUPPORTED_MODEL_IDS,
968
1198
  audioClassification,
969
1199
  audioToAudio,
970
1200
  automaticSpeechRecognition,