@huggingface/inference 2.8.0 → 3.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/LICENSE +1 -1
  2. package/README.md +39 -16
  3. package/dist/index.cjs +364 -134
  4. package/dist/index.js +359 -134
  5. package/dist/src/config.d.ts +3 -0
  6. package/dist/src/config.d.ts.map +1 -0
  7. package/dist/src/index.d.ts +5 -0
  8. package/dist/src/index.d.ts.map +1 -1
  9. package/dist/src/lib/getDefaultTask.d.ts +0 -1
  10. package/dist/src/lib/getDefaultTask.d.ts.map +1 -1
  11. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  12. package/dist/src/providers/fal-ai.d.ts +6 -0
  13. package/dist/src/providers/fal-ai.d.ts.map +1 -0
  14. package/dist/src/providers/replicate.d.ts +6 -0
  15. package/dist/src/providers/replicate.d.ts.map +1 -0
  16. package/dist/src/providers/sambanova.d.ts +6 -0
  17. package/dist/src/providers/sambanova.d.ts.map +1 -0
  18. package/dist/src/providers/together.d.ts +12 -0
  19. package/dist/src/providers/together.d.ts.map +1 -0
  20. package/dist/src/providers/types.d.ts +4 -0
  21. package/dist/src/providers/types.d.ts.map +1 -0
  22. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  23. package/dist/src/tasks/custom/request.d.ts +1 -1
  24. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  25. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  26. package/dist/src/tasks/cv/textToImage.d.ts +8 -0
  27. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  28. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  29. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  30. package/dist/src/types.d.ts +16 -2
  31. package/dist/src/types.d.ts.map +1 -1
  32. package/package.json +2 -2
  33. package/src/config.ts +2 -0
  34. package/src/index.ts +5 -0
  35. package/src/lib/getDefaultTask.ts +1 -1
  36. package/src/lib/makeRequestOptions.ts +199 -59
  37. package/src/providers/fal-ai.ts +15 -0
  38. package/src/providers/replicate.ts +16 -0
  39. package/src/providers/sambanova.ts +23 -0
  40. package/src/providers/together.ts +58 -0
  41. package/src/providers/types.ts +6 -0
  42. package/src/tasks/audio/automaticSpeechRecognition.ts +10 -1
  43. package/src/tasks/custom/request.ts +12 -6
  44. package/src/tasks/custom/streamingRequest.ts +18 -3
  45. package/src/tasks/cv/textToImage.ts +44 -1
  46. package/src/tasks/nlp/chatCompletion.ts +2 -2
  47. package/src/tasks/nlp/textGeneration.ts +43 -9
  48. package/src/types.ts +20 -2
package/dist/index.js CHANGED
@@ -40,131 +40,164 @@ __export(tasks_exports, {
40
40
  zeroShotImageClassification: () => zeroShotImageClassification
41
41
  });
42
42
 
43
- // src/utils/pick.ts
44
- function pick(o, props) {
45
- return Object.assign(
46
- {},
47
- ...props.map((prop) => {
48
- if (o[prop] !== void 0) {
49
- return { [prop]: o[prop] };
50
- }
51
- })
52
- );
53
- }
43
+ // src/config.ts
44
+ var HF_HUB_URL = "https://huggingface.co";
45
+ var HF_INFERENCE_API_URL = "https://api-inference.huggingface.co";
54
46
 
55
- // src/utils/typedInclude.ts
56
- function typedInclude(arr, v) {
57
- return arr.includes(v);
58
- }
47
+ // src/providers/fal-ai.ts
48
+ var FAL_AI_API_BASE_URL = "https://fal.run";
49
+ var FAL_AI_SUPPORTED_MODEL_IDS = {
50
+ "text-to-image": {
51
+ "black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
52
+ "black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev"
53
+ },
54
+ "automatic-speech-recognition": {
55
+ "openai/whisper-large-v3": "fal-ai/whisper"
56
+ }
57
+ };
59
58
 
