@huggingface/inference 3.7.0 → 3.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. package/dist/index.cjs +1152 -839
  2. package/dist/index.js +1154 -841
  3. package/dist/src/lib/getProviderHelper.d.ts +37 -0
  4. package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
  5. package/dist/src/lib/makeRequestOptions.d.ts +0 -2
  6. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  7. package/dist/src/providers/black-forest-labs.d.ts +14 -18
  8. package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
  9. package/dist/src/providers/cerebras.d.ts +4 -2
  10. package/dist/src/providers/cerebras.d.ts.map +1 -1
  11. package/dist/src/providers/cohere.d.ts +5 -2
  12. package/dist/src/providers/cohere.d.ts.map +1 -1
  13. package/dist/src/providers/fal-ai.d.ts +50 -3
  14. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  15. package/dist/src/providers/fireworks-ai.d.ts +5 -2
  16. package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
  17. package/dist/src/providers/hf-inference.d.ts +125 -2
  18. package/dist/src/providers/hf-inference.d.ts.map +1 -1
  19. package/dist/src/providers/hyperbolic.d.ts +31 -2
  20. package/dist/src/providers/hyperbolic.d.ts.map +1 -1
  21. package/dist/src/providers/nebius.d.ts +20 -18
  22. package/dist/src/providers/nebius.d.ts.map +1 -1
  23. package/dist/src/providers/novita.d.ts +21 -18
  24. package/dist/src/providers/novita.d.ts.map +1 -1
  25. package/dist/src/providers/openai.d.ts +4 -2
  26. package/dist/src/providers/openai.d.ts.map +1 -1
  27. package/dist/src/providers/providerHelper.d.ts +182 -0
  28. package/dist/src/providers/providerHelper.d.ts.map +1 -0
  29. package/dist/src/providers/replicate.d.ts +23 -19
  30. package/dist/src/providers/replicate.d.ts.map +1 -1
  31. package/dist/src/providers/sambanova.d.ts +4 -2
  32. package/dist/src/providers/sambanova.d.ts.map +1 -1
  33. package/dist/src/providers/together.d.ts +32 -2
  34. package/dist/src/providers/together.d.ts.map +1 -1
  35. package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
  36. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  37. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  38. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  39. package/dist/src/tasks/audio/utils.d.ts +2 -1
  40. package/dist/src/tasks/audio/utils.d.ts.map +1 -1
  41. package/dist/src/tasks/custom/request.d.ts +0 -2
  42. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  43. package/dist/src/tasks/custom/streamingRequest.d.ts +0 -2
  44. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  45. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  46. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  47. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  48. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  49. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  50. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  51. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
  52. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  53. package/dist/src/tasks/index.d.ts +6 -6
  54. package/dist/src/tasks/index.d.ts.map +1 -1
  55. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  56. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  57. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  58. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
  59. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  60. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  61. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  62. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  63. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  64. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  65. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  66. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  67. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  68. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  69. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  70. package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
  71. package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
  72. package/dist/src/types.d.ts +3 -13
  73. package/dist/src/types.d.ts.map +1 -1
  74. package/package.json +3 -3
  75. package/src/lib/getProviderHelper.ts +270 -0
  76. package/src/lib/makeRequestOptions.ts +34 -91
  77. package/src/providers/black-forest-labs.ts +73 -22
  78. package/src/providers/cerebras.ts +6 -27
  79. package/src/providers/cohere.ts +9 -28
  80. package/src/providers/fal-ai.ts +195 -77
  81. package/src/providers/fireworks-ai.ts +8 -29
  82. package/src/providers/hf-inference.ts +555 -34
  83. package/src/providers/hyperbolic.ts +107 -29
  84. package/src/providers/nebius.ts +65 -29
  85. package/src/providers/novita.ts +68 -32
  86. package/src/providers/openai.ts +6 -32
  87. package/src/providers/providerHelper.ts +354 -0
  88. package/src/providers/replicate.ts +124 -34
  89. package/src/providers/sambanova.ts +5 -30
  90. package/src/providers/together.ts +92 -28
  91. package/src/snippets/getInferenceSnippets.ts +16 -9
  92. package/src/snippets/templates.exported.ts +1 -1
  93. package/src/tasks/audio/audioClassification.ts +4 -7
  94. package/src/tasks/audio/audioToAudio.ts +3 -26
  95. package/src/tasks/audio/automaticSpeechRecognition.ts +4 -3
  96. package/src/tasks/audio/textToSpeech.ts +5 -29
  97. package/src/tasks/audio/utils.ts +2 -1
  98. package/src/tasks/custom/request.ts +0 -2
  99. package/src/tasks/custom/streamingRequest.ts +0 -2
  100. package/src/tasks/cv/imageClassification.ts +3 -7
  101. package/src/tasks/cv/imageSegmentation.ts +3 -8
  102. package/src/tasks/cv/imageToImage.ts +3 -6
  103. package/src/tasks/cv/imageToText.ts +3 -6
  104. package/src/tasks/cv/objectDetection.ts +3 -18
  105. package/src/tasks/cv/textToImage.ts +9 -137
  106. package/src/tasks/cv/textToVideo.ts +11 -62
  107. package/src/tasks/cv/zeroShotImageClassification.ts +3 -7
  108. package/src/tasks/index.ts +6 -6
  109. package/src/tasks/multimodal/documentQuestionAnswering.ts +3 -19
  110. package/src/tasks/multimodal/visualQuestionAnswering.ts +3 -11
  111. package/src/tasks/nlp/chatCompletion.ts +5 -20
  112. package/src/tasks/nlp/chatCompletionStream.ts +1 -2
  113. package/src/tasks/nlp/featureExtraction.ts +3 -18
  114. package/src/tasks/nlp/fillMask.ts +3 -16
  115. package/src/tasks/nlp/questionAnswering.ts +3 -22
  116. package/src/tasks/nlp/sentenceSimilarity.ts +3 -7
  117. package/src/tasks/nlp/summarization.ts +3 -6
  118. package/src/tasks/nlp/tableQuestionAnswering.ts +3 -27
  119. package/src/tasks/nlp/textClassification.ts +3 -8
  120. package/src/tasks/nlp/textGeneration.ts +12 -79
  121. package/src/tasks/nlp/tokenClassification.ts +3 -18
  122. package/src/tasks/nlp/translation.ts +3 -6
  123. package/src/tasks/nlp/zeroShotClassification.ts +3 -16
  124. package/src/tasks/tabular/tabularClassification.ts +3 -6
  125. package/src/tasks/tabular/tabularRegression.ts +3 -6
  126. package/src/types.ts +3 -14
package/dist/index.js CHANGED
@@ -41,442 +41,1088 @@ __export(tasks_exports, {
41
41
  zeroShotImageClassification: () => zeroShotImageClassification
42
42
  });
43
43
 
44
+ // package.json
45
+ var name = "@huggingface/inference";
46
+ var version = "3.7.1";
47
+
44
48
  // src/config.ts
45
49
  var HF_HUB_URL = "https://huggingface.co";
46
50
  var HF_ROUTER_URL = "https://router.huggingface.co";
47
51
  var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
48
52
 