60
- // src/utils/omit.ts
61
- function omit(o, props) {
62
- const propsArr = Array.isArray(props) ? props : [props];
63
- const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
64
- return pick(o, letsKeep);
65
- }
59
+ // src/providers/replicate.ts
60
+ var REPLICATE_API_BASE_URL = "https://api.replicate.com";
61
+ var REPLICATE_SUPPORTED_MODEL_IDS = {
62
+ "text-to-image": {
63
+ "black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
64
+ "ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637"
65
+ }
66
+ // "text-to-speech": {
67
+ // "SWivid/F5-TTS": "x-lance/f5-tts:87faf6dd7a692dd82043f662e76369cab126a2cf1937e25a9d41e0b834fd230e"
68
+ // },
69
+ };
70
+
71
+ // src/providers/sambanova.ts
72
+ var SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
73
+ var SAMBANOVA_SUPPORTED_MODEL_IDS = {
74
+ /** Chat completion / conversational */
75
+ conversational: {
76
+ "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
77
+ "Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
78
+ "Qwen/QwQ-32B-Preview": "QwQ-32B-Preview",
79
+ "meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct",
80
+ "meta-llama/Llama-3.2-1B-Instruct": "Meta-Llama-3.2-1B-Instruct",
81
+ "meta-llama/Llama-3.2-3B-Instruct": "Meta-Llama-3.2-3B-Instruct",
82
+ "meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
83
+ "meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct",
84
+ "meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
85
+ "meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct",
86
+ "meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct",
87
+ "meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B"
88
+ }
89
+ };
90
+
91
+ // src/providers/together.ts
92
+ var TOGETHER_API_BASE_URL = "https://api.together.xyz";
93
+ var TOGETHER_SUPPORTED_MODEL_IDS = {
94
+ "text-to-image": {
95
+ "black-forest-labs/FLUX.1-Canny-dev": "black-forest-labs/FLUX.1-canny",
96
+ "black-forest-labs/FLUX.1-Depth-dev": "black-forest-labs/FLUX.1-depth",
97
+ "black-forest-labs/FLUX.1-dev": "black-forest-labs/FLUX.1-dev",
98
+ "black-forest-labs/FLUX.1-Redux-dev": "black-forest-labs/FLUX.1-redux",
99
+ "black-forest-labs/FLUX.1-schnell": "black-forest-labs/FLUX.1-pro",
100
+ "stabilityai/stable-diffusion-xl-base-1.0": "stabilityai/stable-diffusion-xl-base-1.0"
101
+ },
102
+ conversational: {
103
+ "databricks/dbrx-instruct": "databricks/dbrx-instruct",
104
+ "deepseek-ai/deepseek-llm-67b-chat": "deepseek-ai/deepseek-llm-67b-chat",
105
+ "google/gemma-2-9b-it": "google/gemma-2-9b-it",
106
+ "google/gemma-2b-it": "google/gemma-2-27b-it",
107
+ "llava-hf/llava-v1.6-mistral-7b-hf": "llava-hf/llava-v1.6-mistral-7b-hf",
108
+ "meta-llama/Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
109
+ "meta-llama/Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
110
+ "meta-llama/Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
111
+ "meta-llama/Llama-3.2-11B-Vision-Instruct": "meta-llama/Llama-Vision-Free",
112
+ "meta-llama/Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
113
+ "meta-llama/Llama-3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
114
+ "meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
115
+ "meta-llama/Meta-Llama-3-70B-Instruct": "meta-llama/Llama-3-70b-chat-hf",
116
+ "meta-llama/Meta-Llama-3-8B-Instruct": "togethercomputer/Llama-3-8b-chat-hf-int4",
117
+ "meta-llama/Meta-Llama-3.1-405B-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
118
+ "meta-llama/Meta-Llama-3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
119
+ "meta-llama/Meta-Llama-3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K",
120
+ "microsoft/WizardLM-2-8x22B": "microsoft/WizardLM-2-8x22B",
121
+ "mistralai/Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3",
122
+ "mistralai/Mixtral-8x22B-Instruct-v0.1": "mistralai/Mixtral-8x22B-Instruct-v0.1",
123
+ "mistralai/Mixtral-8x7B-Instruct-v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1",
124
+ "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
125
+ "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
126
+ "Qwen/Qwen2-72B-Instruct": "Qwen/Qwen2-72B-Instruct",
127
+ "Qwen/Qwen2.5-72B-Instruct": "Qwen/Qwen2.5-72B-Instruct-Turbo",
128
+ "Qwen/Qwen2.5-7B-Instruct": "Qwen/Qwen2.5-7B-Instruct-Turbo",
129
+ "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen/Qwen2.5-Coder-32B-Instruct",
130
+ "Qwen/QwQ-32B-Preview": "Qwen/QwQ-32B-Preview",
131
+ "scb10x/llama-3-typhoon-v1.5-8b-instruct": "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct",
132
+ "scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": "scb10x/scb10x-llama3-typhoon-v1-5x-4f316"
133
+ },
134
+ "text-generation": {
135
+ "meta-llama/Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B",
136
+ "mistralai/Mixtral-8x7B-v0.1": "mistralai/Mixtral-8x7B-v0.1"
137
+ }
138
+ };
66
139
 
67
140
  // src/lib/isUrl.ts
68
141
  function isUrl(modelOrUrl) {
69
142
  return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
70
143
  }
71
144
 
72
- // src/lib/getDefaultTask.ts
73
- var taskCache = /* @__PURE__ */ new Map();
74
- var CACHE_DURATION = 10 * 60 * 1e3;
75
- var MAX_CACHE_ITEMS = 1e3;
76
- var HF_HUB_URL = "https://huggingface.co";
77
- async function getDefaultTask(model, accessToken, options) {
78
- if (isUrl(model)) {
79
- return null;
80
- }
81
- const key = `${model}:${accessToken}`;
82
- let cachedTask = taskCache.get(key);
83
- if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
84
- taskCache.delete(key);
85
- cachedTask = void 0;
86
- }
87
- if (cachedTask === void 0) {
88
- const modelTask = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
89
- headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
90
- }).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
91
- if (!modelTask) {
92
- return null;
93
- }
94
- cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
95
- taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
96
- if (taskCache.size > MAX_CACHE_ITEMS) {
97
- taskCache.delete(taskCache.keys().next().value);
98
- }
99
- }
100
- return cachedTask.task;
101
- }
102
-
103
145
  // src/lib/makeRequestOptions.ts
104
- var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
146
+ var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_HUB_URL}/api/inference-proxy/{{PROVIDER}}`;
105
147
  var tasks = null;
106
148
  async function makeRequestOptions(args, options) {
107
- const { accessToken, endpointUrl, ...otherArgs } = args;
108
- let { model } = args;
109
- const {
110
- forceTask: task,
111
- includeCredentials,
112
- taskHint,
113
- wait_for_model,
114
- use_cache,
115
- dont_load_model,
116
- chatCompletion: chatCompletion2
117
- } = options ?? {};
118
- const headers = {};
119
- if (accessToken) {
120
- headers["Authorization"] = `Bearer ${accessToken}`;
149
+ const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...otherArgs } = args;
150
+ const provider = maybeProvider ?? "hf-inference";
151
+ const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion: chatCompletion2 } = options ?? {};
152
+ if (endpointUrl && provider !== "hf-inference") {
153
+ throw new Error(`Cannot use endpointUrl with a third-party provider.`);
121
154
  }
122
- if (!model && !tasks && taskHint) {
123
- const res = await fetch(`${HF_HUB_URL}/api/tasks`);
124
- if (res.ok) {
125
- tasks = await res.json();
126
- }
155
+ if (forceTask && provider !== "hf-inference") {
156
+ throw new Error(`Cannot use forceTask with a third-party provider.`);
157
+ }
158
+ if (maybeModel && isUrl(maybeModel)) {
159
+ throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
127
160
  }
128
- if (!model && tasks && taskHint) {
129
- const taskInfo = tasks[taskHint];
130
- if (taskInfo) {
131
- model = taskInfo.models[0].id;
161
+ let model;
162
+ if (!maybeModel) {
163
+ if (taskHint) {
164
+ model = mapModel({ model: await loadDefaultModel(taskHint), provider, taskHint, chatCompletion: chatCompletion2 });
165
+ } else {
166
+ throw new Error("No model provided, and no default model found for this task");
132
167
  }
168
+ } else {
169
+ model = mapModel({ model: maybeModel, provider, taskHint, chatCompletion: chatCompletion2 });
133
170
  }
134
- if (!model) {
135
- throw new Error("No model provided, and no default model found for this task");
171
+ const authMethod = accessToken ? accessToken.startsWith("hf_") ? "hf-token" : "provider-key" : includeCredentials === "include" ? "credentials-include" : "none";
172
+ const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : makeUrl({
173
+ authMethod,
174
+ chatCompletion: chatCompletion2 ?? false,
175
+ forceTask,
176
+ model,
177
+ provider: provider ?? "hf-inference",
178
+ taskHint
179
+ });
180
+ const headers = {};
181
+ if (accessToken) {
182
+ headers["Authorization"] = provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
136
183
  }
137
184
  const binary = "data" in args && !!args.data;
138
185
  if (!binary) {
139
186
  headers["Content-Type"] = "application/json";
140
187
  }
141
- if (wait_for_model) {
142
- headers["X-Wait-For-Model"] = "true";
143
- }
144
- if (use_cache === false) {
145
- headers["X-Use-Cache"] = "false";
146
- }
147
- if (dont_load_model) {
148
- headers["X-Load-Model"] = "0";
149
- }
150
- let url = (() => {
151
- if (endpointUrl && isUrl(model)) {
152
- throw new TypeError("Both model and endpointUrl cannot be URLs");
153
- }
154
- if (isUrl(model)) {
155
- console.warn("Using a model URL is deprecated, please use the `endpointUrl` parameter instead");
156
- return model;
188
+ if (provider === "hf-inference") {
189
+ if (wait_for_model) {
190
+ headers["X-Wait-For-Model"] = "true";
157
191
  }
158
- if (endpointUrl) {
159
- return endpointUrl;
192
+ if (use_cache === false) {
193
+ headers["X-Use-Cache"] = "false";
160
194
  }
161
- if (task) {
162
- return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
195
+ if (dont_load_model) {
196
+ headers["X-Load-Model"] = "0";
163
197
  }
164
- return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
165
- })();
166
- if (chatCompletion2 && !url.endsWith("/chat/completions")) {
167
- url += "/v1/chat/completions";
198
+ }
199
+ if (provider === "replicate") {
200
+ headers["Prefer"] = "wait";
168
201
  }
169
202
  let credentials;
170
203
  if (typeof includeCredentials === "string") {
@@ -172,17 +205,110 @@ async function makeRequestOptions(args, options) {
172
205
  } else if (includeCredentials === true) {
173
206
  credentials = "include";
174
207
  }
208
+ if (provider === "replicate" && model.includes(":")) {
209
+ const version = model.split(":")[1];
210
+ otherArgs.version = version;
211
+ }
175
212
  const info = {
176
213
  headers,
177
214
  method: "POST",
178
215
  body: binary ? args.data : JSON.stringify({
179
- ...otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs
216
+ ...otherArgs,
217
+ ...chatCompletion2 || provider === "together" ? { model } : void 0
180
218
  }),
181
- ...credentials && { credentials },
219
+ ...credentials ? { credentials } : void 0,
182
220
  signal: options?.signal
183
221
  };
184
222
  return { url, info };
185
223
  }
224
+ function mapModel(params) {
225
+ if (params.provider === "hf-inference") {
226
+ return params.model;
227
+ }
228
+ if (!params.taskHint) {
229
+ throw new Error("taskHint must be specified when using a third-party provider");
230
+ }
231
+ const task = params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint;
232
+ const model = (() => {
233
+ switch (params.provider) {
234
+ case "fal-ai":
235
+ return FAL_AI_SUPPORTED_MODEL_IDS[task]?.[params.model];
236
+ case "replicate":
237
+ return REPLICATE_SUPPORTED_MODEL_IDS[task]?.[params.model];
238
+ case "sambanova":
239
+ return SAMBANOVA_SUPPORTED_MODEL_IDS[task]?.[params.model];
240
+ case "together":
241
+ return TOGETHER_SUPPORTED_MODEL_IDS[task]?.[params.model];
242
+ }
243
+ })();
244
+ if (!model) {
245
+ throw new Error(`Model ${params.model} is not supported for task ${task} and provider ${params.provider}`);
246
+ }
247
+ return model;
248
+ }
249
+ function makeUrl(params) {
250
+ if (params.authMethod === "none" && params.provider !== "hf-inference") {
251
+ throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
252
+ }
253
+ const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
254
+ switch (params.provider) {
255
+ case "fal-ai": {
256
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FAL_AI_API_BASE_URL;
257
+ return `${baseUrl}/${params.model}`;
258
+ }
259
+ case "replicate": {
260
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : REPLICATE_API_BASE_URL;
261
+ if (params.model.includes(":")) {
262
+ return `${baseUrl}/v1/predictions`;
263
+ }
264
+ return `${baseUrl}/v1/models/${params.model}/predictions`;
265
+ }
266
+ case "sambanova": {
267
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : SAMBANOVA_API_BASE_URL;
268
+ if (params.taskHint === "text-generation" && params.chatCompletion) {
269
+ return `${baseUrl}/v1/chat/completions`;
270
+ }
271
+ return baseUrl;
272
+ }
273
+ case "together": {
274
+ const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : TOGETHER_API_BASE_URL;
275
+ if (params.taskHint === "text-to-image") {
276
+ return `${baseUrl}/v1/images/generations`;
277
+ }
278
+ if (params.taskHint === "text-generation") {
279
+ if (params.chatCompletion) {
280
+ return `${baseUrl}/v1/chat/completions`;
281
+ }
282
+ return `${baseUrl}/v1/completions`;
283
+ }
284
+ return baseUrl;
285
+ }
286
+ default: {
287
+ const url = params.forceTask ? `${HF_INFERENCE_API_URL}/pipeline/${params.forceTask}/${params.model}` : `${HF_INFERENCE_API_URL}/models/${params.model}`;
288
+ if (params.taskHint === "text-generation" && params.chatCompletion) {
289
+ return url + `/v1/chat/completions`;
290
+ }
291
+ return url;
292
+ }
293
+ }
294
+ }
295
+ async function loadDefaultModel(task) {
296
+ if (!tasks) {
297
+ tasks = await loadTaskInfo();
298
+ }
299
+ const taskInfo = tasks[task];
300
+ if ((taskInfo?.models.length ?? 0) <= 0) {
301
+ throw new Error(`No default model defined for task ${task}, please define the model explicitly.`);
302
+ }
303
+ return taskInfo.models[0].id;
304
+ }
305
+ async function loadTaskInfo() {
306
+ const res = await fetch(`${HF_HUB_URL}/api/tasks`);
307
+ if (!res.ok) {
308
+ throw new Error("Failed to load tasks definitions from Hugging Face Hub.");
309
+ }
310
+ return await res.json();
311
+ }
186
312
 
187
313
  // src/tasks/custom/request.ts
188
314
  async function request(args, options) {
@@ -195,16 +321,22 @@ async function request(args, options) {
195
321
  });
196
322
  }
197
323
  if (!response.ok) {
198
- if (response.headers.get("Content-Type")?.startsWith("application/json")) {
324
+ const contentType = response.headers.get("Content-Type");
325
+ if (["application/json", "application/problem+json"].some((ct) => contentType?.startsWith(ct))) {
199
326
  const output = await response.json();
200
327
  if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
201
- throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
328
+ throw new Error(
329
+ `Server ${args.model} does not seem to support chat completion. Error: ${JSON.stringify(output.error)}`
330
+ );
202
331
  }
203
- if (output.error) {
204
- throw new Error(output.error);
332
+ if (output.error || output.detail) {
333
+ throw new Error(JSON.stringify(output.error ?? output.detail));
334
+ } else {
335
+ throw new Error(output);
205
336
  }
206
337
  }
207
- throw new Error("An error occurred while fetching the blob");
338
+ const message = contentType?.startsWith("text/plain;") ? await response.text() : void 0;
339
+ throw new Error(message ?? "An error occurred while fetching the blob");
208
340
  }
209
341
  if (response.headers.get("Content-Type")?.startsWith("application/json")) {
210
342
  return await response.json();
@@ -327,9 +459,12 @@ async function* streamingRequest(args, options) {
327
459
  if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
328
460
  throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
329
461
  }
330
- if (output.error) {
462
+ if (typeof output.error === "string") {
331
463
  throw new Error(output.error);
332
464
  }
465
+ if (output.error && "message" in output.error && typeof output.error.message === "string") {
466
+ throw new Error(output.error.message);
467
+ }
333
468
  }
334
469
  throw new Error(`Server response contains error: ${response.status}`);
335
470
  }
@@ -358,8 +493,9 @@ async function* streamingRequest(args, options) {
358
493
  try {
359
494
  while (true) {
360
495
  const { done, value } = await reader.read();
361
- if (done)
496
+ if (done) {
362
497
  return;
498
+ }
363
499
  onChunk(value);
364
500
  for (const event of events) {
365
501
  if (event.data.length > 0) {
@@ -368,7 +504,8 @@ async function* streamingRequest(args, options) {
368
504
  }
369
505
  const data = JSON.parse(event.data);
370
506
  if (typeof data === "object" && data !== null && "error" in data) {
371
- throw new Error(data.error);
507
+ const errorStr = typeof data.error === "string" ? data.error : typeof data.error === "object" && data.error && "message" in data.error && typeof data.error.message === "string" ? data.error.message : JSON.stringify(data.error);
508
+ throw new Error(`Error forwarded from backend: ` + errorStr);
372
509
  }
373
510
  yield data;
374
511
  }
@@ -403,8 +540,29 @@ async function audioClassification(args, options) {
403
540
  return res;
404
541
  }
405
542
 
543
+ // src/utils/base64FromBytes.ts
544
+ function base64FromBytes(arr) {
545
+ if (globalThis.Buffer) {
546
+ return globalThis.Buffer.from(arr).toString("base64");
547
+ } else {
548
+ const bin = [];
549
+ arr.forEach((byte) => {
550
+ bin.push(String.fromCharCode(byte));
551
+ });
552
+ return globalThis.btoa(bin.join(""));
553
+ }
554
+ }
555
+
406
556
  // src/tasks/audio/automaticSpeechRecognition.ts
407
557
  async function automaticSpeechRecognition(args, options) {
558
+ if (args.provider === "fal-ai") {
559
+ const contentType = args.data instanceof Blob ? args.data.type : "audio/mpeg";
560
+ const base64audio = base64FromBytes(
561
+ new Uint8Array(args.data instanceof ArrayBuffer ? args.data : await args.data.arrayBuffer())
562
+ );
563
+ args.audio_url = `data:${contentType};base64,${base64audio}`;
564
+ delete args.data;
565
+ }
408
566
  const res = await request(args, {
409
567
  ...options,
410
568
  taskHint: "automatic-speech-recognition"
@@ -501,10 +659,35 @@ async function objectDetection(args, options) {
501
659
 
502
660
  // src/tasks/cv/textToImage.ts
503
661
  async function textToImage(args, options) {
662
+ if (args.provider === "together" || args.provider === "fal-ai") {
663
+ args.prompt = args.inputs;
664
+ args.inputs = "";
665
+ args.response_format = "base64";
666
+ } else if (args.provider === "replicate") {
667
+ args.input = { prompt: args.inputs };
668
+ delete args.inputs;
669
+ }
504
670
  const res = await request(args, {
505
671
  ...options,
506
672
  taskHint: "text-to-image"
507
673
  });
674
+ if (res && typeof res === "object") {
675
+ if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
676
+ const image = await fetch(res.images[0].url);
677
+ return await image.blob();
678
+ }
679
+ if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
680
+ const base64Data = res.data[0].b64_json;
681
+ const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
682
+ const blob = await base64Response.blob();
683
+ return blob;
684
+ }
685
+ if ("output" in res && Array.isArray(res.output)) {
686
+ const urlResponse = await fetch(res.output[0]);
687
+ const blob = await urlResponse.blob();
688
+ return blob;
689
+ }
690
+ }
508
691
  const isValidOutput = res && res instanceof Blob;
509
692
  if (!isValidOutput) {
510
693
  throw new InferenceOutputError("Expected Blob");
@@ -512,19 +695,6 @@ async function textToImage(args, options) {
512
695
  return res;
513
696
  }
514
697
 
515
- // src/utils/base64FromBytes.ts
516
- function base64FromBytes(arr) {
517
- if (globalThis.Buffer) {
518
- return globalThis.Buffer.from(arr).toString("base64");
519
- } else {
520
- const bin = [];
521
- arr.forEach((byte) => {
522
- bin.push(String.fromCharCode(byte));
523
- });
524
- return globalThis.btoa(bin.join(""));
525
- }
526
- }
527
-
528
698
  // src/tasks/cv/imageToImage.ts
529
699
  async function imageToImage(args, options) {
530
700
  let reqArgs;
@@ -576,6 +746,36 @@ async function zeroShotImageClassification(args, options) {
576
746
  return res;
577
747
  }
578
748
 
749
+ // src/lib/getDefaultTask.ts
750
+ var taskCache = /* @__PURE__ */ new Map();
751
+ var CACHE_DURATION = 10 * 60 * 1e3;
752
+ var MAX_CACHE_ITEMS = 1e3;
753
+ async function getDefaultTask(model, accessToken, options) {
754
+ if (isUrl(model)) {
755
+ return null;
756
+ }
757
+ const key = `${model}:${accessToken}`;
758
+ let cachedTask = taskCache.get(key);
759
+ if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
760
+ taskCache.delete(key);
761
+ cachedTask = void 0;
762
+ }
763
+ if (cachedTask === void 0) {
764
+ const modelTask = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
765
+ headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
766
+ }).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
767
+ if (!modelTask) {
768
+ return null;
769
+ }
770
+ cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
771
+ taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
772
+ if (taskCache.size > MAX_CACHE_ITEMS) {
773
+ taskCache.delete(taskCache.keys().next().value);
774
+ }
775
+ }
776
+ return cachedTask.task;
777
+ }
778
+
579
779
  // src/tasks/nlp/featureExtraction.ts
580
780
  async function featureExtraction(args, options) {
581
781
  const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : void 0;
@@ -697,17 +897,33 @@ function toArray(obj) {
697
897
 
698
898
  // src/tasks/nlp/textGeneration.ts
699
899
  async function textGeneration(args, options) {
700
- const res = toArray(
701
- await request(args, {
900
+ if (args.provider === "together") {
901
+ args.prompt = args.inputs;
902
+ const raw = await request(args, {
702
903
  ...options,
703
904
  taskHint: "text-generation"
704
- })
705
- );
706
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string");
707
- if (!isValidOutput) {
708
- throw new InferenceOutputError("Expected Array<{generated_text: string}>");
905
+ });
906
+ const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
907
+ if (!isValidOutput) {
908
+ throw new InferenceOutputError("Expected ChatCompletionOutput");
909
+ }
910
+ const completion = raw.choices[0];
911
+ return {
912
+ generated_text: completion.text
913
+ };
914
+ } else {
915
+ const res = toArray(
916
+ await request(args, {
917
+ ...options,
918
+ taskHint: "text-generation"
919
+ })
920
+ );
921
+ const isValidOutput = Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
922
+ if (!isValidOutput) {
923
+ throw new InferenceOutputError("Expected Array<{generated_text: string}>");
924
+ }
925
+ return res?.[0];
709
926
  }
710
- return res?.[0];
711
927
  }
712
928
 
713
929
  // src/tasks/nlp/textGenerationStream.ts
@@ -774,7 +990,8 @@ async function chatCompletion(args, options) {
774
990
  taskHint: "text-generation",
775
991
  chatCompletion: true
776
992
  });
777
- const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && typeof res?.system_fingerprint === "string" && typeof res?.usage === "object";
993
+ const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai does not output a system_fingerprint
994
+ (res.system_fingerprint === void 0 || typeof res.system_fingerprint === "string") && typeof res?.usage === "object";
778
995
  if (!isValidOutput) {
779
996
  throw new InferenceOutputError("Expected ChatCompletionOutput");
780
997
  }
@@ -907,10 +1124,18 @@ var HfInferenceEndpoint = class {
907
1124
  }
908
1125
  }
909
1126
  };
1127
+
1128
+ // src/types.ts
1129
+ var INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-inference"];
910
1130
  export {
1131
+ FAL_AI_SUPPORTED_MODEL_IDS,
911
1132
  HfInference,
912
1133
  HfInferenceEndpoint,
1134
+ INFERENCE_PROVIDERS,
913
1135
  InferenceOutputError,
1136
+ REPLICATE_SUPPORTED_MODEL_IDS,
1137
+ SAMBANOVA_SUPPORTED_MODEL_IDS,
1138
+ TOGETHER_SUPPORTED_MODEL_IDS,
914
1139
  audioClassification,
915
1140
  audioToAudio,
916
1141
  automaticSpeechRecognition,
@@ -0,0 +1,3 @@
1
+ export declare const HF_HUB_URL = "https://huggingface.co";
2
+ export declare const HF_INFERENCE_API_URL = "https://api-inference.huggingface.co";
3
+ //# sourceMappingURL=config.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"config.d.ts","sourceRoot":"","sources":["../../src/config.ts"],"names":[],"mappings":"AAAA,eAAO,MAAM,UAAU,2BAA2B,CAAC;AACnD,eAAO,MAAM,oBAAoB,yCAAyC,CAAC"}
@@ -1,5 +1,10 @@
1
+ export type { ProviderMapping } from "./providers/types";
1
2
  export { HfInference, HfInferenceEndpoint } from "./HfInference";
2
3
  export { InferenceOutputError } from "./lib/InferenceOutputError";
4
+ export { FAL_AI_SUPPORTED_MODEL_IDS } from "./providers/fal-ai";
5
+ export { REPLICATE_SUPPORTED_MODEL_IDS } from "./providers/replicate";
6
+ export { SAMBANOVA_SUPPORTED_MODEL_IDS } from "./providers/sambanova";
7
+ export { TOGETHER_SUPPORTED_MODEL_IDS } from "./providers/together";
3
8
  export * from "./types";
4
9
  export * from "./tasks";
5
10
  //# sourceMappingURL=index.d.ts.map
@@ -1 +1 @@
1
- {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,WAAW,EAAE,mBAAmB,EAAE,MAAM,eAAe,CAAC;AACjE,OAAO,EAAE,oBAAoB,EAAE,MAAM,4BAA4B,CAAC;AAClE,cAAc,SAAS,CAAC;AACxB,cAAc,SAAS,CAAC"}
1
+ {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../src/index.ts"],"names":[],"mappings":"AAAA,YAAY,EAAE,eAAe,EAAE,MAAM,mBAAmB,CAAC;AACzD,OAAO,EAAE,WAAW,EAAE,mBAAmB,EAAE,MAAM,eAAe,CAAC;AACjE,OAAO,EAAE,oBAAoB,EAAE,MAAM,4BAA4B,CAAC;AAClE,OAAO,EAAE,0BAA0B,EAAE,MAAM,oBAAoB,CAAC;AAChE,OAAO,EAAE,6BAA6B,EAAE,MAAM,uBAAuB,CAAC;AACtE,OAAO,EAAE,6BAA6B,EAAE,MAAM,uBAAuB,CAAC;AACtE,OAAO,EAAE,4BAA4B,EAAE,MAAM,sBAAsB,CAAC;AACpE,cAAc,SAAS,CAAC;AACxB,cAAc,SAAS,CAAC"}
@@ -1,4 +1,3 @@
1
- export declare const HF_HUB_URL = "https://huggingface.co";
2
1
  export interface DefaultTaskOptions {
3
2
  fetch?: typeof fetch;
4
3
  }
@@ -1 +1 @@
1
- {"version":3,"file":"getDefaultTask.d.ts","sourceRoot":"","sources":["../../../src/lib/getDefaultTask.ts"],"names":[],"mappings":"AAUA,eAAO,MAAM,UAAU,2BAA2B,CAAC;AAEnD,MAAM,WAAW,kBAAkB;IAClC,KAAK,CAAC,EAAE,OAAO,KAAK,CAAC;CACrB;AAED;;;;;GAKG;AACH,wBAAsB,cAAc,CACnC,KAAK,EAAE,MAAM,EACb,WAAW,EAAE,MAAM,GAAG,SAAS,EAC/B,OAAO,CAAC,EAAE,kBAAkB,GAC1B,OAAO,CAAC,MAAM,GAAG,IAAI,CAAC,CAkCxB"}
1
+ {"version":3,"file":"getDefaultTask.d.ts","sourceRoot":"","sources":["../../../src/lib/getDefaultTask.ts"],"names":[],"mappings":"AAYA,MAAM,WAAW,kBAAkB;IAClC,KAAK,CAAC,EAAE,OAAO,KAAK,CAAC;CACrB;AAED;;;;;GAKG;AACH,wBAAsB,cAAc,CACnC,KAAK,EAAE,MAAM,EACb,WAAW,EAAE,MAAM,GAAG,SAAS,EAC/B,OAAO,CAAC,EAAE,kBAAkB,GAC1B,OAAO,CAAC,MAAM,GAAG,IAAI,CAAC,CAkCxB"}