49
- // src/providers/black-forest-labs.ts
50
- var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
51
- var makeBaseUrl = () => {
52
- return BLACK_FOREST_LABS_AI_API_BASE_URL;
53
+ // src/lib/InferenceOutputError.ts
54
+ var InferenceOutputError = class extends TypeError {
55
+ constructor(message) {
56
+ super(
57
+ `Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.`
58
+ );
59
+ this.name = "InferenceOutputError";
60
+ }
53
61
  };
54
- var makeBody = (params) => {
55
- return params.args;
62
+
63
+ // src/utils/delay.ts
64
+ function delay(ms) {
65
+ return new Promise((resolve) => {
66
+ setTimeout(() => resolve(), ms);
67
+ });
68
+ }
69
+
70
+ // src/utils/pick.ts
71
+ function pick(o, props) {
72
+ return Object.assign(
73
+ {},
74
+ ...props.map((prop) => {
75
+ if (o[prop] !== void 0) {
76
+ return { [prop]: o[prop] };
77
+ }
78
+ })
79
+ );
80
+ }
81
+
82
+ // src/utils/typedInclude.ts
83
+ function typedInclude(arr, v) {
84
+ return arr.includes(v);
85
+ }
86
+
87
+ // src/utils/omit.ts
88
+ function omit(o, props) {
89
+ const propsArr = Array.isArray(props) ? props : [props];
90
+ const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
91
+ return pick(o, letsKeep);
92
+ }
93
+
94
+ // src/utils/toArray.ts
95
+ function toArray(obj) {
96
+ if (Array.isArray(obj)) {
97
+ return obj;
98
+ }
99
+ return [obj];
100
+ }
101
+
102
+ // src/providers/providerHelper.ts
103
+ var TaskProviderHelper = class {
104
+ constructor(provider, baseUrl, clientSideRoutingOnly = false) {
105
+ this.provider = provider;
106
+ this.baseUrl = baseUrl;
107
+ this.clientSideRoutingOnly = clientSideRoutingOnly;
108
+ }
109
+ /**
110
+ * Prepare the base URL for the request
111
+ */
112
+ makeBaseUrl(params) {
113
+ return params.authMethod !== "provider-key" ? `${HF_ROUTER_URL}/${this.provider}` : this.baseUrl;
114
+ }
115
+ /**
116
+ * Prepare the body for the request
117
+ */
118
+ makeBody(params) {
119
+ if ("data" in params.args && !!params.args.data) {
120
+ return params.args.data;
121
+ }
122
+ return JSON.stringify(this.preparePayload(params));
123
+ }
124
+ /**
125
+ * Prepare the URL for the request
126
+ */
127
+ makeUrl(params) {
128
+ const baseUrl = this.makeBaseUrl(params);
129
+ const route = this.makeRoute(params).replace(/^\/+/, "");
130
+ return `${baseUrl}/${route}`;
131
+ }
132
+ /**
133
+ * Prepare the headers for the request
134
+ */
135
+ prepareHeaders(params, isBinary) {
136
+ const headers = { Authorization: `Bearer ${params.accessToken}` };
137
+ if (!isBinary) {
138
+ headers["Content-Type"] = "application/json";
139
+ }
140
+ return headers;
141
+ }
56
142
  };
57
- var makeHeaders = (params) => {
58
- if (params.authMethod === "provider-key") {
59
- return { "X-Key": `${params.accessToken}` };
60
- } else {
61
- return { Authorization: `Bearer ${params.accessToken}` };
143
+ var BaseConversationalTask = class extends TaskProviderHelper {
144
+ constructor(provider, baseUrl, clientSideRoutingOnly = false) {
145
+ super(provider, baseUrl, clientSideRoutingOnly);
146
+ }
147
+ makeRoute() {
148
+ return "v1/chat/completions";
149
+ }
150
+ preparePayload(params) {
151
+ return {
152
+ ...params.args,
153
+ model: params.model
154
+ };
155
+ }
156
+ async getResponse(response) {
157
+ if (typeof response === "object" && Array.isArray(response?.choices) && typeof response?.created === "number" && typeof response?.id === "string" && typeof response?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
158
+ (response.system_fingerprint === void 0 || response.system_fingerprint === null || typeof response.system_fingerprint === "string") && typeof response?.usage === "object") {
159
+ return response;
160
+ }
161
+ throw new InferenceOutputError("Expected ChatCompletionOutput");
62
162
  }
63
163
  };
64
- var makeUrl = (params) => {
65
- return `${params.baseUrl}/v1/${params.model}`;
164
+ var BaseTextGenerationTask = class extends TaskProviderHelper {
165
+ constructor(provider, baseUrl, clientSideRoutingOnly = false) {
166
+ super(provider, baseUrl, clientSideRoutingOnly);
167
+ }
168
+ preparePayload(params) {
169
+ return {
170
+ ...params.args,
171
+ model: params.model
172
+ };
173
+ }
174
+ makeRoute() {
175
+ return "v1/completions";
176
+ }
177
+ async getResponse(response) {
178
+ const res = toArray(response);
179
+ if (Array.isArray(res) && res.length > 0 && res.every(
180
+ (x) => typeof x === "object" && !!x && "generated_text" in x && typeof x.generated_text === "string"
181
+ )) {
182
+ return res[0];
183
+ }
184
+ throw new InferenceOutputError("Expected Array<{generated_text: string}>");
185
+ }
66
186
  };
67
- var BLACK_FOREST_LABS_CONFIG = {
68
- makeBaseUrl,
69
- makeBody,
70
- makeHeaders,
71
- makeUrl
187
+
188
+ // src/providers/black-forest-labs.ts
189
+ var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
190
+ var BlackForestLabsTextToImageTask = class extends TaskProviderHelper {
191
+ constructor() {
192
+ super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
193
+ }
194
+ preparePayload(params) {
195
+ return {
196
+ ...omit(params.args, ["inputs", "parameters"]),
197
+ ...params.args.parameters,
198
+ prompt: params.args.inputs
199
+ };
200
+ }
201
+ prepareHeaders(params, binary) {
202
+ const headers = {
203
+ Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`
204
+ };
205
+ if (!binary) {
206
+ headers["Content-Type"] = "application/json";
207
+ }
208
+ return headers;
209
+ }
210
+ makeRoute(params) {
211
+ if (!params) {
212
+ throw new Error("Params are required");
213
+ }
214
+ return `/v1/${params.model}`;
215
+ }
216
+ async getResponse(response, url, headers, outputType) {
217
+ const urlObj = new URL(response.polling_url);
218
+ for (let step = 0; step < 5; step++) {
219
+ await delay(1e3);
220
+ console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
221
+ urlObj.searchParams.set("attempt", step.toString(10));
222
+ const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
223
+ if (!resp.ok) {
224
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
225
+ }
226
+ const payload = await resp.json();
227
+ if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
228
+ if (outputType === "url") {
229
+ return payload.result.sample;
230
+ }
231
+ const image = await fetch(payload.result.sample);
232
+ return await image.blob();
233
+ }
234
+ }
235
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
236
+ }
72
237
  };
73
238
 
74
239
  // src/providers/cerebras.ts
75
- var CEREBRAS_API_BASE_URL = "https://api.cerebras.ai";
76
- var makeBaseUrl2 = () => {
77
- return CEREBRAS_API_BASE_URL;
240
+ var CerebrasConversationalTask = class extends BaseConversationalTask {
241
+ constructor() {
242
+ super("cerebras", "https://api.cerebras.ai");
243
+ }
78
244
  };
79
- var makeBody2 = (params) => {
80
- return {
81
- ...params.args,
82
- model: params.model
83
- };
245
+
246
+ // src/providers/cohere.ts
247
+ var CohereConversationalTask = class extends BaseConversationalTask {
248
+ constructor() {
249
+ super("cohere", "https://api.cohere.com");
250
+ }
251
+ makeRoute() {
252
+ return "/compatibility/v1/chat/completions";
253
+ }
254
+ };
255
+
256
+ // src/lib/isUrl.ts
257
+ function isUrl(modelOrUrl) {
258
+ return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
259
+ }
260
+
261
+ // src/providers/fal-ai.ts
262
+ var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
263
+ var FalAITask = class extends TaskProviderHelper {
264
+ constructor(url) {
265
+ super("fal-ai", url || "https://fal.run");
266
+ }
267
+ preparePayload(params) {
268
+ return params.args;
269
+ }
270
+ makeRoute(params) {
271
+ return `/${params.model}`;
272
+ }
273
+ prepareHeaders(params, binary) {
274
+ const headers = {
275
+ Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`
276
+ };
277
+ if (!binary) {
278
+ headers["Content-Type"] = "application/json";
279
+ }
280
+ return headers;
281
+ }
282
+ };
283
+ var FalAITextToImageTask = class extends FalAITask {
284
+ preparePayload(params) {
285
+ return {
286
+ ...omit(params.args, ["inputs", "parameters"]),
287
+ ...params.args.parameters,
288
+ sync_mode: true,
289
+ prompt: params.args.inputs
290
+ };
291
+ }
292
+ async getResponse(response, outputType) {
293
+ if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") {
294
+ if (outputType === "url") {
295
+ return response.images[0].url;
296
+ }
297
+ const urlResponse = await fetch(response.images[0].url);
298
+ return await urlResponse.blob();
299
+ }
300
+ throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
301
+ }
302
+ };
303
+ var FalAITextToVideoTask = class extends FalAITask {
304
+ constructor() {
305
+ super("https://queue.fal.run");
306
+ }
307
+ makeRoute(params) {
308
+ if (params.authMethod !== "provider-key") {
309
+ return `/${params.model}?_subdomain=queue`;
310
+ }
311
+ return `/${params.model}`;
312
+ }
313
+ preparePayload(params) {
314
+ return {
315
+ ...omit(params.args, ["inputs", "parameters"]),
316
+ ...params.args.parameters,
317
+ prompt: params.args.inputs
318
+ };
319
+ }
320
+ async getResponse(response, url, headers) {
321
+ if (!url || !headers) {
322
+ throw new InferenceOutputError("URL and headers are required for text-to-video task");
323
+ }
324
+ const requestId = response.request_id;
325
+ if (!requestId) {
326
+ throw new InferenceOutputError("No request ID found in the response");
327
+ }
328
+ let status = response.status;
329
+ const parsedUrl = new URL(url);
330
+ const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
331
+ const modelId = new URL(response.response_url).pathname;
332
+ const queryParams = parsedUrl.search;
333
+ const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
334
+ const resultUrl = `${baseUrl}${modelId}${queryParams}`;
335
+ while (status !== "COMPLETED") {
336
+ await delay(500);
337
+ const statusResponse = await fetch(statusUrl, { headers });
338
+ if (!statusResponse.ok) {
339
+ throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
340
+ }
341
+ try {
342
+ status = (await statusResponse.json()).status;
343
+ } catch (error) {
344
+ throw new InferenceOutputError("Failed to parse status response from fal-ai API");
345
+ }
346
+ }
347
+ const resultResponse = await fetch(resultUrl, { headers });
348
+ let result;
349
+ try {
350
+ result = await resultResponse.json();
351
+ } catch (error) {
352
+ throw new InferenceOutputError("Failed to parse result response from fal-ai API");
353
+ }
354
+ if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
355
+ const urlResponse = await fetch(result.video.url);
356
+ return await urlResponse.blob();
357
+ } else {
358
+ throw new InferenceOutputError(
359
+ "Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
360
+ );
361
+ }
362
+ }
84
363
  };
85
- var makeHeaders2 = (params) => {
86
- return { Authorization: `Bearer ${params.accessToken}` };
364
+ var FalAIAutomaticSpeechRecognitionTask = class extends FalAITask {
365
+ prepareHeaders(params, binary) {
366
+ const headers = super.prepareHeaders(params, binary);
367
+ headers["Content-Type"] = "application/json";
368
+ return headers;
369
+ }
370
+ async getResponse(response) {
371
+ const res = response;
372
+ if (typeof res?.text !== "string") {
373
+ throw new InferenceOutputError(
374
+ `Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
375
+ );
376
+ }
377
+ return { text: res.text };
378
+ }
87
379
  };
88
- var makeUrl2 = (params) => {
89
- return `${params.baseUrl}/v1/chat/completions`;
380
+ var FalAITextToSpeechTask = class extends FalAITask {
381
+ preparePayload(params) {
382
+ return {
383
+ ...omit(params.args, ["inputs", "parameters"]),
384
+ ...params.args.parameters,
385
+ lyrics: params.args.inputs
386
+ };
387
+ }
388
+ async getResponse(response) {
389
+ const res = response;
390
+ if (typeof res?.audio?.url !== "string") {
391
+ throw new InferenceOutputError(
392
+ `Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
393
+ );
394
+ }
395
+ try {
396
+ const urlResponse = await fetch(res.audio.url);
397
+ if (!urlResponse.ok) {
398
+ throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
399
+ }
400
+ return await urlResponse.blob();
401
+ } catch (error) {
402
+ throw new InferenceOutputError(
403
+ `Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)}`
404
+ );
405
+ }
406
+ }
90
407
  };
91
- var CEREBRAS_CONFIG = {
92
- makeBaseUrl: makeBaseUrl2,
93
- makeBody: makeBody2,
94
- makeHeaders: makeHeaders2,
95
- makeUrl: makeUrl2
408
+
409
+ // src/providers/fireworks-ai.ts
410
+ var FireworksConversationalTask = class extends BaseConversationalTask {
411
+ constructor() {
412
+ super("fireworks-ai", "https://api.fireworks.ai");
413
+ }
414
+ makeRoute() {
415
+ return "/inference/v1/chat/completions";
416
+ }
96
417
  };
97
418
 
98
- // src/providers/cohere.ts
99
- var COHERE_API_BASE_URL = "https://api.cohere.com";
100
- var makeBaseUrl3 = () => {
101
- return COHERE_API_BASE_URL;
419
+ // src/providers/hf-inference.ts
420
+ var HFInferenceTask = class extends TaskProviderHelper {
421
+ constructor() {
422
+ super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
423
+ }
424
+ preparePayload(params) {
425
+ return params.args;
426
+ }
427
+ makeUrl(params) {
428
+ if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
429
+ return params.model;
430
+ }
431
+ return super.makeUrl(params);
432
+ }
433
+ makeRoute(params) {
434
+ if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
435
+ return `pipeline/${params.task}/${params.model}`;
436
+ }
437
+ return `models/${params.model}`;
438
+ }
439
+ async getResponse(response) {
440
+ return response;
441
+ }
102
442
  };
103
- var makeBody3 = (params) => {
104
- return {
105
- ...params.args,
106
- model: params.model
107
- };
443
+ var HFInferenceTextToImageTask = class extends HFInferenceTask {
444
+ async getResponse(response, url, headers, outputType) {
445
+ if (!response) {
446
+ throw new InferenceOutputError("response is undefined");
447
+ }
448
+ if (typeof response == "object") {
449
+ if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
450
+ const base64Data = response.data[0].b64_json;
451
+ if (outputType === "url") {
452
+ return `data:image/jpeg;base64,${base64Data}`;
453
+ }
454
+ const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
455
+ return await base64Response.blob();
456
+ }
457
+ if ("output" in response && Array.isArray(response.output)) {
458
+ if (outputType === "url") {
459
+ return response.output[0];
460
+ }
461
+ const urlResponse = await fetch(response.output[0]);
462
+ const blob = await urlResponse.blob();
463
+ return blob;
464
+ }
465
+ }
466
+ if (response instanceof Blob) {
467
+ if (outputType === "url") {
468
+ const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
469
+ return `data:image/jpeg;base64,${b64}`;
470
+ }
471
+ return response;
472
+ }
473
+ throw new InferenceOutputError("Expected a Blob ");
474
+ }
475
+ };
476
+ var HFInferenceConversationalTask = class extends HFInferenceTask {
477
+ makeUrl(params) {
478
+ let url;
479
+ if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
480
+ url = params.model.trim();
481
+ } else {
482
+ url = `${this.makeBaseUrl(params)}/models/${params.model}`;
483
+ }
484
+ url = url.replace(/\/+$/, "");
485
+ if (url.endsWith("/v1")) {
486
+ url += "/chat/completions";
487
+ } else if (!url.endsWith("/chat/completions")) {
488
+ url += "/v1/chat/completions";
489
+ }
490
+ return url;
491
+ }
492
+ preparePayload(params) {
493
+ return {
494
+ ...params.args,
495
+ model: params.model
496
+ };
497
+ }
498
+ async getResponse(response) {
499
+ return response;
500
+ }
108
501
  };
109
- var makeHeaders3 = (params) => {
110
- return { Authorization: `Bearer ${params.accessToken}` };
502
+ var HFInferenceTextGenerationTask = class extends HFInferenceTask {
503
+ async getResponse(response) {
504
+ const res = toArray(response);
505
+ if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
506
+ return res?.[0];
507
+ }
508
+ throw new InferenceOutputError("Expected Array<{generated_text: string}>");
509
+ }
111
510
  };
112
- var makeUrl3 = (params) => {
113
- return `${params.baseUrl}/compatibility/v1/chat/completions`;
511
+ var HFInferenceAudioClassificationTask = class extends HFInferenceTask {
512
+ async getResponse(response) {
513
+ if (Array.isArray(response) && response.every(
514
+ (x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
515
+ )) {
516
+ return response;
517
+ }
518
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
519
+ }
114
520
  };
115
- var COHERE_CONFIG = {
116
- makeBaseUrl: makeBaseUrl3,
117
- makeBody: makeBody3,
118
- makeHeaders: makeHeaders3,
119
- makeUrl: makeUrl3
521
+ var HFInferenceAutomaticSpeechRecognitionTask = class extends HFInferenceTask {
522
+ async getResponse(response) {
523
+ return response;
524
+ }
120
525
  };
121
-
122
- // src/lib/InferenceOutputError.ts
123
- var InferenceOutputError = class extends TypeError {
124
- constructor(message) {
125
- super(
126
- `Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.`
526
+ var HFInferenceAudioToAudioTask = class extends HFInferenceTask {
527
+ async getResponse(response) {
528
+ if (!Array.isArray(response)) {
529
+ throw new InferenceOutputError("Expected Array");
530
+ }
531
+ if (!response.every((elem) => {
532
+ return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
533
+ })) {
534
+ throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
535
+ }
536
+ return response;
537
+ }
538
+ };
539
+ var HFInferenceDocumentQuestionAnsweringTask = class extends HFInferenceTask {
540
+ async getResponse(response) {
541
+ if (Array.isArray(response) && response.every(
542
+ (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
543
+ )) {
544
+ return response[0];
545
+ }
546
+ throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
547
+ }
548
+ };
549
+ var HFInferenceFeatureExtractionTask = class extends HFInferenceTask {
550
+ async getResponse(response) {
551
+ const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
552
+ if (curDepth > maxDepth)
553
+ return false;
554
+ if (arr.every((x) => Array.isArray(x))) {
555
+ return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
556
+ } else {
557
+ return arr.every((x) => typeof x === "number");
558
+ }
559
+ };
560
+ if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
561
+ return response;
562
+ }
563
+ throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
564
+ }
565
+ };
566
+ var HFInferenceImageClassificationTask = class extends HFInferenceTask {
567
+ async getResponse(response) {
568
+ if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
569
+ return response;
570
+ }
571
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
572
+ }
573
+ };
574
+ var HFInferenceImageSegmentationTask = class extends HFInferenceTask {
575
+ async getResponse(response) {
576
+ if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) {
577
+ return response;
578
+ }
579
+ throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
580
+ }
581
+ };
582
+ var HFInferenceImageToTextTask = class extends HFInferenceTask {
583
+ async getResponse(response) {
584
+ if (typeof response?.generated_text !== "string") {
585
+ throw new InferenceOutputError("Expected {generated_text: string}");
586
+ }
587
+ return response;
588
+ }
589
+ };
590
+ var HFInferenceImageToImageTask = class extends HFInferenceTask {
591
+ async getResponse(response) {
592
+ if (response instanceof Blob) {
593
+ return response;
594
+ }
595
+ throw new InferenceOutputError("Expected Blob");
596
+ }
597
+ };
598
+ var HFInferenceObjectDetectionTask = class extends HFInferenceTask {
599
+ async getResponse(response) {
600
+ if (Array.isArray(response) && response.every(
601
+ (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
602
+ )) {
603
+ return response;
604
+ }
605
+ throw new InferenceOutputError(
606
+ "Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
127
607
  );
128
- this.name = "InferenceOutputError";
129
608
  }
130
609
  };
131
-
132
- // src/lib/isUrl.ts
133
- function isUrl(modelOrUrl) {
134
- return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
135
- }
136
-
137
- // src/utils/delay.ts
138
- function delay(ms) {
139
- return new Promise((resolve) => {
140
- setTimeout(() => resolve(), ms);
141
- });
142
- }
143
-
144
- // src/providers/fal-ai.ts
145
- var FAL_AI_API_BASE_URL = "https://fal.run";
146
- var FAL_AI_API_BASE_URL_QUEUE = "https://queue.fal.run";
147
- var makeBaseUrl4 = (task) => {
148
- return task === "text-to-video" ? FAL_AI_API_BASE_URL_QUEUE : FAL_AI_API_BASE_URL;
610
+ var HFInferenceZeroShotImageClassificationTask = class extends HFInferenceTask {
611
+ async getResponse(response) {
612
+ if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
613
+ return response;
614
+ }
615
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
616
+ }
617
+ };
618
+ var HFInferenceTextClassificationTask = class extends HFInferenceTask {
619
+ async getResponse(response) {
620
+ const output = response?.[0];
621
+ if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
622
+ return output;
623
+ }
624
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
625
+ }
149
626
  };
150
- var makeBody4 = (params) => {
151
- return params.args;
627
+ var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask {
628
+ async getResponse(response) {
629
+ if (Array.isArray(response) ? response.every(
630
+ (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
631
+ ) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number") {
632
+ return Array.isArray(response) ? response[0] : response;
633
+ }
634
+ throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
635
+ }
152
636
  };
153
- var makeHeaders4 = (params) => {
154
- return {
155
- Authorization: params.authMethod === "provider-key" ? `Key ${params.accessToken}` : `Bearer ${params.accessToken}`
156
- };
637
+ var HFInferenceFillMaskTask = class extends HFInferenceTask {
638
+ async getResponse(response) {
639
+ if (Array.isArray(response) && response.every(
640
+ (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
641
+ )) {
642
+ return response;
643
+ }
644
+ throw new InferenceOutputError(
645
+ "Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
646
+ );
647
+ }
157
648
  };
158
- var makeUrl4 = (params) => {
159
- const baseUrl = `${params.baseUrl}/${params.model}`;
160
- if (params.authMethod !== "provider-key" && params.task === "text-to-video") {
161
- return `${baseUrl}?_subdomain=queue`;
162
- }
163
- return baseUrl;
164
- };
165
- var FAL_AI_CONFIG = {
166
- makeBaseUrl: makeBaseUrl4,
167
- makeBody: makeBody4,
168
- makeHeaders: makeHeaders4,
169
- makeUrl: makeUrl4
170
- };
171
- async function pollFalResponse(res, url, headers) {
172
- const requestId = res.request_id;
173
- if (!requestId) {
174
- throw new InferenceOutputError("No request ID found in the response");
175
- }
176
- let status = res.status;
177
- const parsedUrl = new URL(url);
178
- const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
179
- const modelId = new URL(res.response_url).pathname;
180
- const queryParams = parsedUrl.search;
181
- const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
182
- const resultUrl = `${baseUrl}${modelId}${queryParams}`;
183
- while (status !== "COMPLETED") {
184
- await delay(500);
185
- const statusResponse = await fetch(statusUrl, { headers });
186
- if (!statusResponse.ok) {
187
- throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
649
+ var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask {
650
+ async getResponse(response) {
651
+ if (Array.isArray(response) && response.every(
652
+ (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
653
+ )) {
654
+ return response;
188
655
  }
189
- try {
190
- status = (await statusResponse.json()).status;
191
- } catch (error) {
192
- throw new InferenceOutputError("Failed to parse status response from fal-ai API");
656
+ throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
657
+ }
658
+ };
659
+ var HFInferenceSentenceSimilarityTask = class extends HFInferenceTask {
660
+ async getResponse(response) {
661
+ if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
662
+ return response;
193
663
  }
664
+ throw new InferenceOutputError("Expected Array<number>");
194
665
  }
195
- const resultResponse = await fetch(resultUrl, { headers });
196
- let result;
197
- try {
198
- result = await resultResponse.json();
199
- } catch (error) {
200
- throw new InferenceOutputError("Failed to parse result response from fal-ai API");
666
+ };
667
+ var HFInferenceTableQuestionAnsweringTask = class extends HFInferenceTask {
668
+ static validate(elem) {
669
+ return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
670
+ (coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
671
+ );
201
672
  }
202
- if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
203
- const urlResponse = await fetch(result.video.url);
204
- return await urlResponse.blob();
205
- } else {
673
+ async getResponse(response) {
674
+ if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) {
675
+ return Array.isArray(response) ? response[0] : response;
676
+ }
206
677
  throw new InferenceOutputError(
207
- "Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
678
+ "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
208
679
  );
209
680
  }
210
- }
211
-
212
- // src/providers/fireworks-ai.ts
213
- var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai";
214
- var makeBaseUrl5 = () => {
215
- return FIREWORKS_AI_API_BASE_URL;
216
- };
217
- var makeBody5 = (params) => {
218
- return {
219
- ...params.args,
220
- ...params.chatCompletion ? { model: params.model } : void 0
221
- };
222
681
  };
223
- var makeHeaders5 = (params) => {
224
- return { Authorization: `Bearer ${params.accessToken}` };
225
- };
226
- var makeUrl5 = (params) => {
227
- if (params.chatCompletion) {
228
- return `${params.baseUrl}/inference/v1/chat/completions`;
682
+ var HFInferenceTokenClassificationTask = class extends HFInferenceTask {
683
+ async getResponse(response) {
684
+ if (Array.isArray(response) && response.every(
685
+ (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
686
+ )) {
687
+ return response;
688
+ }
689
+ throw new InferenceOutputError(
690
+ "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
691
+ );
229
692
  }
230
- return `${params.baseUrl}/inference`;
231
693
  };
232
- var FIREWORKS_AI_CONFIG = {
233
- makeBaseUrl: makeBaseUrl5,
234
- makeBody: makeBody5,
235
- makeHeaders: makeHeaders5,
236
- makeUrl: makeUrl5
694
+ var HFInferenceTranslationTask = class extends HFInferenceTask {
695
+ async getResponse(response) {
696
+ if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
697
+ return response?.length === 1 ? response?.[0] : response;
698
+ }
699
+ throw new InferenceOutputError("Expected Array<{translation_text: string}>");
700
+ }
237
701
  };
238
-
239
- // src/providers/hf-inference.ts
240
- var makeBaseUrl6 = () => {
241
- return `${HF_ROUTER_URL}/hf-inference`;
702
+ var HFInferenceSummarizationTask = class extends HFInferenceTask {
703
+ async getResponse(response) {
704
+ if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
705
+ return response?.[0];
706
+ }
707
+ throw new InferenceOutputError("Expected Array<{summary_text: string}>");
708
+ }
242
709
  };
243
- var makeBody6 = (params) => {
244
- return {
245
- ...params.args,
246
- ...params.chatCompletion ? { model: params.model } : void 0
247
- };
710
+ var HFInferenceTextToSpeechTask = class extends HFInferenceTask {
711
+ async getResponse(response) {
712
+ return response;
713
+ }
248
714
  };
249
- var makeHeaders6 = (params) => {
250
- return { Authorization: `Bearer ${params.accessToken}` };
715
+ var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
716
+ async getResponse(response) {
717
+ if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
718
+ return response;
719
+ }
720
+ throw new InferenceOutputError("Expected Array<number>");
721
+ }
251
722
  };
252
- var makeUrl6 = (params) => {
253
- if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
254
- return `${params.baseUrl}/pipeline/${params.task}/${params.model}`;
723
+ var HFInferenceVisualQuestionAnsweringTask = class extends HFInferenceTask {
724
+ async getResponse(response) {
725
+ if (Array.isArray(response) && response.every(
726
+ (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
727
+ )) {
728
+ return response[0];
729
+ }
730
+ throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
255
731
  }
256
- if (params.chatCompletion) {
257
- return `${params.baseUrl}/models/${params.model}/v1/chat/completions`;
732
+ };
733
+ var HFInferenceTabularRegressionTask = class extends HFInferenceTask {
734
+ async getResponse(response) {
735
+ if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
736
+ return response;
737
+ }
738
+ throw new InferenceOutputError("Expected Array<number>");
258
739
  }
259
- return `${params.baseUrl}/models/${params.model}`;
260
740
  };
261
- var HF_INFERENCE_CONFIG = {
262
- makeBaseUrl: makeBaseUrl6,
263
- makeBody: makeBody6,
264
- makeHeaders: makeHeaders6,
265
- makeUrl: makeUrl6
741
+ var HFInferenceTextToAudioTask = class extends HFInferenceTask {
742
+ async getResponse(response) {
743
+ return response;
744
+ }
266
745
  };
267
746
 
268
747
  // src/providers/hyperbolic.ts
269
748
  var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
270
- var makeBaseUrl7 = () => {
271
- return HYPERBOLIC_API_BASE_URL;
272
- };
273
- var makeBody7 = (params) => {
274
- return {
275
- ...params.args,
276
- ...params.task === "text-to-image" ? { model_name: params.model } : { model: params.model }
277
- };
278
- };
279
- var makeHeaders7 = (params) => {
280
- return { Authorization: `Bearer ${params.accessToken}` };
749
+ var HyperbolicConversationalTask = class extends BaseConversationalTask {
750
+ constructor() {
751
+ super("hyperbolic", HYPERBOLIC_API_BASE_URL);
752
+ }
281
753
  };
282
- var makeUrl7 = (params) => {
283
- if (params.task === "text-to-image") {
284
- return `${params.baseUrl}/v1/images/generations`;
754
+ var HyperbolicTextGenerationTask = class extends BaseTextGenerationTask {
755
+ constructor() {
756
+ super("hyperbolic", HYPERBOLIC_API_BASE_URL);
757
+ }
758
+ makeRoute() {
759
+ return "v1/chat/completions";
760
+ }
761
+ preparePayload(params) {
762
+ return {
763
+ messages: [{ content: params.args.inputs, role: "user" }],
764
+ ...params.args.parameters ? {
765
+ max_tokens: params.args.parameters.max_new_tokens,
766
+ ...omit(params.args.parameters, "max_new_tokens")
767
+ } : void 0,
768
+ ...omit(params.args, ["inputs", "parameters"]),
769
+ model: params.model
770
+ };
771
+ }
772
+ async getResponse(response) {
773
+ if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
774
+ const completion = response.choices[0];
775
+ return {
776
+ generated_text: completion.message.content
777
+ };
778
+ }
779
+ throw new InferenceOutputError("Expected Hyperbolic text generation response format");
285
780
  }
286
- return `${params.baseUrl}/v1/chat/completions`;
287
781
  };
288
- var HYPERBOLIC_CONFIG = {
289
- makeBaseUrl: makeBaseUrl7,
290
- makeBody: makeBody7,
291
- makeHeaders: makeHeaders7,
292
- makeUrl: makeUrl7
782
+ var HyperbolicTextToImageTask = class extends TaskProviderHelper {
783
+ constructor() {
784
+ super("hyperbolic", HYPERBOLIC_API_BASE_URL);
785
+ }
786
+ makeRoute(params) {
787
+ return `/v1/images/generations`;
788
+ }
789
+ preparePayload(params) {
790
+ return {
791
+ ...omit(params.args, ["inputs", "parameters"]),
792
+ ...params.args.parameters,
793
+ prompt: params.args.inputs,
794
+ model_name: params.model
795
+ };
796
+ }
797
+ async getResponse(response, url, headers, outputType) {
798
+ if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images[0] && typeof response.images[0].image === "string") {
799
+ if (outputType === "url") {
800
+ return `data:image/jpeg;base64,${response.images[0].image}`;
801
+ }
802
+ return fetch(`data:image/jpeg;base64,${response.images[0].image}`).then((res) => res.blob());
803
+ }
804
+ throw new InferenceOutputError("Expected Hyperbolic text-to-image response format");
805
+ }
293
806
  };
294
807
 
295
808
  // src/providers/nebius.ts
296
809
  var NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
297
- var makeBaseUrl8 = () => {
298
- return NEBIUS_API_BASE_URL;
299
- };
300
- var makeBody8 = (params) => {
301
- return {
302
- ...params.args,
303
- model: params.model
304
- };
810
+ var NebiusConversationalTask = class extends BaseConversationalTask {
811
+ constructor() {
812
+ super("nebius", NEBIUS_API_BASE_URL);
813
+ }
305
814
  };
306
- var makeHeaders8 = (params) => {
307
- return { Authorization: `Bearer ${params.accessToken}` };
815
+ var NebiusTextGenerationTask = class extends BaseTextGenerationTask {
816
+ constructor() {
817
+ super("nebius", NEBIUS_API_BASE_URL);
818
+ }
308
819
  };
309
- var makeUrl8 = (params) => {
310
- if (params.task === "text-to-image") {
311
- return `${params.baseUrl}/v1/images/generations`;
820
+ var NebiusTextToImageTask = class extends TaskProviderHelper {
821
+ constructor() {
822
+ super("nebius", NEBIUS_API_BASE_URL);
823
+ }
824
+ preparePayload(params) {
825
+ return {
826
+ ...omit(params.args, ["inputs", "parameters"]),
827
+ ...params.args.parameters,
828
+ response_format: "b64_json",
829
+ prompt: params.args.inputs,
830
+ model: params.model
831
+ };
312
832
  }
313
- if (params.chatCompletion) {
314
- return `${params.baseUrl}/v1/chat/completions`;
833
+ makeRoute(params) {
834
+ return "v1/images/generations";
315
835
  }
316
- if (params.task === "text-generation") {
317
- return `${params.baseUrl}/v1/completions`;
836
+ async getResponse(response, url, headers, outputType) {
837
+ if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
838
+ const base64Data = response.data[0].b64_json;
839
+ if (outputType === "url") {
840
+ return `data:image/jpeg;base64,${base64Data}`;
841
+ }
842
+ return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
843
+ }
844
+ throw new InferenceOutputError("Expected Nebius text-to-image response format");
318
845
  }
319
- return params.baseUrl;
320
- };
321
- var NEBIUS_CONFIG = {
322
- makeBaseUrl: makeBaseUrl8,
323
- makeBody: makeBody8,
324
- makeHeaders: makeHeaders8,
325
- makeUrl: makeUrl8
326
846
  };
327
847
 
328
848
  // src/providers/novita.ts
329
849
  var NOVITA_API_BASE_URL = "https://api.novita.ai";
330
- var makeBaseUrl9 = () => {
331
- return NOVITA_API_BASE_URL;
332
- };
333
- var makeBody9 = (params) => {
334
- return {
335
- ...params.args,
336
- ...params.chatCompletion ? { model: params.model } : void 0
337
- };
338
- };
339
- var makeHeaders9 = (params) => {
340
- return { Authorization: `Bearer ${params.accessToken}` };
850
+ var NovitaTextGenerationTask = class extends BaseTextGenerationTask {
851
+ constructor() {
852
+ super("novita", NOVITA_API_BASE_URL);
853
+ }
854
+ makeRoute() {
855
+ return "/v3/openai/chat/completions";
856
+ }
341
857
  };
342
- var makeUrl9 = (params) => {
343
- if (params.chatCompletion) {
344
- return `${params.baseUrl}/v3/openai/chat/completions`;
345
- } else if (params.task === "text-generation") {
346
- return `${params.baseUrl}/v3/openai/completions`;
347
- } else if (params.task === "text-to-video") {
348
- return `${params.baseUrl}/v3/hf/${params.model}`;
858
+ var NovitaConversationalTask = class extends BaseConversationalTask {
859
+ constructor() {
860
+ super("novita", NOVITA_API_BASE_URL);
861
+ }
862
+ makeRoute() {
863
+ return "/v3/openai/chat/completions";
349
864
  }
350
- return params.baseUrl;
351
865
  };
352
- var NOVITA_CONFIG = {
353
- makeBaseUrl: makeBaseUrl9,
354
- makeBody: makeBody9,
355
- makeHeaders: makeHeaders9,
356
- makeUrl: makeUrl9
866
+
867
+ // src/providers/openai.ts
868
+ var OPENAI_API_BASE_URL = "https://api.openai.com";
869
+ var OpenAIConversationalTask = class extends BaseConversationalTask {
870
+ constructor() {
871
+ super("openai", OPENAI_API_BASE_URL, true);
872
+ }
357
873
  };
358
874
 
359
875
  // src/providers/replicate.ts
360
- var REPLICATE_API_BASE_URL = "https://api.replicate.com";
361
- var makeBaseUrl10 = () => {
362
- return REPLICATE_API_BASE_URL;
363
- };
364
- var makeBody10 = (params) => {
365
- return {
366
- input: params.args,
367
- version: params.model.includes(":") ? params.model.split(":")[1] : void 0
368
- };
876
+ var ReplicateTask = class extends TaskProviderHelper {
877
+ constructor(url) {
878
+ super("replicate", url || "https://api.replicate.com");
879
+ }
880
+ makeRoute(params) {
881
+ if (params.model.includes(":")) {
882
+ return "v1/predictions";
883
+ }
884
+ return `v1/models/${params.model}/predictions`;
885
+ }
886
+ preparePayload(params) {
887
+ return {
888
+ input: {
889
+ ...omit(params.args, ["inputs", "parameters"]),
890
+ ...params.args.parameters,
891
+ prompt: params.args.inputs
892
+ },
893
+ version: params.model.includes(":") ? params.model.split(":")[1] : void 0
894
+ };
895
+ }
896
+ prepareHeaders(params, binary) {
897
+ const headers = { Authorization: `Bearer ${params.accessToken}`, Prefer: "wait" };
898
+ if (!binary) {
899
+ headers["Content-Type"] = "application/json";
900
+ }
901
+ return headers;
902
+ }
903
+ makeUrl(params) {
904
+ const baseUrl = this.makeBaseUrl(params);
905
+ if (params.model.includes(":")) {
906
+ return `${baseUrl}/v1/predictions`;
907
+ }
908
+ return `${baseUrl}/v1/models/${params.model}/predictions`;
909
+ }
369
910
  };
370
- var makeHeaders10 = (params) => {
371
- return { Authorization: `Bearer ${params.accessToken}`, Prefer: "wait" };
911
+ var ReplicateTextToImageTask = class extends ReplicateTask {
912
+ async getResponse(res, url, headers, outputType) {
913
+ if (typeof res === "object" && "output" in res && Array.isArray(res.output) && res.output.length > 0 && typeof res.output[0] === "string") {
914
+ if (outputType === "url") {
915
+ return res.output[0];
916
+ }
917
+ const urlResponse = await fetch(res.output[0]);
918
+ return await urlResponse.blob();
919
+ }
920
+ throw new InferenceOutputError("Expected Replicate text-to-image response format");
921
+ }
372
922
  };
373
- var makeUrl10 = (params) => {
374
- if (params.model.includes(":")) {
375
- return `${params.baseUrl}/v1/predictions`;
923
+ var ReplicateTextToSpeechTask = class extends ReplicateTask {
924
+ preparePayload(params) {
925
+ const payload = super.preparePayload(params);
926
+ const input = payload["input"];
927
+ if (typeof input === "object" && input !== null && "prompt" in input) {
928
+ const inputObj = input;
929
+ inputObj["text"] = inputObj["prompt"];
930
+ delete inputObj["prompt"];
931
+ }
932
+ return payload;
933
+ }
934
+ async getResponse(response) {
935
+ if (response instanceof Blob) {
936
+ return response;
937
+ }
938
+ if (response && typeof response === "object") {
939
+ if ("output" in response) {
940
+ if (typeof response.output === "string") {
941
+ const urlResponse = await fetch(response.output);
942
+ return await urlResponse.blob();
943
+ } else if (Array.isArray(response.output)) {
944
+ const urlResponse = await fetch(response.output[0]);
945
+ return await urlResponse.blob();
946
+ }
947
+ }
948
+ }
949
+ throw new InferenceOutputError("Expected Blob or object with output");
376
950
  }
377
- return `${params.baseUrl}/v1/models/${params.model}/predictions`;
378
951
  };
379
- var REPLICATE_CONFIG = {
380
- makeBaseUrl: makeBaseUrl10,
381
- makeBody: makeBody10,
382
- makeHeaders: makeHeaders10,
383
- makeUrl: makeUrl10
952
+ var ReplicateTextToVideoTask = class extends ReplicateTask {
953
+ async getResponse(response) {
954
+ if (typeof response === "object" && !!response && "output" in response && typeof response.output === "string" && isUrl(response.output)) {
955
+ const urlResponse = await fetch(response.output);
956
+ return await urlResponse.blob();
957
+ }
958
+ throw new InferenceOutputError("Expected { output: string }");
959
+ }
384
960
  };
385
961
 
386
962
  // src/providers/sambanova.ts
387
- var SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
388
- var makeBaseUrl11 = () => {
389
- return SAMBANOVA_API_BASE_URL;
390
- };
391
- var makeBody11 = (params) => {
392
- return {
393
- ...params.args,
394
- ...params.chatCompletion ? { model: params.model } : void 0
395
- };
396
- };
397
- var makeHeaders11 = (params) => {
398
- return { Authorization: `Bearer ${params.accessToken}` };
399
- };
400
- var makeUrl11 = (params) => {
401
- if (params.chatCompletion) {
402
- return `${params.baseUrl}/v1/chat/completions`;
963
+ var SambanovaConversationalTask = class extends BaseConversationalTask {
964
+ constructor() {
965
+ super("sambanova", "https://api.sambanova.ai");
403
966
  }
404
- return params.baseUrl;
405
- };
406
- var SAMBANOVA_CONFIG = {
407
- makeBaseUrl: makeBaseUrl11,
408
- makeBody: makeBody11,
409
- makeHeaders: makeHeaders11,
410
- makeUrl: makeUrl11
411
967
  };
412
968
 
413
969
  // src/providers/together.ts
414
970
  var TOGETHER_API_BASE_URL = "https://api.together.xyz";
415
- var makeBaseUrl12 = () => {
416
- return TOGETHER_API_BASE_URL;
417
- };
418
- var makeBody12 = (params) => {
419
- return {
420
- ...params.args,
421
- model: params.model
422
- };
423
- };
424
- var makeHeaders12 = (params) => {
425
- return { Authorization: `Bearer ${params.accessToken}` };
971
+ var TogetherConversationalTask = class extends BaseConversationalTask {
972
+ constructor() {
973
+ super("together", TOGETHER_API_BASE_URL);
974
+ }
426
975
  };
427
- var makeUrl12 = (params) => {
428
- if (params.task === "text-to-image") {
429
- return `${params.baseUrl}/v1/images/generations`;
976
+ var TogetherTextGenerationTask = class extends BaseTextGenerationTask {
977
+ constructor() {
978
+ super("together", TOGETHER_API_BASE_URL);
430
979
  }
431
- if (params.chatCompletion) {
432
- return `${params.baseUrl}/v1/chat/completions`;
980
+ preparePayload(params) {
981
+ return {
982
+ model: params.model,
983
+ ...params.args,
984
+ prompt: params.args.inputs
985
+ };
433
986
  }
434
- if (params.task === "text-generation") {
435
- return `${params.baseUrl}/v1/completions`;
987
+ async getResponse(response) {
988
+ if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
989
+ const completion = response.choices[0];
990
+ return {
991
+ generated_text: completion.text
992
+ };
993
+ }
994
+ throw new InferenceOutputError("Expected Together text generation response format");
436
995
  }
437
- return params.baseUrl;
438
996
  };
439
- var TOGETHER_CONFIG = {
440
- makeBaseUrl: makeBaseUrl12,
441
- makeBody: makeBody12,
442
- makeHeaders: makeHeaders12,
443
- makeUrl: makeUrl12
997
+ var TogetherTextToImageTask = class extends TaskProviderHelper {
998
+ constructor() {
999
+ super("together", TOGETHER_API_BASE_URL);
1000
+ }
1001
+ makeRoute() {
1002
+ return "v1/images/generations";
1003
+ }
1004
+ preparePayload(params) {
1005
+ return {
1006
+ ...omit(params.args, ["inputs", "parameters"]),
1007
+ ...params.args.parameters,
1008
+ prompt: params.args.inputs,
1009
+ response_format: "base64",
1010
+ model: params.model
1011
+ };
1012
+ }
1013
+ async getResponse(response, outputType) {
1014
+ if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
1015
+ const base64Data = response.data[0].b64_json;
1016
+ if (outputType === "url") {
1017
+ return `data:image/jpeg;base64,${base64Data}`;
1018
+ }
1019
+ return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
1020
+ }
1021
+ throw new InferenceOutputError("Expected Together text-to-image response format");
1022
+ }
444
1023
  };
445
1024
 
446
- // src/providers/openai.ts
447
- var OPENAI_API_BASE_URL = "https://api.openai.com";
448
- var makeBaseUrl13 = () => {
449
- return OPENAI_API_BASE_URL;
450
- };
451
- var makeBody13 = (params) => {
452
- if (!params.chatCompletion) {
453
- throw new Error("OpenAI only supports chat completions.");
1025
+ // src/lib/getProviderHelper.ts
1026
+ var PROVIDERS = {
1027
+ "black-forest-labs": {
1028
+ "text-to-image": new BlackForestLabsTextToImageTask()
1029
+ },
1030
+ cerebras: {
1031
+ conversational: new CerebrasConversationalTask()
1032
+ },
1033
+ cohere: {
1034
+ conversational: new CohereConversationalTask()
1035
+ },
1036
+ "fal-ai": {
1037
+ "text-to-image": new FalAITextToImageTask(),
1038
+ "text-to-speech": new FalAITextToSpeechTask(),
1039
+ "text-to-video": new FalAITextToVideoTask(),
1040
+ "automatic-speech-recognition": new FalAIAutomaticSpeechRecognitionTask()
1041
+ },
1042
+ "hf-inference": {
1043
+ "text-to-image": new HFInferenceTextToImageTask(),
1044
+ conversational: new HFInferenceConversationalTask(),
1045
+ "text-generation": new HFInferenceTextGenerationTask(),
1046
+ "text-classification": new HFInferenceTextClassificationTask(),
1047
+ "question-answering": new HFInferenceQuestionAnsweringTask(),
1048
+ "audio-classification": new HFInferenceAudioClassificationTask(),
1049
+ "automatic-speech-recognition": new HFInferenceAutomaticSpeechRecognitionTask(),
1050
+ "fill-mask": new HFInferenceFillMaskTask(),
1051
+ "feature-extraction": new HFInferenceFeatureExtractionTask(),
1052
+ "image-classification": new HFInferenceImageClassificationTask(),
1053
+ "image-segmentation": new HFInferenceImageSegmentationTask(),
1054
+ "document-question-answering": new HFInferenceDocumentQuestionAnsweringTask(),
1055
+ "image-to-text": new HFInferenceImageToTextTask(),
1056
+ "object-detection": new HFInferenceObjectDetectionTask(),
1057
+ "audio-to-audio": new HFInferenceAudioToAudioTask(),
1058
+ "zero-shot-image-classification": new HFInferenceZeroShotImageClassificationTask(),
1059
+ "zero-shot-classification": new HFInferenceZeroShotClassificationTask(),
1060
+ "image-to-image": new HFInferenceImageToImageTask(),
1061
+ "sentence-similarity": new HFInferenceSentenceSimilarityTask(),
1062
+ "table-question-answering": new HFInferenceTableQuestionAnsweringTask(),
1063
+ "tabular-classification": new HFInferenceTabularClassificationTask(),
1064
+ "text-to-speech": new HFInferenceTextToSpeechTask(),
1065
+ "token-classification": new HFInferenceTokenClassificationTask(),
1066
+ translation: new HFInferenceTranslationTask(),
1067
+ summarization: new HFInferenceSummarizationTask(),
1068
+ "visual-question-answering": new HFInferenceVisualQuestionAnsweringTask(),
1069
+ "tabular-regression": new HFInferenceTabularRegressionTask(),
1070
+ "text-to-audio": new HFInferenceTextToAudioTask()
1071
+ },
1072
+ "fireworks-ai": {
1073
+ conversational: new FireworksConversationalTask()
1074
+ },
1075
+ hyperbolic: {
1076
+ "text-to-image": new HyperbolicTextToImageTask(),
1077
+ conversational: new HyperbolicConversationalTask(),
1078
+ "text-generation": new HyperbolicTextGenerationTask()
1079
+ },
1080
+ nebius: {
1081
+ "text-to-image": new NebiusTextToImageTask(),
1082
+ conversational: new NebiusConversationalTask(),
1083
+ "text-generation": new NebiusTextGenerationTask()
1084
+ },
1085
+ novita: {
1086
+ conversational: new NovitaConversationalTask(),
1087
+ "text-generation": new NovitaTextGenerationTask()
1088
+ },
1089
+ openai: {
1090
+ conversational: new OpenAIConversationalTask()
1091
+ },
1092
+ replicate: {
1093
+ "text-to-image": new ReplicateTextToImageTask(),
1094
+ "text-to-speech": new ReplicateTextToSpeechTask(),
1095
+ "text-to-video": new ReplicateTextToVideoTask()
1096
+ },
1097
+ sambanova: {
1098
+ conversational: new SambanovaConversationalTask()
1099
+ },
1100
+ together: {
1101
+ "text-to-image": new TogetherTextToImageTask(),
1102
+ conversational: new TogetherConversationalTask(),
1103
+ "text-generation": new TogetherTextGenerationTask()
454
1104
  }
455
- return {
456
- ...params.args,
457
- model: params.model
458
- };
459
1105
  };
460
- var makeHeaders13 = (params) => {
461
- return { Authorization: `Bearer ${params.accessToken}` };
462
- };
463
- var makeUrl13 = (params) => {
464
- if (!params.chatCompletion) {
465
- throw new Error("OpenAI only supports chat completions.");
1106
+ function getProviderHelper(provider, task) {
1107
+ if (provider === "hf-inference") {
1108
+ if (!task) {
1109
+ return new HFInferenceTask();
1110
+ }
466
1111
  }
467
- return `${params.baseUrl}/v1/chat/completions`;
468
- };
469
- var OPENAI_CONFIG = {
470
- makeBaseUrl: makeBaseUrl13,
471
- makeBody: makeBody13,
472
- makeHeaders: makeHeaders13,
473
- makeUrl: makeUrl13,
474
- clientSideRoutingOnly: true
475
- };
476
-
477
- // package.json
478
- var name = "@huggingface/inference";
479
- var version = "3.7.0";
1112
+ if (!task) {
1113
+ throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'");
1114
+ }
1115
+ if (!(provider in PROVIDERS)) {
1116
+ throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
1117
+ }
1118
+ const providerTasks = PROVIDERS[provider];
1119
+ if (!providerTasks || !(task in providerTasks)) {
1120
+ throw new Error(
1121
+ `Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}`
1122
+ );
1123
+ }
1124
+ return providerTasks[task];
1125
+ }
480
1126
 
481
1127
  // src/providers/consts.ts
482
1128
  var HARDCODED_MODEL_ID_MAPPING = {
@@ -546,28 +1192,11 @@ async function getProviderModelId(params, args, options = {}) {
546
1192
  }
547
1193
 
548
1194
  // src/lib/makeRequestOptions.ts
549
- var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
550
1195
  var tasks = null;
551
- var providerConfigs = {
552
- "black-forest-labs": BLACK_FOREST_LABS_CONFIG,
553
- cerebras: CEREBRAS_CONFIG,
554
- cohere: COHERE_CONFIG,
555
- "fal-ai": FAL_AI_CONFIG,
556
- "fireworks-ai": FIREWORKS_AI_CONFIG,
557
- "hf-inference": HF_INFERENCE_CONFIG,
558
- hyperbolic: HYPERBOLIC_CONFIG,
559
- openai: OPENAI_CONFIG,
560
- nebius: NEBIUS_CONFIG,
561
- novita: NOVITA_CONFIG,
562
- replicate: REPLICATE_CONFIG,
563
- sambanova: SAMBANOVA_CONFIG,
564
- together: TOGETHER_CONFIG
565
- };
566
1196
  async function makeRequestOptions(args, options) {
567
1197
  const { provider: maybeProvider, model: maybeModel } = args;
568
1198
  const provider = maybeProvider ?? "hf-inference";
569
- const providerConfig = providerConfigs[provider];
570
- const { task, chatCompletion: chatCompletion2 } = options ?? {};
1199
+ const { task } = options ?? {};
571
1200
  if (args.endpointUrl && provider !== "hf-inference") {
572
1201
  throw new Error(`Cannot use endpointUrl with a third-party provider.`);
573
1202
  }
@@ -577,19 +1206,16 @@ async function makeRequestOptions(args, options) {
577
1206
  if (!maybeModel && !task) {
578
1207
  throw new Error("No model provided, and no task has been specified.");
579
1208
  }
580
- if (!providerConfig) {
581
- throw new Error(`No provider config found for provider ${provider}`);
582
- }
583
- if (providerConfig.clientSideRoutingOnly && !maybeModel) {
1209
+ const hfModel = maybeModel ?? await loadDefaultModel(task);
1210
+ const providerHelper = getProviderHelper(provider, task);
1211
+ if (providerHelper.clientSideRoutingOnly && !maybeModel) {
584
1212
  throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
585
1213
  }
586
- const hfModel = maybeModel ?? await loadDefaultModel(task);
587
- const resolvedModel = providerConfig.clientSideRoutingOnly ? (
1214
+ const resolvedModel = providerHelper.clientSideRoutingOnly ? (
588
1215
  // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
589
1216
  removeProviderPrefix(maybeModel, provider)
590
1217
  ) : await getProviderModelId({ model: hfModel, provider }, args, {
591
1218
  task,
592
- chatCompletion: chatCompletion2,
593
1219
  fetch: options?.fetch
594
1220
  });
595
1221
  return makeRequestOptionsFromResolvedModel(resolvedModel, args, options);
@@ -597,10 +1223,10 @@ async function makeRequestOptions(args, options) {
597
1223
  function makeRequestOptionsFromResolvedModel(resolvedModel, args, options) {
598
1224
  const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
599
1225
  const provider = maybeProvider ?? "hf-inference";
600
- const providerConfig = providerConfigs[provider];
601
- const { includeCredentials, task, chatCompletion: chatCompletion2, signal, billTo } = options ?? {};
1226
+ const { includeCredentials, task, signal, billTo } = options ?? {};
1227
+ const providerHelper = getProviderHelper(provider, task);
602
1228
  const authMethod = (() => {
603
- if (providerConfig.clientSideRoutingOnly) {
1229
+ if (providerHelper.clientSideRoutingOnly) {
604
1230
  if (accessToken && accessToken.startsWith("hf_")) {
605
1231
  throw new Error(`Provider ${provider} is closed-source and does not support HF tokens.`);
606
1232
  }
@@ -614,35 +1240,30 @@ function makeRequestOptionsFromResolvedModel(resolvedModel, args, options) {
614
1240
  }
615
1241
  return "none";
616
1242
  })();
617
- const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : providerConfig.makeUrl({
1243
+ const modelId = endpointUrl ?? resolvedModel;
1244
+ const url = providerHelper.makeUrl({
618
1245
  authMethod,
619
- baseUrl: authMethod !== "provider-key" ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider) : providerConfig.makeBaseUrl(task),
620
- model: resolvedModel,
621
- chatCompletion: chatCompletion2,
1246
+ model: modelId,
622
1247
  task
623
1248
  });
624
- const binary = "data" in args && !!args.data;
625
- const headers = providerConfig.makeHeaders({
626
- accessToken,
627
- authMethod
628
- });
1249
+ const headers = providerHelper.prepareHeaders(
1250
+ {
1251
+ accessToken,
1252
+ authMethod
1253
+ },
1254
+ "data" in args && !!args.data
1255
+ );
629
1256
  if (billTo) {
630
1257
  headers[HF_HEADER_X_BILL_TO] = billTo;
631
1258
  }
632
- if (!binary) {
633
- headers["Content-Type"] = "application/json";
634
- }
635
1259
  const ownUserAgent = `${name}/${version}`;
636
1260
  const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : void 0].filter((x) => x !== void 0).join(" ");
637
1261
  headers["User-Agent"] = userAgent;
638
- const body = binary ? args.data : JSON.stringify(
639
- providerConfig.makeBody({
640
- args: remainingArgs,
641
- model: resolvedModel,
642
- task,
643
- chatCompletion: chatCompletion2
644
- })
645
- );
1262
+ const body = providerHelper.makeBody({
1263
+ args: remainingArgs,
1264
+ model: resolvedModel,
1265
+ task
1266
+ });
646
1267
  let credentials;
647
1268
  if (typeof includeCredentials === "string") {
648
1269
  credentials = includeCredentials;
@@ -904,30 +1525,6 @@ async function* streamingRequest(args, options) {
904
1525
  yield* innerStreamingRequest(args, options);
905
1526
  }
906
1527
 
907
- // src/utils/pick.ts
908
- function pick(o, props) {
909
- return Object.assign(
910
- {},
911
- ...props.map((prop) => {
912
- if (o[prop] !== void 0) {
913
- return { [prop]: o[prop] };
914
- }
915
- })
916
- );
917
- }
918
-
919
- // src/utils/typedInclude.ts
920
- function typedInclude(arr, v) {
921
- return arr.includes(v);
922
- }
923
-
924
- // src/utils/omit.ts
925
- function omit(o, props) {
926
- const propsArr = Array.isArray(props) ? props : [props];
927
- const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
928
- return pick(o, letsKeep);
929
- }
930
-
931
1528
  // src/tasks/audio/utils.ts
932
1529
  function preparePayload(args) {
933
1530
  return "data" in args ? args : {
@@ -938,16 +1535,24 @@ function preparePayload(args) {
938
1535
 
939
1536
  // src/tasks/audio/audioClassification.ts
940
1537
  async function audioClassification(args, options) {
1538
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
941
1539
  const payload = preparePayload(args);
942
1540
  const { data: res } = await innerRequest(payload, {
943
1541
  ...options,
944
1542
  task: "audio-classification"
945
1543
  });
946
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
947
- if (!isValidOutput) {
948
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
949
- }
950
- return res;
1544
+ return providerHelper.getResponse(res);
1545
+ }
1546
+
1547
+ // src/tasks/audio/audioToAudio.ts
1548
+ async function audioToAudio(args, options) {
1549
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
1550
+ const payload = preparePayload(args);
1551
+ const { data: res } = await innerRequest(payload, {
1552
+ ...options,
1553
+ task: "audio-to-audio"
1554
+ });
1555
+ return providerHelper.getResponse(res);
951
1556
  }
952
1557
 
953
1558
  // src/utils/base64FromBytes.ts
@@ -965,6 +1570,7 @@ function base64FromBytes(arr) {
965
1570
 
966
1571
  // src/tasks/audio/automaticSpeechRecognition.ts
967
1572
  async function automaticSpeechRecognition(args, options) {
1573
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
968
1574
  const payload = await buildPayload(args);
969
1575
  const { data: res } = await innerRequest(payload, {
970
1576
  ...options,
@@ -974,9 +1580,8 @@ async function automaticSpeechRecognition(args, options) {
974
1580
  if (!isValidOutput) {
975
1581
  throw new InferenceOutputError("Expected {text: string}");
976
1582
  }
977
- return res;
1583
+ return providerHelper.getResponse(res);
978
1584
  }
979
- var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
980
1585
  async function buildPayload(args) {
981
1586
  if (args.provider === "fal-ai") {
982
1587
  const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : void 0;
@@ -995,63 +1600,23 @@ async function buildPayload(args) {
995
1600
  }
996
1601
  const base64audio = base64FromBytes(new Uint8Array(await blob.arrayBuffer()));
997
1602
  return {
998
- ..."data" in args ? omit(args, "data") : omit(args, "inputs"),
999
- audio_url: `data:${contentType};base64,${base64audio}`
1000
- };
1001
- } else {
1002
- return preparePayload(args);
1003
- }
1004
- }
1005
-
1006
- // src/tasks/audio/textToSpeech.ts
1007
- async function textToSpeech(args, options) {
1008
- const payload = args.provider === "replicate" ? {
1009
- ...omit(args, ["inputs", "parameters"]),
1010
- ...args.parameters,
1011
- text: args.inputs
1012
- } : args;
1013
- const { data: res } = await innerRequest(payload, {
1014
- ...options,
1015
- task: "text-to-speech"
1016
- });
1017
- if (res instanceof Blob) {
1018
- return res;
1019
- }
1020
- if (res && typeof res === "object") {
1021
- if ("output" in res) {
1022
- if (typeof res.output === "string") {
1023
- const urlResponse = await fetch(res.output);
1024
- const blob = await urlResponse.blob();
1025
- return blob;
1026
- } else if (Array.isArray(res.output)) {
1027
- const urlResponse = await fetch(res.output[0]);
1028
- const blob = await urlResponse.blob();
1029
- return blob;
1030
- }
1031
- }
1603
+ ..."data" in args ? omit(args, "data") : omit(args, "inputs"),
1604
+ audio_url: `data:${contentType};base64,${base64audio}`
1605
+ };
1606
+ } else {
1607
+ return preparePayload(args);
1032
1608
  }
1033
- throw new InferenceOutputError("Expected Blob or object with output");
1034
1609
  }
1035
1610
 
1036
- // src/tasks/audio/audioToAudio.ts
1037
- async function audioToAudio(args, options) {
1038
- const payload = preparePayload(args);
1039
- const { data: res } = await innerRequest(payload, {
1611
+ // src/tasks/audio/textToSpeech.ts
1612
+ async function textToSpeech(args, options) {
1613
+ const provider = args.provider ?? "hf-inference";
1614
+ const providerHelper = getProviderHelper(provider, "text-to-speech");
1615
+ const { data: res } = await innerRequest(args, {
1040
1616
  ...options,
1041
- task: "audio-to-audio"
1617
+ task: "text-to-speech"
1042
1618
  });
1043
- return validateOutput(res);
1044
- }
1045
- function validateOutput(output) {
1046
- if (!Array.isArray(output)) {
1047
- throw new InferenceOutputError("Expected Array");
1048
- }
1049
- if (!output.every((elem) => {
1050
- return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
1051
- })) {
1052
- throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
1053
- }
1054
- return output;
1619
+ return providerHelper.getResponse(res);
1055
1620
  }
1056
1621
 
1057
1622
  // src/tasks/cv/utils.ts
@@ -1061,183 +1626,95 @@ function preparePayload2(args) {
1061
1626
 
1062
1627
  // src/tasks/cv/imageClassification.ts
1063
1628
  async function imageClassification(args, options) {
1629
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
1064
1630
  const payload = preparePayload2(args);
1065
1631
  const { data: res } = await innerRequest(payload, {
1066
1632
  ...options,
1067
1633
  task: "image-classification"
1068
1634
  });
1069
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
1070
- if (!isValidOutput) {
1071
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
1072
- }
1073
- return res;
1635
+ return providerHelper.getResponse(res);
1074
1636
  }
1075
1637
 
1076
1638
  // src/tasks/cv/imageSegmentation.ts
1077
1639
  async function imageSegmentation(args, options) {
1640
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
1078
1641
  const payload = preparePayload2(args);
1079
1642
  const { data: res } = await innerRequest(payload, {
1080
1643
  ...options,
1081
1644
  task: "image-segmentation"
1082
1645
  });
1083
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
1084
- if (!isValidOutput) {
1085
- throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
1646
+ return providerHelper.getResponse(res);
1647
+ }
1648
+
1649
+ // src/tasks/cv/imageToImage.ts
1650
+ async function imageToImage(args, options) {
1651
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
1652
+ let reqArgs;
1653
+ if (!args.parameters) {
1654
+ reqArgs = {
1655
+ accessToken: args.accessToken,
1656
+ model: args.model,
1657
+ data: args.inputs
1658
+ };
1659
+ } else {
1660
+ reqArgs = {
1661
+ ...args,
1662
+ inputs: base64FromBytes(
1663
+ new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await args.inputs.arrayBuffer())
1664
+ )
1665
+ };
1086
1666
  }
1087
- return res;
1667
+ const { data: res } = await innerRequest(reqArgs, {
1668
+ ...options,
1669
+ task: "image-to-image"
1670
+ });
1671
+ return providerHelper.getResponse(res);
1088
1672
  }
1089
1673
 
1090
1674
  // src/tasks/cv/imageToText.ts
1091
1675
  async function imageToText(args, options) {
1676
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
1092
1677
  const payload = preparePayload2(args);
1093
1678
  const { data: res } = await innerRequest(payload, {
1094
1679
  ...options,
1095
1680
  task: "image-to-text"
1096
1681
  });
1097
- if (typeof res?.[0]?.generated_text !== "string") {
1098
- throw new InferenceOutputError("Expected {generated_text: string}");
1099
- }
1100
- return res?.[0];
1682
+ return providerHelper.getResponse(res[0]);
1101
1683
  }
1102
1684
 
1103
1685
  // src/tasks/cv/objectDetection.ts
1104
1686
  async function objectDetection(args, options) {
1687
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
1105
1688
  const payload = preparePayload2(args);
1106
1689
  const { data: res } = await innerRequest(payload, {
1107
1690
  ...options,
1108
1691
  task: "object-detection"
1109
1692
  });
1110
- const isValidOutput = Array.isArray(res) && res.every(
1111
- (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
1112
- );
1113
- if (!isValidOutput) {
1114
- throw new InferenceOutputError(
1115
- "Expected Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
1116
- );
1117
- }
1118
- return res;
1693
+ return providerHelper.getResponse(res);
1119
1694
  }
1120
1695
 
1121
1696
  // src/tasks/cv/textToImage.ts
1122
- function getResponseFormatArg(provider) {
1123
- switch (provider) {
1124
- case "fal-ai":
1125
- return { sync_mode: true };
1126
- case "nebius":
1127
- return { response_format: "b64_json" };
1128
- case "replicate":
1129
- return void 0;
1130
- case "together":
1131
- return { response_format: "base64" };
1132
- default:
1133
- return void 0;
1134
- }
1135
- }
1136
1697
  async function textToImage(args, options) {
1137
- const payload = !args.provider || args.provider === "hf-inference" || args.provider === "sambanova" ? args : {
1138
- ...omit(args, ["inputs", "parameters"]),
1139
- ...args.parameters,
1140
- ...getResponseFormatArg(args.provider),
1141
- prompt: args.inputs
1142
- };
1143
- const { data: res } = await innerRequest(payload, {
1698
+ const provider = args.provider ?? "hf-inference";
1699
+ const providerHelper = getProviderHelper(provider, "text-to-image");
1700
+ const { data: res } = await innerRequest(args, {
1144
1701
  ...options,
1145
1702
  task: "text-to-image"
1146
1703
  });
1147
- if (res && typeof res === "object") {
1148
- if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
1149
- return await pollBflResponse(res.polling_url, options?.outputType);
1150
- }
1151
- if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
1152
- if (options?.outputType === "url") {
1153
- return res.images[0].url;
1154
- } else {
1155
- const image = await fetch(res.images[0].url);
1156
- return await image.blob();
1157
- }
1158
- }
1159
- if (args.provider === "hyperbolic" && "images" in res && Array.isArray(res.images) && res.images[0] && typeof res.images[0].image === "string") {
1160
- if (options?.outputType === "url") {
1161
- return `data:image/jpeg;base64,${res.images[0].image}`;
1162
- }
1163
- const base64Response = await fetch(`data:image/jpeg;base64,${res.images[0].image}`);
1164
- return await base64Response.blob();
1165
- }
1166
- if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
1167
- const base64Data = res.data[0].b64_json;
1168
- if (options?.outputType === "url") {
1169
- return `data:image/jpeg;base64,${base64Data}`;
1170
- }
1171
- const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
1172
- return await base64Response.blob();
1173
- }
1174
- if ("output" in res && Array.isArray(res.output)) {
1175
- if (options?.outputType === "url") {
1176
- return res.output[0];
1177
- }
1178
- const urlResponse = await fetch(res.output[0]);
1179
- const blob = await urlResponse.blob();
1180
- return blob;
1181
- }
1182
- }
1183
- const isValidOutput = res && res instanceof Blob;
1184
- if (!isValidOutput) {
1185
- throw new InferenceOutputError("Expected Blob");
1186
- }
1187
- if (options?.outputType === "url") {
1188
- const b64 = await res.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
1189
- return `data:image/jpeg;base64,${b64}`;
1190
- }
1191
- return res;
1192
- }
1193
- async function pollBflResponse(url, outputType) {
1194
- const urlObj = new URL(url);
1195
- for (let step = 0; step < 5; step++) {
1196
- await delay(1e3);
1197
- console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
1198
- urlObj.searchParams.set("attempt", step.toString(10));
1199
- const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
1200
- if (!resp.ok) {
1201
- throw new InferenceOutputError("Failed to fetch result from black forest labs API");
1202
- }
1203
- const payload = await resp.json();
1204
- if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
1205
- if (outputType === "url") {
1206
- return payload.result.sample;
1207
- }
1208
- const image = await fetch(payload.result.sample);
1209
- return await image.blob();
1210
- }
1211
- }
1212
- throw new InferenceOutputError("Failed to fetch result from black forest labs API");
1704
+ const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-image" });
1705
+ return providerHelper.getResponse(res, url, info.headers, options?.outputType);
1213
1706
  }
1214
1707
 
1215
- // src/tasks/cv/imageToImage.ts
1216
- async function imageToImage(args, options) {
1217
- let reqArgs;
1218
- if (!args.parameters) {
1219
- reqArgs = {
1220
- accessToken: args.accessToken,
1221
- model: args.model,
1222
- data: args.inputs
1223
- };
1224
- } else {
1225
- reqArgs = {
1226
- ...args,
1227
- inputs: base64FromBytes(
1228
- new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await args.inputs.arrayBuffer())
1229
- )
1230
- };
1231
- }
1232
- const { data: res } = await innerRequest(reqArgs, {
1708
+ // src/tasks/cv/textToVideo.ts
1709
+ async function textToVideo(args, options) {
1710
+ const provider = args.provider ?? "hf-inference";
1711
+ const providerHelper = getProviderHelper(provider, "text-to-video");
1712
+ const { data: response } = await innerRequest(args, {
1233
1713
  ...options,
1234
- task: "image-to-image"
1714
+ task: "text-to-video"
1235
1715
  });
1236
- const isValidOutput = res && res instanceof Blob;
1237
- if (!isValidOutput) {
1238
- throw new InferenceOutputError("Expected Blob");
1239
- }
1240
- return res;
1716
+ const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-video" });
1717
+ return providerHelper.getResponse(response, url, info.headers);
1241
1718
  }
1242
1719
 
1243
1720
  // src/tasks/cv/zeroShotImageClassification.ts
@@ -1263,226 +1740,112 @@ async function preparePayload3(args) {
1263
1740
  }
1264
1741
  }
1265
1742
  async function zeroShotImageClassification(args, options) {
1743
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
1266
1744
  const payload = await preparePayload3(args);
1267
1745
  const { data: res } = await innerRequest(payload, {
1268
1746
  ...options,
1269
1747
  task: "zero-shot-image-classification"
1270
1748
  });
1271
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
1272
- if (!isValidOutput) {
1273
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
1274
- }
1275
- return res;
1749
+ return providerHelper.getResponse(res);
1276
1750
  }
1277
1751
 
1278
- // src/tasks/cv/textToVideo.ts
1279
- var SUPPORTED_PROVIDERS = ["fal-ai", "novita", "replicate"];
1280
- async function textToVideo(args, options) {
1281
- if (!args.provider || !typedInclude(SUPPORTED_PROVIDERS, args.provider)) {
1282
- throw new Error(
1283
- `textToVideo inference is only supported for the following providers: ${SUPPORTED_PROVIDERS.join(", ")}`
1284
- );
1285
- }
1286
- const payload = args.provider === "fal-ai" || args.provider === "replicate" || args.provider === "novita" ? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs } : args;
1287
- const { data, requestContext } = await innerRequest(payload, {
1752
+ // src/tasks/nlp/chatCompletion.ts
1753
+ async function chatCompletion(args, options) {
1754
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
1755
+ const { data: response } = await innerRequest(args, {
1288
1756
  ...options,
1289
- task: "text-to-video"
1757
+ task: "conversational"
1758
+ });
1759
+ return providerHelper.getResponse(response);
1760
+ }
1761
+
1762
+ // src/tasks/nlp/chatCompletionStream.ts
1763
+ async function* chatCompletionStream(args, options) {
1764
+ yield* innerStreamingRequest(args, {
1765
+ ...options,
1766
+ task: "conversational"
1290
1767
  });
1291
- if (args.provider === "fal-ai") {
1292
- return await pollFalResponse(
1293
- data,
1294
- requestContext.url,
1295
- requestContext.info.headers
1296
- );
1297
- } else if (args.provider === "novita") {
1298
- const isValidOutput = typeof data === "object" && !!data && "video" in data && typeof data.video === "object" && !!data.video && "video_url" in data.video && typeof data.video.video_url === "string" && isUrl(data.video.video_url);
1299
- if (!isValidOutput) {
1300
- throw new InferenceOutputError("Expected { video: { video_url: string } }");
1301
- }
1302
- const urlResponse = await fetch(data.video.video_url);
1303
- return await urlResponse.blob();
1304
- } else {
1305
- const isValidOutput = typeof data === "object" && !!data && "output" in data && typeof data.output === "string" && isUrl(data.output);
1306
- if (!isValidOutput) {
1307
- throw new InferenceOutputError("Expected { output: string }");
1308
- }
1309
- const urlResponse = await fetch(data.output);
1310
- return await urlResponse.blob();
1311
- }
1312
1768
  }
1313
1769
 
1314
1770
  // src/tasks/nlp/featureExtraction.ts
1315
1771
  async function featureExtraction(args, options) {
1772
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction");
1316
1773
  const { data: res } = await innerRequest(args, {
1317
1774
  ...options,
1318
1775
  task: "feature-extraction"
1319
1776
  });
1320
- let isValidOutput = true;
1321
- const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
1322
- if (curDepth > maxDepth)
1323
- return false;
1324
- if (arr.every((x) => Array.isArray(x))) {
1325
- return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
1326
- } else {
1327
- return arr.every((x) => typeof x === "number");
1328
- }
1329
- };
1330
- isValidOutput = Array.isArray(res) && isNumArrayRec(res, 3, 0);
1331
- if (!isValidOutput) {
1332
- throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
1333
- }
1334
- return res;
1777
+ return providerHelper.getResponse(res);
1335
1778
  }
1336
1779
 
1337
1780
  // src/tasks/nlp/fillMask.ts
1338
1781
  async function fillMask(args, options) {
1782
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask");
1339
1783
  const { data: res } = await innerRequest(args, {
1340
1784
  ...options,
1341
1785
  task: "fill-mask"
1342
1786
  });
1343
- const isValidOutput = Array.isArray(res) && res.every(
1344
- (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
1345
- );
1346
- if (!isValidOutput) {
1347
- throw new InferenceOutputError(
1348
- "Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
1349
- );
1350
- }
1351
- return res;
1787
+ return providerHelper.getResponse(res);
1352
1788
  }
1353
1789
 
1354
1790
  // src/tasks/nlp/questionAnswering.ts
1355
1791
  async function questionAnswering(args, options) {
1792
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering");
1356
1793
  const { data: res } = await innerRequest(args, {
1357
1794
  ...options,
1358
1795
  task: "question-answering"
1359
1796
  });
1360
- const isValidOutput = Array.isArray(res) ? res.every(
1361
- (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
1362
- ) : typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";
1363
- if (!isValidOutput) {
1364
- throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
1365
- }
1366
- return Array.isArray(res) ? res[0] : res;
1797
+ return providerHelper.getResponse(res);
1367
1798
  }
1368
1799
 
1369
1800
  // src/tasks/nlp/sentenceSimilarity.ts
1370
1801
  async function sentenceSimilarity(args, options) {
1802
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity");
1371
1803
  const { data: res } = await innerRequest(args, {
1372
1804
  ...options,
1373
1805
  task: "sentence-similarity"
1374
1806
  });
1375
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1376
- if (!isValidOutput) {
1377
- throw new InferenceOutputError("Expected number[]");
1378
- }
1379
- return res;
1807
+ return providerHelper.getResponse(res);
1380
1808
  }
1381
1809
 
1382
1810
  // src/tasks/nlp/summarization.ts
1383
1811
  async function summarization(args, options) {
1812
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization");
1384
1813
  const { data: res } = await innerRequest(args, {
1385
1814
  ...options,
1386
1815
  task: "summarization"
1387
1816
  });
1388
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string");
1389
- if (!isValidOutput) {
1390
- throw new InferenceOutputError("Expected Array<{summary_text: string}>");
1391
- }
1392
- return res?.[0];
1817
+ return providerHelper.getResponse(res);
1393
1818
  }
1394
1819
 
1395
1820
  // src/tasks/nlp/tableQuestionAnswering.ts
1396
1821
  async function tableQuestionAnswering(args, options) {
1822
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
1397
1823
  const { data: res } = await innerRequest(args, {
1398
1824
  ...options,
1399
1825
  task: "table-question-answering"
1400
1826
  });
1401
- const isValidOutput = Array.isArray(res) ? res.every((elem) => validate(elem)) : validate(res);
1402
- if (!isValidOutput) {
1403
- throw new InferenceOutputError(
1404
- "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
1405
- );
1406
- }
1407
- return Array.isArray(res) ? res[0] : res;
1408
- }
1409
- function validate(elem) {
1410
- return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
1411
- (coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
1412
- );
1827
+ return providerHelper.getResponse(res);
1413
1828
  }
1414
1829
 
1415
1830
  // src/tasks/nlp/textClassification.ts
1416
1831
  async function textClassification(args, options) {
1832
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
1417
1833
  const { data: res } = await innerRequest(args, {
1418
1834
  ...options,
1419
1835
  task: "text-classification"
1420
1836
  });
1421
- const output = res?.[0];
1422
- const isValidOutput = Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number");
1423
- if (!isValidOutput) {
1424
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
1425
- }
1426
- return output;
1427
- }
1428
-
1429
- // src/utils/toArray.ts
1430
- function toArray(obj) {
1431
- if (Array.isArray(obj)) {
1432
- return obj;
1433
- }
1434
- return [obj];
1837
+ return providerHelper.getResponse(res);
1435
1838
  }
1436
1839
 
1437
1840
  // src/tasks/nlp/textGeneration.ts
1438
1841
  async function textGeneration(args, options) {
1439
- if (args.provider === "together") {
1440
- args.prompt = args.inputs;
1441
- const { data: raw } = await innerRequest(args, {
1442
- ...options,
1443
- task: "text-generation"
1444
- });
1445
- const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
1446
- if (!isValidOutput) {
1447
- throw new InferenceOutputError("Expected ChatCompletionOutput");
1448
- }
1449
- const completion = raw.choices[0];
1450
- return {
1451
- generated_text: completion.text
1452
- };
1453
- } else if (args.provider === "hyperbolic") {
1454
- const payload = {
1455
- messages: [{ content: args.inputs, role: "user" }],
1456
- ...args.parameters ? {
1457
- max_tokens: args.parameters.max_new_tokens,
1458
- ...omit(args.parameters, "max_new_tokens")
1459
- } : void 0,
1460
- ...omit(args, ["inputs", "parameters"])
1461
- };
1462
- const raw = (await innerRequest(payload, {
1463
- ...options,
1464
- task: "text-generation"
1465
- })).data;
1466
- const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
1467
- if (!isValidOutput) {
1468
- throw new InferenceOutputError("Expected ChatCompletionOutput");
1469
- }
1470
- const completion = raw.choices[0];
1471
- return {
1472
- generated_text: completion.message.content
1473
- };
1474
- } else {
1475
- const { data: res } = await innerRequest(args, {
1476
- ...options,
1477
- task: "text-generation"
1478
- });
1479
- const output = toArray(res);
1480
- const isValidOutput = Array.isArray(output) && output.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
1481
- if (!isValidOutput) {
1482
- throw new InferenceOutputError("Expected Array<{generated_text: string}>");
1483
- }
1484
- return output?.[0];
1485
- }
1842
+ const provider = args.provider ?? "hf-inference";
1843
+ const providerHelper = getProviderHelper(provider, "text-generation");
1844
+ const { data: response } = await innerRequest(args, {
1845
+ ...options,
1846
+ task: "text-generation"
1847
+ });
1848
+ return providerHelper.getResponse(response);
1486
1849
  }
1487
1850
 
1488
1851
  // src/tasks/nlp/textGenerationStream.ts
@@ -1495,77 +1858,37 @@ async function* textGenerationStream(args, options) {
1495
1858
 
1496
1859
  // src/tasks/nlp/tokenClassification.ts
1497
1860
  async function tokenClassification(args, options) {
1861
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification");
1498
1862
  const { data: res } = await innerRequest(args, {
1499
1863
  ...options,
1500
1864
  task: "token-classification"
1501
1865
  });
1502
- const output = toArray(res);
1503
- const isValidOutput = Array.isArray(output) && output.every(
1504
- (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
1505
- );
1506
- if (!isValidOutput) {
1507
- throw new InferenceOutputError(
1508
- "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
1509
- );
1510
- }
1511
- return output;
1866
+ return providerHelper.getResponse(res);
1512
1867
  }
1513
1868
 
1514
1869
  // src/tasks/nlp/translation.ts
1515
1870
  async function translation(args, options) {
1871
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation");
1516
1872
  const { data: res } = await innerRequest(args, {
1517
1873
  ...options,
1518
1874
  task: "translation"
1519
1875
  });
1520
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");
1521
- if (!isValidOutput) {
1522
- throw new InferenceOutputError("Expected type Array<{translation_text: string}>");
1523
- }
1524
- return res?.length === 1 ? res?.[0] : res;
1876
+ return providerHelper.getResponse(res);
1525
1877
  }
1526
1878
 
1527
1879
  // src/tasks/nlp/zeroShotClassification.ts
1528
1880
  async function zeroShotClassification(args, options) {
1881
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification");
1529
1882
  const { data: res } = await innerRequest(args, {
1530
1883
  ...options,
1531
1884
  task: "zero-shot-classification"
1532
1885
  });
1533
- const output = toArray(res);
1534
- const isValidOutput = Array.isArray(output) && output.every(
1535
- (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
1536
- );
1537
- if (!isValidOutput) {
1538
- throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
1539
- }
1540
- return output;
1541
- }
1542
-
1543
- // src/tasks/nlp/chatCompletion.ts
1544
- async function chatCompletion(args, options) {
1545
- const { data: res } = await innerRequest(args, {
1546
- ...options,
1547
- task: "text-generation",
1548
- chatCompletion: true
1549
- });
1550
- const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
1551
- (res.system_fingerprint === void 0 || res.system_fingerprint === null || typeof res.system_fingerprint === "string") && typeof res?.usage === "object";
1552
- if (!isValidOutput) {
1553
- throw new InferenceOutputError("Expected ChatCompletionOutput");
1554
- }
1555
- return res;
1556
- }
1557
-
1558
- // src/tasks/nlp/chatCompletionStream.ts
1559
- async function* chatCompletionStream(args, options) {
1560
- yield* innerStreamingRequest(args, {
1561
- ...options,
1562
- task: "text-generation",
1563
- chatCompletion: true
1564
- });
1886
+ return providerHelper.getResponse(res);
1565
1887
  }
1566
1888
 
1567
1889
  // src/tasks/multimodal/documentQuestionAnswering.ts
1568
1890
  async function documentQuestionAnswering(args, options) {
1891
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering");
1569
1892
  const reqArgs = {
1570
1893
  ...args,
1571
1894
  inputs: {
@@ -1581,18 +1904,12 @@ async function documentQuestionAnswering(args, options) {
1581
1904
  task: "document-question-answering"
1582
1905
  }
1583
1906
  );
1584
- const output = toArray(res);
1585
- const isValidOutput = Array.isArray(output) && output.every(
1586
- (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
1587
- );
1588
- if (!isValidOutput) {
1589
- throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>");
1590
- }
1591
- return output[0];
1907
+ return providerHelper.getResponse(res);
1592
1908
  }
1593
1909
 
1594
1910
  // src/tasks/multimodal/visualQuestionAnswering.ts
1595
1911
  async function visualQuestionAnswering(args, options) {
1912
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "visual-question-answering");
1596
1913
  const reqArgs = {
1597
1914
  ...args,
1598
1915
  inputs: {
@@ -1605,39 +1922,27 @@ async function visualQuestionAnswering(args, options) {
1605
1922
  ...options,
1606
1923
  task: "visual-question-answering"
1607
1924
  });
1608
- const isValidOutput = Array.isArray(res) && res.every(
1609
- (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
1610
- );
1611
- if (!isValidOutput) {
1612
- throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
1613
- }
1614
- return res[0];
1925
+ return providerHelper.getResponse(res);
1615
1926
  }
1616
1927
 
1617
- // src/tasks/tabular/tabularRegression.ts
1618
- async function tabularRegression(args, options) {
1928
+ // src/tasks/tabular/tabularClassification.ts
1929
+ async function tabularClassification(args, options) {
1930
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification");
1619
1931
  const { data: res } = await innerRequest(args, {
1620
1932
  ...options,
1621
- task: "tabular-regression"
1933
+ task: "tabular-classification"
1622
1934
  });
1623
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1624
- if (!isValidOutput) {
1625
- throw new InferenceOutputError("Expected number[]");
1626
- }
1627
- return res;
1935
+ return providerHelper.getResponse(res);
1628
1936
  }
1629
1937
 
1630
- // src/tasks/tabular/tabularClassification.ts
1631
- async function tabularClassification(args, options) {
1938
+ // src/tasks/tabular/tabularRegression.ts
1939
+ async function tabularRegression(args, options) {
1940
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression");
1632
1941
  const { data: res } = await innerRequest(args, {
1633
1942
  ...options,
1634
- task: "tabular-classification"
1943
+ task: "tabular-regression"
1635
1944
  });
1636
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1637
- if (!isValidOutput) {
1638
- throw new InferenceOutputError("Expected number[]");
1639
- }
1640
- return res;
1945
+ return providerHelper.getResponse(res);
1641
1946
  }
1642
1947
 
1643
1948
  // src/InferenceClient.ts
@@ -1706,11 +2011,11 @@ __export(snippets_exports, {
1706
2011
  });
1707
2012
 
1708
2013
  // src/snippets/getInferenceSnippets.ts
2014
+ import { Template } from "@huggingface/jinja";
1709
2015
  import {
1710
- inferenceSnippetLanguages,
1711
- getModelInputSnippet
2016
+ getModelInputSnippet,
2017
+ inferenceSnippetLanguages
1712
2018
  } from "@huggingface/tasks";
1713
- import { Template } from "@huggingface/jinja";
1714
2019
 
1715
2020
  // src/snippets/templates.exported.ts
1716
2021
  var templates = {
@@ -1753,7 +2058,7 @@ const image = await client.textToVideo({
1753
2058
  },
1754
2059
  "openai": {
1755
2060
  "conversational": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\nconst chatCompletion = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);',
1756
- "conversationalStream": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\nlet out = "";\n\nconst stream = await client.chat.completions.create({\n provider: "{{ provider }}",\n model: "{{ model.id }}",\n{{ inputs.asTsString }}\n});\n\nfor await (const chunk of stream) {\n if (chunk.choices && chunk.choices.length > 0) {\n const newContent = chunk.choices[0].delta.content;\n out += newContent;\n console.log(newContent);\n } \n}'
2061
+ "conversationalStream": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\nconst stream = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n stream: true,\n});\n\nfor await (const chunk of stream) {\n process.stdout.write(chunk.choices[0]?.delta?.content || "");\n}'
1757
2062
  }
1758
2063
  },
1759
2064
  "python": {
@@ -1885,15 +2190,23 @@ var HF_JS_METHODS = {
1885
2190
  };
1886
2191
  var snippetGenerator = (templateName, inputPreparationFn) => {
1887
2192
  return (model, accessToken, provider, providerModelId, opts) => {
2193
+ let task = model.pipeline_tag;
1888
2194
  if (model.pipeline_tag && ["text-generation", "image-text-to-text"].includes(model.pipeline_tag) && model.tags.includes("conversational")) {
1889
2195
  templateName = opts?.streaming ? "conversationalStream" : "conversational";
1890
2196
  inputPreparationFn = prepareConversationalInput;
2197
+ task = "conversational";
1891
2198
  }
1892
2199
  const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
1893
2200
  const request2 = makeRequestOptionsFromResolvedModel(
1894
2201
  providerModelId ?? model.id,
1895
- { accessToken, provider, ...inputs },
1896
- { chatCompletion: templateName.includes("conversational"), task: model.pipeline_tag }
2202
+ {
2203
+ accessToken,
2204
+ provider,
2205
+ ...inputs
2206
+ },
2207
+ {
2208
+ task
2209
+ }
1897
2210
  );
1898
2211
  let providerInputs = inputs;
1899
2212
  const bodyAsObj = request2.info.body;
@@ -1980,7 +2293,7 @@ var prepareConversationalInput = (model, opts) => {
1980
2293
  return {
1981
2294
  messages: opts?.messages ?? getModelInputSnippet(model),
1982
2295
  ...opts?.temperature ? { temperature: opts?.temperature } : void 0,
1983
- max_tokens: opts?.max_tokens ?? 500,
2296
+ max_tokens: opts?.max_tokens ?? 512,
1984
2297
  ...opts?.top_p ? { top_p: opts?.top_p } : void 0
1985
2298
  };
1986
2299
  };