@huggingface/inference 3.7.0 → 3.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. package/dist/index.cjs +1369 -941
  2. package/dist/index.js +1371 -943
  3. package/dist/src/lib/getInferenceProviderMapping.d.ts +21 -0
  4. package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -0
  5. package/dist/src/lib/getProviderHelper.d.ts +37 -0
  6. package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
  7. package/dist/src/lib/makeRequestOptions.d.ts +5 -5
  8. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  9. package/dist/src/providers/black-forest-labs.d.ts +14 -18
  10. package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
  11. package/dist/src/providers/cerebras.d.ts +4 -2
  12. package/dist/src/providers/cerebras.d.ts.map +1 -1
  13. package/dist/src/providers/cohere.d.ts +5 -2
  14. package/dist/src/providers/cohere.d.ts.map +1 -1
  15. package/dist/src/providers/consts.d.ts +2 -3
  16. package/dist/src/providers/consts.d.ts.map +1 -1
  17. package/dist/src/providers/fal-ai.d.ts +50 -3
  18. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  19. package/dist/src/providers/fireworks-ai.d.ts +5 -2
  20. package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
  21. package/dist/src/providers/hf-inference.d.ts +126 -2
  22. package/dist/src/providers/hf-inference.d.ts.map +1 -1
  23. package/dist/src/providers/hyperbolic.d.ts +31 -2
  24. package/dist/src/providers/hyperbolic.d.ts.map +1 -1
  25. package/dist/src/providers/nebius.d.ts +20 -18
  26. package/dist/src/providers/nebius.d.ts.map +1 -1
  27. package/dist/src/providers/novita.d.ts +21 -18
  28. package/dist/src/providers/novita.d.ts.map +1 -1
  29. package/dist/src/providers/openai.d.ts +4 -2
  30. package/dist/src/providers/openai.d.ts.map +1 -1
  31. package/dist/src/providers/providerHelper.d.ts +182 -0
  32. package/dist/src/providers/providerHelper.d.ts.map +1 -0
  33. package/dist/src/providers/replicate.d.ts +23 -19
  34. package/dist/src/providers/replicate.d.ts.map +1 -1
  35. package/dist/src/providers/sambanova.d.ts +4 -2
  36. package/dist/src/providers/sambanova.d.ts.map +1 -1
  37. package/dist/src/providers/together.d.ts +32 -2
  38. package/dist/src/providers/together.d.ts.map +1 -1
  39. package/dist/src/snippets/getInferenceSnippets.d.ts +2 -1
  40. package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
  41. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  42. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  43. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  44. package/dist/src/tasks/audio/utils.d.ts +2 -1
  45. package/dist/src/tasks/audio/utils.d.ts.map +1 -1
  46. package/dist/src/tasks/custom/request.d.ts +0 -2
  47. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  48. package/dist/src/tasks/custom/streamingRequest.d.ts +0 -2
  49. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  50. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  51. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  52. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  53. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  54. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  55. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  56. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
  57. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  58. package/dist/src/tasks/index.d.ts +6 -6
  59. package/dist/src/tasks/index.d.ts.map +1 -1
  60. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  61. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  62. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  63. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
  64. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  65. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  66. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  67. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  68. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  69. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  70. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  71. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  72. package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
  73. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  74. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  75. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  76. package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
  77. package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
  78. package/dist/src/types.d.ts +5 -13
  79. package/dist/src/types.d.ts.map +1 -1
  80. package/dist/src/utils/request.d.ts +3 -2
  81. package/dist/src/utils/request.d.ts.map +1 -1
  82. package/package.json +3 -3
  83. package/src/lib/getInferenceProviderMapping.ts +96 -0
  84. package/src/lib/getProviderHelper.ts +270 -0
  85. package/src/lib/makeRequestOptions.ts +78 -97
  86. package/src/providers/black-forest-labs.ts +73 -22
  87. package/src/providers/cerebras.ts +6 -27
  88. package/src/providers/cohere.ts +9 -28
  89. package/src/providers/consts.ts +5 -2
  90. package/src/providers/fal-ai.ts +224 -77
  91. package/src/providers/fireworks-ai.ts +8 -29
  92. package/src/providers/hf-inference.ts +557 -34
  93. package/src/providers/hyperbolic.ts +107 -29
  94. package/src/providers/nebius.ts +65 -29
  95. package/src/providers/novita.ts +68 -32
  96. package/src/providers/openai.ts +6 -32
  97. package/src/providers/providerHelper.ts +354 -0
  98. package/src/providers/replicate.ts +124 -34
  99. package/src/providers/sambanova.ts +5 -30
  100. package/src/providers/together.ts +92 -28
  101. package/src/snippets/getInferenceSnippets.ts +39 -14
  102. package/src/snippets/templates.exported.ts +25 -25
  103. package/src/tasks/audio/audioClassification.ts +5 -8
  104. package/src/tasks/audio/audioToAudio.ts +4 -27
  105. package/src/tasks/audio/automaticSpeechRecognition.ts +5 -4
  106. package/src/tasks/audio/textToSpeech.ts +5 -29
  107. package/src/tasks/audio/utils.ts +2 -1
  108. package/src/tasks/custom/request.ts +3 -3
  109. package/src/tasks/custom/streamingRequest.ts +4 -3
  110. package/src/tasks/cv/imageClassification.ts +4 -8
  111. package/src/tasks/cv/imageSegmentation.ts +4 -9
  112. package/src/tasks/cv/imageToImage.ts +4 -7
  113. package/src/tasks/cv/imageToText.ts +4 -7
  114. package/src/tasks/cv/objectDetection.ts +4 -19
  115. package/src/tasks/cv/textToImage.ts +9 -137
  116. package/src/tasks/cv/textToVideo.ts +17 -64
  117. package/src/tasks/cv/zeroShotImageClassification.ts +4 -8
  118. package/src/tasks/index.ts +6 -6
  119. package/src/tasks/multimodal/documentQuestionAnswering.ts +4 -19
  120. package/src/tasks/multimodal/visualQuestionAnswering.ts +4 -12
  121. package/src/tasks/nlp/chatCompletion.ts +5 -20
  122. package/src/tasks/nlp/chatCompletionStream.ts +4 -3
  123. package/src/tasks/nlp/featureExtraction.ts +4 -19
  124. package/src/tasks/nlp/fillMask.ts +4 -17
  125. package/src/tasks/nlp/questionAnswering.ts +11 -26
  126. package/src/tasks/nlp/sentenceSimilarity.ts +4 -8
  127. package/src/tasks/nlp/summarization.ts +4 -7
  128. package/src/tasks/nlp/tableQuestionAnswering.ts +10 -30
  129. package/src/tasks/nlp/textClassification.ts +4 -9
  130. package/src/tasks/nlp/textGeneration.ts +11 -79
  131. package/src/tasks/nlp/textGenerationStream.ts +3 -1
  132. package/src/tasks/nlp/tokenClassification.ts +11 -23
  133. package/src/tasks/nlp/translation.ts +4 -7
  134. package/src/tasks/nlp/zeroShotClassification.ts +11 -21
  135. package/src/tasks/tabular/tabularClassification.ts +4 -7
  136. package/src/tasks/tabular/tabularRegression.ts +4 -7
  137. package/src/types.ts +5 -14
  138. package/src/utils/request.ts +7 -4
  139. package/dist/src/lib/getProviderModelId.d.ts +0 -10
  140. package/dist/src/lib/getProviderModelId.d.ts.map +0 -1
  141. package/src/lib/getProviderModelId.ts +0 -74
package/dist/index.cjs CHANGED
@@ -98,91 +98,211 @@ __export(tasks_exports, {
98
98
  zeroShotImageClassification: () => zeroShotImageClassification
99
99
  });
100
100
 
101
+ // src/lib/InferenceOutputError.ts
102
+ var InferenceOutputError = class extends TypeError {
103
+ constructor(message) {
104
+ super(
105
+ `Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.`
106
+ );
107
+ this.name = "InferenceOutputError";
108
+ }
109
+ };
110
+
111
+ // src/utils/delay.ts
112
+ function delay(ms) {
113
+ return new Promise((resolve) => {
114
+ setTimeout(() => resolve(), ms);
115
+ });
116
+ }
117
+
118
+ // src/utils/pick.ts
119
+ function pick(o, props) {
120
+ return Object.assign(
121
+ {},
122
+ ...props.map((prop) => {
123
+ if (o[prop] !== void 0) {
124
+ return { [prop]: o[prop] };
125
+ }
126
+ })
127
+ );
128
+ }
129
+
130
+ // src/utils/typedInclude.ts
131
+ function typedInclude(arr, v) {
132
+ return arr.includes(v);
133
+ }
134
+
135
+ // src/utils/omit.ts
136
+ function omit(o, props) {
137
+ const propsArr = Array.isArray(props) ? props : [props];
138
+ const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
139
+ return pick(o, letsKeep);
140
+ }
141
+
101
142
  // src/config.ts
102
143
  var HF_HUB_URL = "https://huggingface.co";
103
144
  var HF_ROUTER_URL = "https://router.huggingface.co";
104
145
  var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
105
146
 
106
- // src/providers/black-forest-labs.ts
107
- var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
108
- var makeBaseUrl = () => {
109
- return BLACK_FOREST_LABS_AI_API_BASE_URL;
110
- };
111
- var makeBody = (params) => {
112
- return params.args;
147
+ // src/utils/toArray.ts
148
+ function toArray(obj) {
149
+ if (Array.isArray(obj)) {
150
+ return obj;
151
+ }
152
+ return [obj];
153
+ }
154
+
155
+ // src/providers/providerHelper.ts
156
+ var TaskProviderHelper = class {
157
+ constructor(provider, baseUrl, clientSideRoutingOnly = false) {
158
+ this.provider = provider;
159
+ this.baseUrl = baseUrl;
160
+ this.clientSideRoutingOnly = clientSideRoutingOnly;
161
+ }
162
+ /**
163
+ * Prepare the base URL for the request
164
+ */
165
+ makeBaseUrl(params) {
166
+ return params.authMethod !== "provider-key" ? `${HF_ROUTER_URL}/${this.provider}` : this.baseUrl;
167
+ }
168
+ /**
169
+ * Prepare the body for the request
170
+ */
171
+ makeBody(params) {
172
+ if ("data" in params.args && !!params.args.data) {
173
+ return params.args.data;
174
+ }
175
+ return JSON.stringify(this.preparePayload(params));
176
+ }
177
+ /**
178
+ * Prepare the URL for the request
179
+ */
180
+ makeUrl(params) {
181
+ const baseUrl = this.makeBaseUrl(params);
182
+ const route = this.makeRoute(params).replace(/^\/+/, "");
183
+ return `${baseUrl}/${route}`;
184
+ }
185
+ /**
186
+ * Prepare the headers for the request
187
+ */
188
+ prepareHeaders(params, isBinary) {
189
+ const headers = { Authorization: `Bearer ${params.accessToken}` };
190
+ if (!isBinary) {
191
+ headers["Content-Type"] = "application/json";
192
+ }
193
+ return headers;
194
+ }
113
195
  };
114
- var makeHeaders = (params) => {
115
- if (params.authMethod === "provider-key") {
116
- return { "X-Key": `${params.accessToken}` };
117
- } else {
118
- return { Authorization: `Bearer ${params.accessToken}` };
196
+ var BaseConversationalTask = class extends TaskProviderHelper {
197
+ constructor(provider, baseUrl, clientSideRoutingOnly = false) {
198
+ super(provider, baseUrl, clientSideRoutingOnly);
199
+ }
200
+ makeRoute() {
201
+ return "v1/chat/completions";
202
+ }
203
+ preparePayload(params) {
204
+ return {
205
+ ...params.args,
206
+ model: params.model
207
+ };
208
+ }
209
+ async getResponse(response) {
210
+ if (typeof response === "object" && Array.isArray(response?.choices) && typeof response?.created === "number" && typeof response?.id === "string" && typeof response?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
211
+ (response.system_fingerprint === void 0 || response.system_fingerprint === null || typeof response.system_fingerprint === "string") && typeof response?.usage === "object") {
212
+ return response;
213
+ }
214
+ throw new InferenceOutputError("Expected ChatCompletionOutput");
119
215
  }
120
216
  };
121
- var makeUrl = (params) => {
122
- return `${params.baseUrl}/v1/${params.model}`;
217
+ var BaseTextGenerationTask = class extends TaskProviderHelper {
218
+ constructor(provider, baseUrl, clientSideRoutingOnly = false) {
219
+ super(provider, baseUrl, clientSideRoutingOnly);
220
+ }
221
+ preparePayload(params) {
222
+ return {
223
+ ...params.args,
224
+ model: params.model
225
+ };
226
+ }
227
+ makeRoute() {
228
+ return "v1/completions";
229
+ }
230
+ async getResponse(response) {
231
+ const res = toArray(response);
232
+ if (Array.isArray(res) && res.length > 0 && res.every(
233
+ (x) => typeof x === "object" && !!x && "generated_text" in x && typeof x.generated_text === "string"
234
+ )) {
235
+ return res[0];
236
+ }
237
+ throw new InferenceOutputError("Expected Array<{generated_text: string}>");
238
+ }
123
239
  };
124
- var BLACK_FOREST_LABS_CONFIG = {
125
- makeBaseUrl,
126
- makeBody,
127
- makeHeaders,
128
- makeUrl
240
+
241
+ // src/providers/black-forest-labs.ts
242
+ var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
243
+ var BlackForestLabsTextToImageTask = class extends TaskProviderHelper {
244
+ constructor() {
245
+ super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
246
+ }
247
+ preparePayload(params) {
248
+ return {
249
+ ...omit(params.args, ["inputs", "parameters"]),
250
+ ...params.args.parameters,
251
+ prompt: params.args.inputs
252
+ };
253
+ }
254
+ prepareHeaders(params, binary) {
255
+ const headers = {
256
+ Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`
257
+ };
258
+ if (!binary) {
259
+ headers["Content-Type"] = "application/json";
260
+ }
261
+ return headers;
262
+ }
263
+ makeRoute(params) {
264
+ if (!params) {
265
+ throw new Error("Params are required");
266
+ }
267
+ return `/v1/${params.model}`;
268
+ }
269
+ async getResponse(response, url, headers, outputType) {
270
+ const urlObj = new URL(response.polling_url);
271
+ for (let step = 0; step < 5; step++) {
272
+ await delay(1e3);
273
+ console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
274
+ urlObj.searchParams.set("attempt", step.toString(10));
275
+ const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
276
+ if (!resp.ok) {
277
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
278
+ }
279
+ const payload = await resp.json();
280
+ if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
281
+ if (outputType === "url") {
282
+ return payload.result.sample;
283
+ }
284
+ const image = await fetch(payload.result.sample);
285
+ return await image.blob();
286
+ }
287
+ }
288
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
289
+ }
129
290
  };
130
291
 
131
292
  // src/providers/cerebras.ts
132
- var CEREBRAS_API_BASE_URL = "https://api.cerebras.ai";
133
- var makeBaseUrl2 = () => {
134
- return CEREBRAS_API_BASE_URL;
135
- };
136
- var makeBody2 = (params) => {
137
- return {
138
- ...params.args,
139
- model: params.model
140
- };
141
- };
142
- var makeHeaders2 = (params) => {
143
- return { Authorization: `Bearer ${params.accessToken}` };
144
- };
145
- var makeUrl2 = (params) => {
146
- return `${params.baseUrl}/v1/chat/completions`;
147
- };
148
- var CEREBRAS_CONFIG = {
149
- makeBaseUrl: makeBaseUrl2,
150
- makeBody: makeBody2,
151
- makeHeaders: makeHeaders2,
152
- makeUrl: makeUrl2
293
+ var CerebrasConversationalTask = class extends BaseConversationalTask {
294
+ constructor() {
295
+ super("cerebras", "https://api.cerebras.ai");
296
+ }
153
297
  };
154
298
 
155
299
  // src/providers/cohere.ts
156
- var COHERE_API_BASE_URL = "https://api.cohere.com";
157
- var makeBaseUrl3 = () => {
158
- return COHERE_API_BASE_URL;
159
- };
160
- var makeBody3 = (params) => {
161
- return {
162
- ...params.args,
163
- model: params.model
164
- };
165
- };
166
- var makeHeaders3 = (params) => {
167
- return { Authorization: `Bearer ${params.accessToken}` };
168
- };
169
- var makeUrl3 = (params) => {
170
- return `${params.baseUrl}/compatibility/v1/chat/completions`;
171
- };
172
- var COHERE_CONFIG = {
173
- makeBaseUrl: makeBaseUrl3,
174
- makeBody: makeBody3,
175
- makeHeaders: makeHeaders3,
176
- makeUrl: makeUrl3
177
- };
178
-
179
- // src/lib/InferenceOutputError.ts
180
- var InferenceOutputError = class extends TypeError {
181
- constructor(message) {
182
- super(
183
- `Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.`
184
- );
185
- this.name = "InferenceOutputError";
300
+ var CohereConversationalTask = class extends BaseConversationalTask {
301
+ constructor() {
302
+ super("cohere", "https://api.cohere.com");
303
+ }
304
+ makeRoute() {
305
+ return "/compatibility/v1/chat/completions";
186
306
  }
187
307
  };
188
308
 
@@ -191,352 +311,902 @@ function isUrl(modelOrUrl) {
191
311
  return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
192
312
  }
193
313
 
194
- // src/utils/delay.ts
195
- function delay(ms) {
196
- return new Promise((resolve) => {
197
- setTimeout(() => resolve(), ms);
198
- });
199
- }
200
-
201
314
  // src/providers/fal-ai.ts
202
- var FAL_AI_API_BASE_URL = "https://fal.run";
203
- var FAL_AI_API_BASE_URL_QUEUE = "https://queue.fal.run";
204
- var makeBaseUrl4 = (task) => {
205
- return task === "text-to-video" ? FAL_AI_API_BASE_URL_QUEUE : FAL_AI_API_BASE_URL;
315
+ var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
316
+ var FalAITask = class extends TaskProviderHelper {
317
+ constructor(url) {
318
+ super("fal-ai", url || "https://fal.run");
319
+ }
320
+ preparePayload(params) {
321
+ return params.args;
322
+ }
323
+ makeRoute(params) {
324
+ return `/${params.model}`;
325
+ }
326
+ prepareHeaders(params, binary) {
327
+ const headers = {
328
+ Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`
329
+ };
330
+ if (!binary) {
331
+ headers["Content-Type"] = "application/json";
332
+ }
333
+ return headers;
334
+ }
206
335
  };
207
- var makeBody4 = (params) => {
208
- return params.args;
336
+ function buildLoraPath(modelId, adapterWeightsPath) {
337
+ return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
338
+ }
339
+ var FalAITextToImageTask = class extends FalAITask {
340
+ preparePayload(params) {
341
+ const payload = {
342
+ ...omit(params.args, ["inputs", "parameters"]),
343
+ ...params.args.parameters,
344
+ sync_mode: true,
345
+ prompt: params.args.inputs,
346
+ ...params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath ? {
347
+ loras: [
348
+ {
349
+ path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
350
+ scale: 1
351
+ }
352
+ ]
353
+ } : void 0
354
+ };
355
+ if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
356
+ payload.loras = [
357
+ {
358
+ path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
359
+ scale: 1
360
+ }
361
+ ];
362
+ if (params.mapping.providerId === "fal-ai/lora") {
363
+ payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
364
+ }
365
+ }
366
+ return payload;
367
+ }
368
+ async getResponse(response, outputType) {
369
+ if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") {
370
+ if (outputType === "url") {
371
+ return response.images[0].url;
372
+ }
373
+ const urlResponse = await fetch(response.images[0].url);
374
+ return await urlResponse.blob();
375
+ }
376
+ throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
377
+ }
209
378
  };
210
- var makeHeaders4 = (params) => {
211
- return {
212
- Authorization: params.authMethod === "provider-key" ? `Key ${params.accessToken}` : `Bearer ${params.accessToken}`
213
- };
379
+ var FalAITextToVideoTask = class extends FalAITask {
380
+ constructor() {
381
+ super("https://queue.fal.run");
382
+ }
383
+ makeRoute(params) {
384
+ if (params.authMethod !== "provider-key") {
385
+ return `/${params.model}?_subdomain=queue`;
386
+ }
387
+ return `/${params.model}`;
388
+ }
389
+ preparePayload(params) {
390
+ return {
391
+ ...omit(params.args, ["inputs", "parameters"]),
392
+ ...params.args.parameters,
393
+ prompt: params.args.inputs
394
+ };
395
+ }
396
+ async getResponse(response, url, headers) {
397
+ if (!url || !headers) {
398
+ throw new InferenceOutputError("URL and headers are required for text-to-video task");
399
+ }
400
+ const requestId = response.request_id;
401
+ if (!requestId) {
402
+ throw new InferenceOutputError("No request ID found in the response");
403
+ }
404
+ let status = response.status;
405
+ const parsedUrl = new URL(url);
406
+ const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
407
+ const modelId = new URL(response.response_url).pathname;
408
+ const queryParams = parsedUrl.search;
409
+ const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
410
+ const resultUrl = `${baseUrl}${modelId}${queryParams}`;
411
+ while (status !== "COMPLETED") {
412
+ await delay(500);
413
+ const statusResponse = await fetch(statusUrl, { headers });
414
+ if (!statusResponse.ok) {
415
+ throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
416
+ }
417
+ try {
418
+ status = (await statusResponse.json()).status;
419
+ } catch (error) {
420
+ throw new InferenceOutputError("Failed to parse status response from fal-ai API");
421
+ }
422
+ }
423
+ const resultResponse = await fetch(resultUrl, { headers });
424
+ let result;
425
+ try {
426
+ result = await resultResponse.json();
427
+ } catch (error) {
428
+ throw new InferenceOutputError("Failed to parse result response from fal-ai API");
429
+ }
430
+ if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
431
+ const urlResponse = await fetch(result.video.url);
432
+ return await urlResponse.blob();
433
+ } else {
434
+ throw new InferenceOutputError(
435
+ "Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
436
+ );
437
+ }
438
+ }
439
+ };
440
+ var FalAIAutomaticSpeechRecognitionTask = class extends FalAITask {
441
+ prepareHeaders(params, binary) {
442
+ const headers = super.prepareHeaders(params, binary);
443
+ headers["Content-Type"] = "application/json";
444
+ return headers;
445
+ }
446
+ async getResponse(response) {
447
+ const res = response;
448
+ if (typeof res?.text !== "string") {
449
+ throw new InferenceOutputError(
450
+ `Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
451
+ );
452
+ }
453
+ return { text: res.text };
454
+ }
214
455
  };
215
- var makeUrl4 = (params) => {
216
- const baseUrl = `${params.baseUrl}/${params.model}`;
217
- if (params.authMethod !== "provider-key" && params.task === "text-to-video") {
218
- return `${baseUrl}?_subdomain=queue`;
219
- }
220
- return baseUrl;
221
- };
222
- var FAL_AI_CONFIG = {
223
- makeBaseUrl: makeBaseUrl4,
224
- makeBody: makeBody4,
225
- makeHeaders: makeHeaders4,
226
- makeUrl: makeUrl4
227
- };
228
- async function pollFalResponse(res, url, headers) {
229
- const requestId = res.request_id;
230
- if (!requestId) {
231
- throw new InferenceOutputError("No request ID found in the response");
232
- }
233
- let status = res.status;
234
- const parsedUrl = new URL(url);
235
- const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
236
- const modelId = new URL(res.response_url).pathname;
237
- const queryParams = parsedUrl.search;
238
- const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
239
- const resultUrl = `${baseUrl}${modelId}${queryParams}`;
240
- while (status !== "COMPLETED") {
241
- await delay(500);
242
- const statusResponse = await fetch(statusUrl, { headers });
243
- if (!statusResponse.ok) {
244
- throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
456
+ var FalAITextToSpeechTask = class extends FalAITask {
457
+ preparePayload(params) {
458
+ return {
459
+ ...omit(params.args, ["inputs", "parameters"]),
460
+ ...params.args.parameters,
461
+ lyrics: params.args.inputs
462
+ };
463
+ }
464
+ async getResponse(response) {
465
+ const res = response;
466
+ if (typeof res?.audio?.url !== "string") {
467
+ throw new InferenceOutputError(
468
+ `Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
469
+ );
245
470
  }
246
471
  try {
247
- status = (await statusResponse.json()).status;
472
+ const urlResponse = await fetch(res.audio.url);
473
+ if (!urlResponse.ok) {
474
+ throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
475
+ }
476
+ return await urlResponse.blob();
248
477
  } catch (error) {
249
- throw new InferenceOutputError("Failed to parse status response from fal-ai API");
478
+ throw new InferenceOutputError(
479
+ `Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)}`
480
+ );
250
481
  }
251
482
  }
252
- const resultResponse = await fetch(resultUrl, { headers });
253
- let result;
254
- try {
255
- result = await resultResponse.json();
256
- } catch (error) {
257
- throw new InferenceOutputError("Failed to parse result response from fal-ai API");
483
+ };
484
+
485
+ // src/providers/fireworks-ai.ts
486
+ var FireworksConversationalTask = class extends BaseConversationalTask {
487
+ constructor() {
488
+ super("fireworks-ai", "https://api.fireworks.ai");
258
489
  }
259
- if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
260
- const urlResponse = await fetch(result.video.url);
261
- return await urlResponse.blob();
262
- } else {
490
+ makeRoute() {
491
+ return "/inference/v1/chat/completions";
492
+ }
493
+ };
494
+
495
+ // src/providers/hf-inference.ts
496
+ var EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = ["feature-extraction", "sentence-similarity"];
497
+ var HFInferenceTask = class extends TaskProviderHelper {
498
+ constructor() {
499
+ super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
500
+ }
501
+ preparePayload(params) {
502
+ return params.args;
503
+ }
504
+ makeUrl(params) {
505
+ if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
506
+ return params.model;
507
+ }
508
+ return super.makeUrl(params);
509
+ }
510
+ makeRoute(params) {
511
+ if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
512
+ return `pipeline/${params.task}/${params.model}`;
513
+ }
514
+ return `models/${params.model}`;
515
+ }
516
+ async getResponse(response) {
517
+ return response;
518
+ }
519
+ };
520
+ var HFInferenceTextToImageTask = class extends HFInferenceTask {
521
+ async getResponse(response, url, headers, outputType) {
522
+ if (!response) {
523
+ throw new InferenceOutputError("response is undefined");
524
+ }
525
+ if (typeof response == "object") {
526
+ if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
527
+ const base64Data = response.data[0].b64_json;
528
+ if (outputType === "url") {
529
+ return `data:image/jpeg;base64,${base64Data}`;
530
+ }
531
+ const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
532
+ return await base64Response.blob();
533
+ }
534
+ if ("output" in response && Array.isArray(response.output)) {
535
+ if (outputType === "url") {
536
+ return response.output[0];
537
+ }
538
+ const urlResponse = await fetch(response.output[0]);
539
+ const blob = await urlResponse.blob();
540
+ return blob;
541
+ }
542
+ }
543
+ if (response instanceof Blob) {
544
+ if (outputType === "url") {
545
+ const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
546
+ return `data:image/jpeg;base64,${b64}`;
547
+ }
548
+ return response;
549
+ }
550
+ throw new InferenceOutputError("Expected a Blob ");
551
+ }
552
+ };
553
+ var HFInferenceConversationalTask = class extends HFInferenceTask {
554
+ makeUrl(params) {
555
+ let url;
556
+ if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
557
+ url = params.model.trim();
558
+ } else {
559
+ url = `${this.makeBaseUrl(params)}/models/${params.model}`;
560
+ }
561
+ url = url.replace(/\/+$/, "");
562
+ if (url.endsWith("/v1")) {
563
+ url += "/chat/completions";
564
+ } else if (!url.endsWith("/chat/completions")) {
565
+ url += "/v1/chat/completions";
566
+ }
567
+ return url;
568
+ }
569
+ preparePayload(params) {
570
+ return {
571
+ ...params.args,
572
+ model: params.model
573
+ };
574
+ }
575
+ async getResponse(response) {
576
+ return response;
577
+ }
578
+ };
579
+ var HFInferenceTextGenerationTask = class extends HFInferenceTask {
580
+ async getResponse(response) {
581
+ const res = toArray(response);
582
+ if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
583
+ return res?.[0];
584
+ }
585
+ throw new InferenceOutputError("Expected Array<{generated_text: string}>");
586
+ }
587
+ };
588
+ var HFInferenceAudioClassificationTask = class extends HFInferenceTask {
589
+ async getResponse(response) {
590
+ if (Array.isArray(response) && response.every(
591
+ (x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
592
+ )) {
593
+ return response;
594
+ }
595
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
596
+ }
597
+ };
598
+ var HFInferenceAutomaticSpeechRecognitionTask = class extends HFInferenceTask {
599
+ async getResponse(response) {
600
+ return response;
601
+ }
602
+ };
603
+ var HFInferenceAudioToAudioTask = class extends HFInferenceTask {
604
+ async getResponse(response) {
605
+ if (!Array.isArray(response)) {
606
+ throw new InferenceOutputError("Expected Array");
607
+ }
608
+ if (!response.every((elem) => {
609
+ return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
610
+ })) {
611
+ throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
612
+ }
613
+ return response;
614
+ }
615
+ };
616
+ var HFInferenceDocumentQuestionAnsweringTask = class extends HFInferenceTask {
617
+ async getResponse(response) {
618
+ if (Array.isArray(response) && response.every(
619
+ (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
620
+ )) {
621
+ return response[0];
622
+ }
623
+ throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
624
+ }
625
+ };
626
+ var HFInferenceFeatureExtractionTask = class extends HFInferenceTask {
627
+ async getResponse(response) {
628
+ const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
629
+ if (curDepth > maxDepth)
630
+ return false;
631
+ if (arr.every((x) => Array.isArray(x))) {
632
+ return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
633
+ } else {
634
+ return arr.every((x) => typeof x === "number");
635
+ }
636
+ };
637
+ if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
638
+ return response;
639
+ }
640
+ throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
641
+ }
642
+ };
643
+ var HFInferenceImageClassificationTask = class extends HFInferenceTask {
644
+ async getResponse(response) {
645
+ if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
646
+ return response;
647
+ }
648
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
649
+ }
650
+ };
651
+ var HFInferenceImageSegmentationTask = class extends HFInferenceTask {
652
+ async getResponse(response) {
653
+ if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) {
654
+ return response;
655
+ }
656
+ throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
657
+ }
658
+ };
659
+ var HFInferenceImageToTextTask = class extends HFInferenceTask {
660
+ async getResponse(response) {
661
+ if (typeof response?.generated_text !== "string") {
662
+ throw new InferenceOutputError("Expected {generated_text: string}");
663
+ }
664
+ return response;
665
+ }
666
+ };
667
+ var HFInferenceImageToImageTask = class extends HFInferenceTask {
668
+ async getResponse(response) {
669
+ if (response instanceof Blob) {
670
+ return response;
671
+ }
672
+ throw new InferenceOutputError("Expected Blob");
673
+ }
674
+ };
675
+ var HFInferenceObjectDetectionTask = class extends HFInferenceTask {
676
+ async getResponse(response) {
677
+ if (Array.isArray(response) && response.every(
678
+ (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
679
+ )) {
680
+ return response;
681
+ }
263
682
  throw new InferenceOutputError(
264
- "Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
683
+ "Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
265
684
  );
266
685
  }
267
- }
268
-
269
- // src/providers/fireworks-ai.ts
270
- var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai";
271
- var makeBaseUrl5 = () => {
272
- return FIREWORKS_AI_API_BASE_URL;
273
686
  };
274
- var makeBody5 = (params) => {
275
- return {
276
- ...params.args,
277
- ...params.chatCompletion ? { model: params.model } : void 0
278
- };
687
+ var HFInferenceZeroShotImageClassificationTask = class extends HFInferenceTask {
688
+ async getResponse(response) {
689
+ if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
690
+ return response;
691
+ }
692
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
693
+ }
279
694
  };
280
- var makeHeaders5 = (params) => {
281
- return { Authorization: `Bearer ${params.accessToken}` };
695
+ var HFInferenceTextClassificationTask = class extends HFInferenceTask {
696
+ async getResponse(response) {
697
+ const output = response?.[0];
698
+ if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
699
+ return output;
700
+ }
701
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
702
+ }
282
703
  };
283
- var makeUrl5 = (params) => {
284
- if (params.chatCompletion) {
285
- return `${params.baseUrl}/inference/v1/chat/completions`;
704
+ var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask {
705
+ async getResponse(response) {
706
+ if (Array.isArray(response) ? response.every(
707
+ (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
708
+ ) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number") {
709
+ return Array.isArray(response) ? response[0] : response;
710
+ }
711
+ throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
286
712
  }
287
- return `${params.baseUrl}/inference`;
288
713
  };
289
- var FIREWORKS_AI_CONFIG = {
290
- makeBaseUrl: makeBaseUrl5,
291
- makeBody: makeBody5,
292
- makeHeaders: makeHeaders5,
293
- makeUrl: makeUrl5
714
+ var HFInferenceFillMaskTask = class extends HFInferenceTask {
715
+ async getResponse(response) {
716
+ if (Array.isArray(response) && response.every(
717
+ (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
718
+ )) {
719
+ return response;
720
+ }
721
+ throw new InferenceOutputError(
722
+ "Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
723
+ );
724
+ }
294
725
  };
295
-
296
- // src/providers/hf-inference.ts
297
- var makeBaseUrl6 = () => {
298
- return `${HF_ROUTER_URL}/hf-inference`;
726
+ var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask {
727
+ async getResponse(response) {
728
+ if (Array.isArray(response) && response.every(
729
+ (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
730
+ )) {
731
+ return response;
732
+ }
733
+ throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
734
+ }
299
735
  };
300
- var makeBody6 = (params) => {
301
- return {
302
- ...params.args,
303
- ...params.chatCompletion ? { model: params.model } : void 0
304
- };
736
+ var HFInferenceSentenceSimilarityTask = class extends HFInferenceTask {
737
+ async getResponse(response) {
738
+ if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
739
+ return response;
740
+ }
741
+ throw new InferenceOutputError("Expected Array<number>");
742
+ }
305
743
  };
306
- var makeHeaders6 = (params) => {
307
- return { Authorization: `Bearer ${params.accessToken}` };
744
+ var HFInferenceTableQuestionAnsweringTask = class extends HFInferenceTask {
745
+ static validate(elem) {
746
+ return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
747
+ (coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
748
+ );
749
+ }
750
+ async getResponse(response) {
751
+ if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) {
752
+ return Array.isArray(response) ? response[0] : response;
753
+ }
754
+ throw new InferenceOutputError(
755
+ "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
756
+ );
757
+ }
308
758
  };
309
- var makeUrl6 = (params) => {
310
- if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
311
- return `${params.baseUrl}/pipeline/${params.task}/${params.model}`;
759
+ var HFInferenceTokenClassificationTask = class extends HFInferenceTask {
760
+ async getResponse(response) {
761
+ if (Array.isArray(response) && response.every(
762
+ (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
763
+ )) {
764
+ return response;
765
+ }
766
+ throw new InferenceOutputError(
767
+ "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
768
+ );
312
769
  }
313
- if (params.chatCompletion) {
314
- return `${params.baseUrl}/models/${params.model}/v1/chat/completions`;
770
+ };
771
+ var HFInferenceTranslationTask = class extends HFInferenceTask {
772
+ async getResponse(response) {
773
+ if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
774
+ return response?.length === 1 ? response?.[0] : response;
775
+ }
776
+ throw new InferenceOutputError("Expected Array<{translation_text: string}>");
777
+ }
778
+ };
779
+ var HFInferenceSummarizationTask = class extends HFInferenceTask {
780
+ async getResponse(response) {
781
+ if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
782
+ return response?.[0];
783
+ }
784
+ throw new InferenceOutputError("Expected Array<{summary_text: string}>");
785
+ }
786
+ };
787
+ var HFInferenceTextToSpeechTask = class extends HFInferenceTask {
788
+ async getResponse(response) {
789
+ return response;
315
790
  }
316
- return `${params.baseUrl}/models/${params.model}`;
317
791
  };
318
- var HF_INFERENCE_CONFIG = {
319
- makeBaseUrl: makeBaseUrl6,
320
- makeBody: makeBody6,
321
- makeHeaders: makeHeaders6,
322
- makeUrl: makeUrl6
792
+ var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
793
+ async getResponse(response) {
794
+ if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
795
+ return response;
796
+ }
797
+ throw new InferenceOutputError("Expected Array<number>");
798
+ }
799
+ };
800
+ var HFInferenceVisualQuestionAnsweringTask = class extends HFInferenceTask {
801
+ async getResponse(response) {
802
+ if (Array.isArray(response) && response.every(
803
+ (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
804
+ )) {
805
+ return response[0];
806
+ }
807
+ throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
808
+ }
809
+ };
810
+ var HFInferenceTabularRegressionTask = class extends HFInferenceTask {
811
+ async getResponse(response) {
812
+ if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
813
+ return response;
814
+ }
815
+ throw new InferenceOutputError("Expected Array<number>");
816
+ }
817
+ };
818
+ var HFInferenceTextToAudioTask = class extends HFInferenceTask {
819
+ async getResponse(response) {
820
+ return response;
821
+ }
323
822
  };
324
823
 
325
824
  // src/providers/hyperbolic.ts
326
825
  var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
327
- var makeBaseUrl7 = () => {
328
- return HYPERBOLIC_API_BASE_URL;
329
- };
330
- var makeBody7 = (params) => {
331
- return {
332
- ...params.args,
333
- ...params.task === "text-to-image" ? { model_name: params.model } : { model: params.model }
334
- };
335
- };
336
- var makeHeaders7 = (params) => {
337
- return { Authorization: `Bearer ${params.accessToken}` };
826
+ var HyperbolicConversationalTask = class extends BaseConversationalTask {
827
+ constructor() {
828
+ super("hyperbolic", HYPERBOLIC_API_BASE_URL);
829
+ }
338
830
  };
339
- var makeUrl7 = (params) => {
340
- if (params.task === "text-to-image") {
341
- return `${params.baseUrl}/v1/images/generations`;
831
+ var HyperbolicTextGenerationTask = class extends BaseTextGenerationTask {
832
+ constructor() {
833
+ super("hyperbolic", HYPERBOLIC_API_BASE_URL);
834
+ }
835
+ makeRoute() {
836
+ return "v1/chat/completions";
837
+ }
838
+ preparePayload(params) {
839
+ return {
840
+ messages: [{ content: params.args.inputs, role: "user" }],
841
+ ...params.args.parameters ? {
842
+ max_tokens: params.args.parameters.max_new_tokens,
843
+ ...omit(params.args.parameters, "max_new_tokens")
844
+ } : void 0,
845
+ ...omit(params.args, ["inputs", "parameters"]),
846
+ model: params.model
847
+ };
848
+ }
849
+ async getResponse(response) {
850
+ if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
851
+ const completion = response.choices[0];
852
+ return {
853
+ generated_text: completion.message.content
854
+ };
855
+ }
856
+ throw new InferenceOutputError("Expected Hyperbolic text generation response format");
342
857
  }
343
- return `${params.baseUrl}/v1/chat/completions`;
344
858
  };
345
- var HYPERBOLIC_CONFIG = {
346
- makeBaseUrl: makeBaseUrl7,
347
- makeBody: makeBody7,
348
- makeHeaders: makeHeaders7,
349
- makeUrl: makeUrl7
859
+ var HyperbolicTextToImageTask = class extends TaskProviderHelper {
860
+ constructor() {
861
+ super("hyperbolic", HYPERBOLIC_API_BASE_URL);
862
+ }
863
+ makeRoute(params) {
864
+ return `/v1/images/generations`;
865
+ }
866
+ preparePayload(params) {
867
+ return {
868
+ ...omit(params.args, ["inputs", "parameters"]),
869
+ ...params.args.parameters,
870
+ prompt: params.args.inputs,
871
+ model_name: params.model
872
+ };
873
+ }
874
+ async getResponse(response, url, headers, outputType) {
875
+ if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images[0] && typeof response.images[0].image === "string") {
876
+ if (outputType === "url") {
877
+ return `data:image/jpeg;base64,${response.images[0].image}`;
878
+ }
879
+ return fetch(`data:image/jpeg;base64,${response.images[0].image}`).then((res) => res.blob());
880
+ }
881
+ throw new InferenceOutputError("Expected Hyperbolic text-to-image response format");
882
+ }
350
883
  };
351
884
 
352
885
  // src/providers/nebius.ts
353
886
  var NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
354
- var makeBaseUrl8 = () => {
355
- return NEBIUS_API_BASE_URL;
356
- };
357
- var makeBody8 = (params) => {
358
- return {
359
- ...params.args,
360
- model: params.model
361
- };
887
+ var NebiusConversationalTask = class extends BaseConversationalTask {
888
+ constructor() {
889
+ super("nebius", NEBIUS_API_BASE_URL);
890
+ }
362
891
  };
363
- var makeHeaders8 = (params) => {
364
- return { Authorization: `Bearer ${params.accessToken}` };
892
+ var NebiusTextGenerationTask = class extends BaseTextGenerationTask {
893
+ constructor() {
894
+ super("nebius", NEBIUS_API_BASE_URL);
895
+ }
365
896
  };
366
- var makeUrl8 = (params) => {
367
- if (params.task === "text-to-image") {
368
- return `${params.baseUrl}/v1/images/generations`;
897
+ var NebiusTextToImageTask = class extends TaskProviderHelper {
898
+ constructor() {
899
+ super("nebius", NEBIUS_API_BASE_URL);
369
900
  }
370
- if (params.chatCompletion) {
371
- return `${params.baseUrl}/v1/chat/completions`;
901
+ preparePayload(params) {
902
+ return {
903
+ ...omit(params.args, ["inputs", "parameters"]),
904
+ ...params.args.parameters,
905
+ response_format: "b64_json",
906
+ prompt: params.args.inputs,
907
+ model: params.model
908
+ };
372
909
  }
373
- if (params.task === "text-generation") {
374
- return `${params.baseUrl}/v1/completions`;
910
+ makeRoute(params) {
911
+ return "v1/images/generations";
912
+ }
913
+ async getResponse(response, url, headers, outputType) {
914
+ if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
915
+ const base64Data = response.data[0].b64_json;
916
+ if (outputType === "url") {
917
+ return `data:image/jpeg;base64,${base64Data}`;
918
+ }
919
+ return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
920
+ }
921
+ throw new InferenceOutputError("Expected Nebius text-to-image response format");
375
922
  }
376
- return params.baseUrl;
377
- };
378
- var NEBIUS_CONFIG = {
379
- makeBaseUrl: makeBaseUrl8,
380
- makeBody: makeBody8,
381
- makeHeaders: makeHeaders8,
382
- makeUrl: makeUrl8
383
923
  };
384
924
 
385
925
  // src/providers/novita.ts
386
926
  var NOVITA_API_BASE_URL = "https://api.novita.ai";
387
- var makeBaseUrl9 = () => {
388
- return NOVITA_API_BASE_URL;
389
- };
390
- var makeBody9 = (params) => {
391
- return {
392
- ...params.args,
393
- ...params.chatCompletion ? { model: params.model } : void 0
394
- };
395
- };
396
- var makeHeaders9 = (params) => {
397
- return { Authorization: `Bearer ${params.accessToken}` };
927
+ var NovitaTextGenerationTask = class extends BaseTextGenerationTask {
928
+ constructor() {
929
+ super("novita", NOVITA_API_BASE_URL);
930
+ }
931
+ makeRoute() {
932
+ return "/v3/openai/chat/completions";
933
+ }
398
934
  };
399
- var makeUrl9 = (params) => {
400
- if (params.chatCompletion) {
401
- return `${params.baseUrl}/v3/openai/chat/completions`;
402
- } else if (params.task === "text-generation") {
403
- return `${params.baseUrl}/v3/openai/completions`;
404
- } else if (params.task === "text-to-video") {
405
- return `${params.baseUrl}/v3/hf/${params.model}`;
935
+ var NovitaConversationalTask = class extends BaseConversationalTask {
936
+ constructor() {
937
+ super("novita", NOVITA_API_BASE_URL);
938
+ }
939
+ makeRoute() {
940
+ return "/v3/openai/chat/completions";
406
941
  }
407
- return params.baseUrl;
408
942
  };
409
- var NOVITA_CONFIG = {
410
- makeBaseUrl: makeBaseUrl9,
411
- makeBody: makeBody9,
412
- makeHeaders: makeHeaders9,
413
- makeUrl: makeUrl9
943
+
944
+ // src/providers/openai.ts
945
+ var OPENAI_API_BASE_URL = "https://api.openai.com";
946
+ var OpenAIConversationalTask = class extends BaseConversationalTask {
947
+ constructor() {
948
+ super("openai", OPENAI_API_BASE_URL, true);
949
+ }
414
950
  };
415
951
 
416
952
  // src/providers/replicate.ts
417
- var REPLICATE_API_BASE_URL = "https://api.replicate.com";
418
- var makeBaseUrl10 = () => {
419
- return REPLICATE_API_BASE_URL;
420
- };
421
- var makeBody10 = (params) => {
422
- return {
423
- input: params.args,
424
- version: params.model.includes(":") ? params.model.split(":")[1] : void 0
425
- };
953
+ var ReplicateTask = class extends TaskProviderHelper {
954
+ constructor(url) {
955
+ super("replicate", url || "https://api.replicate.com");
956
+ }
957
+ makeRoute(params) {
958
+ if (params.model.includes(":")) {
959
+ return "v1/predictions";
960
+ }
961
+ return `v1/models/${params.model}/predictions`;
962
+ }
963
+ preparePayload(params) {
964
+ return {
965
+ input: {
966
+ ...omit(params.args, ["inputs", "parameters"]),
967
+ ...params.args.parameters,
968
+ prompt: params.args.inputs
969
+ },
970
+ version: params.model.includes(":") ? params.model.split(":")[1] : void 0
971
+ };
972
+ }
973
+ prepareHeaders(params, binary) {
974
+ const headers = { Authorization: `Bearer ${params.accessToken}`, Prefer: "wait" };
975
+ if (!binary) {
976
+ headers["Content-Type"] = "application/json";
977
+ }
978
+ return headers;
979
+ }
980
+ makeUrl(params) {
981
+ const baseUrl = this.makeBaseUrl(params);
982
+ if (params.model.includes(":")) {
983
+ return `${baseUrl}/v1/predictions`;
984
+ }
985
+ return `${baseUrl}/v1/models/${params.model}/predictions`;
986
+ }
426
987
  };
427
- var makeHeaders10 = (params) => {
428
- return { Authorization: `Bearer ${params.accessToken}`, Prefer: "wait" };
988
+ var ReplicateTextToImageTask = class extends ReplicateTask {
989
+ async getResponse(res, url, headers, outputType) {
990
+ if (typeof res === "object" && "output" in res && Array.isArray(res.output) && res.output.length > 0 && typeof res.output[0] === "string") {
991
+ if (outputType === "url") {
992
+ return res.output[0];
993
+ }
994
+ const urlResponse = await fetch(res.output[0]);
995
+ return await urlResponse.blob();
996
+ }
997
+ throw new InferenceOutputError("Expected Replicate text-to-image response format");
998
+ }
429
999
  };
430
- var makeUrl10 = (params) => {
431
- if (params.model.includes(":")) {
432
- return `${params.baseUrl}/v1/predictions`;
1000
+ var ReplicateTextToSpeechTask = class extends ReplicateTask {
1001
+ preparePayload(params) {
1002
+ const payload = super.preparePayload(params);
1003
+ const input = payload["input"];
1004
+ if (typeof input === "object" && input !== null && "prompt" in input) {
1005
+ const inputObj = input;
1006
+ inputObj["text"] = inputObj["prompt"];
1007
+ delete inputObj["prompt"];
1008
+ }
1009
+ return payload;
1010
+ }
1011
+ async getResponse(response) {
1012
+ if (response instanceof Blob) {
1013
+ return response;
1014
+ }
1015
+ if (response && typeof response === "object") {
1016
+ if ("output" in response) {
1017
+ if (typeof response.output === "string") {
1018
+ const urlResponse = await fetch(response.output);
1019
+ return await urlResponse.blob();
1020
+ } else if (Array.isArray(response.output)) {
1021
+ const urlResponse = await fetch(response.output[0]);
1022
+ return await urlResponse.blob();
1023
+ }
1024
+ }
1025
+ }
1026
+ throw new InferenceOutputError("Expected Blob or object with output");
433
1027
  }
434
- return `${params.baseUrl}/v1/models/${params.model}/predictions`;
435
1028
  };
436
- var REPLICATE_CONFIG = {
437
- makeBaseUrl: makeBaseUrl10,
438
- makeBody: makeBody10,
439
- makeHeaders: makeHeaders10,
440
- makeUrl: makeUrl10
1029
+ var ReplicateTextToVideoTask = class extends ReplicateTask {
1030
+ async getResponse(response) {
1031
+ if (typeof response === "object" && !!response && "output" in response && typeof response.output === "string" && isUrl(response.output)) {
1032
+ const urlResponse = await fetch(response.output);
1033
+ return await urlResponse.blob();
1034
+ }
1035
+ throw new InferenceOutputError("Expected { output: string }");
1036
+ }
441
1037
  };
442
1038
 
443
1039
  // src/providers/sambanova.ts
444
- var SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
445
- var makeBaseUrl11 = () => {
446
- return SAMBANOVA_API_BASE_URL;
447
- };
448
- var makeBody11 = (params) => {
449
- return {
450
- ...params.args,
451
- ...params.chatCompletion ? { model: params.model } : void 0
452
- };
453
- };
454
- var makeHeaders11 = (params) => {
455
- return { Authorization: `Bearer ${params.accessToken}` };
456
- };
457
- var makeUrl11 = (params) => {
458
- if (params.chatCompletion) {
459
- return `${params.baseUrl}/v1/chat/completions`;
1040
+ var SambanovaConversationalTask = class extends BaseConversationalTask {
1041
+ constructor() {
1042
+ super("sambanova", "https://api.sambanova.ai");
460
1043
  }
461
- return params.baseUrl;
462
- };
463
- var SAMBANOVA_CONFIG = {
464
- makeBaseUrl: makeBaseUrl11,
465
- makeBody: makeBody11,
466
- makeHeaders: makeHeaders11,
467
- makeUrl: makeUrl11
468
1044
  };
469
1045
 
470
1046
  // src/providers/together.ts
471
1047
  var TOGETHER_API_BASE_URL = "https://api.together.xyz";
472
- var makeBaseUrl12 = () => {
473
- return TOGETHER_API_BASE_URL;
474
- };
475
- var makeBody12 = (params) => {
476
- return {
477
- ...params.args,
478
- model: params.model
479
- };
480
- };
481
- var makeHeaders12 = (params) => {
482
- return { Authorization: `Bearer ${params.accessToken}` };
1048
+ var TogetherConversationalTask = class extends BaseConversationalTask {
1049
+ constructor() {
1050
+ super("together", TOGETHER_API_BASE_URL);
1051
+ }
483
1052
  };
484
- var makeUrl12 = (params) => {
485
- if (params.task === "text-to-image") {
486
- return `${params.baseUrl}/v1/images/generations`;
1053
+ var TogetherTextGenerationTask = class extends BaseTextGenerationTask {
1054
+ constructor() {
1055
+ super("together", TOGETHER_API_BASE_URL);
487
1056
  }
488
- if (params.chatCompletion) {
489
- return `${params.baseUrl}/v1/chat/completions`;
1057
+ preparePayload(params) {
1058
+ return {
1059
+ model: params.model,
1060
+ ...params.args,
1061
+ prompt: params.args.inputs
1062
+ };
490
1063
  }
491
- if (params.task === "text-generation") {
492
- return `${params.baseUrl}/v1/completions`;
1064
+ async getResponse(response) {
1065
+ if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
1066
+ const completion = response.choices[0];
1067
+ return {
1068
+ generated_text: completion.text
1069
+ };
1070
+ }
1071
+ throw new InferenceOutputError("Expected Together text generation response format");
493
1072
  }
494
- return params.baseUrl;
495
1073
  };
496
- var TOGETHER_CONFIG = {
497
- makeBaseUrl: makeBaseUrl12,
498
- makeBody: makeBody12,
499
- makeHeaders: makeHeaders12,
500
- makeUrl: makeUrl12
1074
+ var TogetherTextToImageTask = class extends TaskProviderHelper {
1075
+ constructor() {
1076
+ super("together", TOGETHER_API_BASE_URL);
1077
+ }
1078
+ makeRoute() {
1079
+ return "v1/images/generations";
1080
+ }
1081
+ preparePayload(params) {
1082
+ return {
1083
+ ...omit(params.args, ["inputs", "parameters"]),
1084
+ ...params.args.parameters,
1085
+ prompt: params.args.inputs,
1086
+ response_format: "base64",
1087
+ model: params.model
1088
+ };
1089
+ }
1090
+ async getResponse(response, outputType) {
1091
+ if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
1092
+ const base64Data = response.data[0].b64_json;
1093
+ if (outputType === "url") {
1094
+ return `data:image/jpeg;base64,${base64Data}`;
1095
+ }
1096
+ return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
1097
+ }
1098
+ throw new InferenceOutputError("Expected Together text-to-image response format");
1099
+ }
501
1100
  };
502
1101
 
503
- // src/providers/openai.ts
504
- var OPENAI_API_BASE_URL = "https://api.openai.com";
505
- var makeBaseUrl13 = () => {
506
- return OPENAI_API_BASE_URL;
507
- };
508
- var makeBody13 = (params) => {
509
- if (!params.chatCompletion) {
510
- throw new Error("OpenAI only supports chat completions.");
1102
+ // src/lib/getProviderHelper.ts
1103
+ var PROVIDERS = {
1104
+ "black-forest-labs": {
1105
+ "text-to-image": new BlackForestLabsTextToImageTask()
1106
+ },
1107
+ cerebras: {
1108
+ conversational: new CerebrasConversationalTask()
1109
+ },
1110
+ cohere: {
1111
+ conversational: new CohereConversationalTask()
1112
+ },
1113
+ "fal-ai": {
1114
+ "text-to-image": new FalAITextToImageTask(),
1115
+ "text-to-speech": new FalAITextToSpeechTask(),
1116
+ "text-to-video": new FalAITextToVideoTask(),
1117
+ "automatic-speech-recognition": new FalAIAutomaticSpeechRecognitionTask()
1118
+ },
1119
+ "hf-inference": {
1120
+ "text-to-image": new HFInferenceTextToImageTask(),
1121
+ conversational: new HFInferenceConversationalTask(),
1122
+ "text-generation": new HFInferenceTextGenerationTask(),
1123
+ "text-classification": new HFInferenceTextClassificationTask(),
1124
+ "question-answering": new HFInferenceQuestionAnsweringTask(),
1125
+ "audio-classification": new HFInferenceAudioClassificationTask(),
1126
+ "automatic-speech-recognition": new HFInferenceAutomaticSpeechRecognitionTask(),
1127
+ "fill-mask": new HFInferenceFillMaskTask(),
1128
+ "feature-extraction": new HFInferenceFeatureExtractionTask(),
1129
+ "image-classification": new HFInferenceImageClassificationTask(),
1130
+ "image-segmentation": new HFInferenceImageSegmentationTask(),
1131
+ "document-question-answering": new HFInferenceDocumentQuestionAnsweringTask(),
1132
+ "image-to-text": new HFInferenceImageToTextTask(),
1133
+ "object-detection": new HFInferenceObjectDetectionTask(),
1134
+ "audio-to-audio": new HFInferenceAudioToAudioTask(),
1135
+ "zero-shot-image-classification": new HFInferenceZeroShotImageClassificationTask(),
1136
+ "zero-shot-classification": new HFInferenceZeroShotClassificationTask(),
1137
+ "image-to-image": new HFInferenceImageToImageTask(),
1138
+ "sentence-similarity": new HFInferenceSentenceSimilarityTask(),
1139
+ "table-question-answering": new HFInferenceTableQuestionAnsweringTask(),
1140
+ "tabular-classification": new HFInferenceTabularClassificationTask(),
1141
+ "text-to-speech": new HFInferenceTextToSpeechTask(),
1142
+ "token-classification": new HFInferenceTokenClassificationTask(),
1143
+ translation: new HFInferenceTranslationTask(),
1144
+ summarization: new HFInferenceSummarizationTask(),
1145
+ "visual-question-answering": new HFInferenceVisualQuestionAnsweringTask(),
1146
+ "tabular-regression": new HFInferenceTabularRegressionTask(),
1147
+ "text-to-audio": new HFInferenceTextToAudioTask()
1148
+ },
1149
+ "fireworks-ai": {
1150
+ conversational: new FireworksConversationalTask()
1151
+ },
1152
+ hyperbolic: {
1153
+ "text-to-image": new HyperbolicTextToImageTask(),
1154
+ conversational: new HyperbolicConversationalTask(),
1155
+ "text-generation": new HyperbolicTextGenerationTask()
1156
+ },
1157
+ nebius: {
1158
+ "text-to-image": new NebiusTextToImageTask(),
1159
+ conversational: new NebiusConversationalTask(),
1160
+ "text-generation": new NebiusTextGenerationTask()
1161
+ },
1162
+ novita: {
1163
+ conversational: new NovitaConversationalTask(),
1164
+ "text-generation": new NovitaTextGenerationTask()
1165
+ },
1166
+ openai: {
1167
+ conversational: new OpenAIConversationalTask()
1168
+ },
1169
+ replicate: {
1170
+ "text-to-image": new ReplicateTextToImageTask(),
1171
+ "text-to-speech": new ReplicateTextToSpeechTask(),
1172
+ "text-to-video": new ReplicateTextToVideoTask()
1173
+ },
1174
+ sambanova: {
1175
+ conversational: new SambanovaConversationalTask()
1176
+ },
1177
+ together: {
1178
+ "text-to-image": new TogetherTextToImageTask(),
1179
+ conversational: new TogetherConversationalTask(),
1180
+ "text-generation": new TogetherTextGenerationTask()
511
1181
  }
512
- return {
513
- ...params.args,
514
- model: params.model
515
- };
516
1182
  };
517
- var makeHeaders13 = (params) => {
518
- return { Authorization: `Bearer ${params.accessToken}` };
519
- };
520
- var makeUrl13 = (params) => {
521
- if (!params.chatCompletion) {
522
- throw new Error("OpenAI only supports chat completions.");
1183
+ function getProviderHelper(provider, task) {
1184
+ if (provider === "hf-inference") {
1185
+ if (!task) {
1186
+ return new HFInferenceTask();
1187
+ }
523
1188
  }
524
- return `${params.baseUrl}/v1/chat/completions`;
525
- };
526
- var OPENAI_CONFIG = {
527
- makeBaseUrl: makeBaseUrl13,
528
- makeBody: makeBody13,
529
- makeHeaders: makeHeaders13,
530
- makeUrl: makeUrl13,
531
- clientSideRoutingOnly: true
532
- };
1189
+ if (!task) {
1190
+ throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'");
1191
+ }
1192
+ if (!(provider in PROVIDERS)) {
1193
+ throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
1194
+ }
1195
+ const providerTasks = PROVIDERS[provider];
1196
+ if (!providerTasks || !(task in providerTasks)) {
1197
+ throw new Error(
1198
+ `Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}`
1199
+ );
1200
+ }
1201
+ return providerTasks[task];
1202
+ }
533
1203
 
534
1204
  // package.json
535
1205
  var name = "@huggingface/inference";
536
- var version = "3.7.0";
1206
+ var version = "3.8.0";
537
1207
 
538
1208
  // src/providers/consts.ts
539
- var HARDCODED_MODEL_ID_MAPPING = {
1209
+ var HARDCODED_MODEL_INFERENCE_MAPPING = {
540
1210
  /**
541
1211
  * "HF model ID" => "Model ID on Inference Provider's side"
542
1212
  *
@@ -558,106 +1228,127 @@ var HARDCODED_MODEL_ID_MAPPING = {
558
1228
  together: {}
559
1229
  };
560
1230
 
561
- // src/lib/getProviderModelId.ts
1231
+ // src/lib/getInferenceProviderMapping.ts
562
1232
  var inferenceProviderMappingCache = /* @__PURE__ */ new Map();
563
- async function getProviderModelId(params, args, options = {}) {
564
- if (params.provider === "hf-inference") {
565
- return params.model;
566
- }
567
- if (!options.task) {
568
- throw new Error("task must be specified when using a third-party provider");
569
- }
570
- const task = options.task === "text-generation" && options.chatCompletion ? "conversational" : options.task;
571
- if (HARDCODED_MODEL_ID_MAPPING[params.provider]?.[params.model]) {
572
- return HARDCODED_MODEL_ID_MAPPING[params.provider][params.model];
1233
+ async function getInferenceProviderMapping(params, options) {
1234
+ if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
1235
+ return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
573
1236
  }
574
1237
  let inferenceProviderMapping;
575
- if (inferenceProviderMappingCache.has(params.model)) {
576
- inferenceProviderMapping = inferenceProviderMappingCache.get(params.model);
1238
+ if (inferenceProviderMappingCache.has(params.modelId)) {
1239
+ inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId);
577
1240
  } else {
578
- inferenceProviderMapping = await (options?.fetch ?? fetch)(
579
- `${HF_HUB_URL}/api/models/${params.model}?expand[]=inferenceProviderMapping`,
1241
+ const resp = await (options?.fetch ?? fetch)(
1242
+ `${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`,
580
1243
  {
581
- headers: args.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${args.accessToken}` } : {}
1244
+ headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {}
582
1245
  }
583
- ).then((resp) => resp.json()).then((json) => json.inferenceProviderMapping).catch(() => null);
1246
+ );
1247
+ if (resp.status === 404) {
1248
+ throw new Error(`Model ${params.modelId} does not exist`);
1249
+ }
1250
+ inferenceProviderMapping = await resp.json().then((json) => json.inferenceProviderMapping).catch(() => null);
584
1251
  }
585
1252
  if (!inferenceProviderMapping) {
586
- throw new Error(`We have not been able to find inference provider information for model ${params.model}.`);
1253
+ throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`);
587
1254
  }
588
1255
  const providerMapping = inferenceProviderMapping[params.provider];
589
1256
  if (providerMapping) {
590
- if (providerMapping.task !== task) {
1257
+ const equivalentTasks = params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task];
1258
+ if (!typedInclude(equivalentTasks, providerMapping.task)) {
591
1259
  throw new Error(
592
- `Model ${params.model} is not supported for task ${task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
1260
+ `Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
593
1261
  );
594
1262
  }
595
1263
  if (providerMapping.status === "staging") {
596
1264
  console.warn(
597
- `Model ${params.model} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
1265
+ `Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
598
1266
  );
599
1267
  }
600
- return providerMapping.providerId;
1268
+ if (providerMapping.adapter === "lora") {
1269
+ const treeResp = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${params.modelId}/tree/main`);
1270
+ if (!treeResp.ok) {
1271
+ throw new Error(`Unable to fetch the model tree for ${params.modelId}.`);
1272
+ }
1273
+ const tree = await treeResp.json();
1274
+ const adapterWeightsPath = tree.find(({ type, path }) => type === "file" && path.endsWith(".safetensors"))?.path;
1275
+ if (!adapterWeightsPath) {
1276
+ throw new Error(`No .safetensors file found in the model tree for ${params.modelId}.`);
1277
+ }
1278
+ return {
1279
+ ...providerMapping,
1280
+ hfModelId: params.modelId,
1281
+ adapterWeightsPath
1282
+ };
1283
+ }
1284
+ return { ...providerMapping, hfModelId: params.modelId };
601
1285
  }
602
- throw new Error(`Model ${params.model} is not supported provider ${params.provider}.`);
1286
+ return null;
603
1287
  }
604
1288
 
605
1289
  // src/lib/makeRequestOptions.ts
606
- var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
607
1290
  var tasks = null;
608
- var providerConfigs = {
609
- "black-forest-labs": BLACK_FOREST_LABS_CONFIG,
610
- cerebras: CEREBRAS_CONFIG,
611
- cohere: COHERE_CONFIG,
612
- "fal-ai": FAL_AI_CONFIG,
613
- "fireworks-ai": FIREWORKS_AI_CONFIG,
614
- "hf-inference": HF_INFERENCE_CONFIG,
615
- hyperbolic: HYPERBOLIC_CONFIG,
616
- openai: OPENAI_CONFIG,
617
- nebius: NEBIUS_CONFIG,
618
- novita: NOVITA_CONFIG,
619
- replicate: REPLICATE_CONFIG,
620
- sambanova: SAMBANOVA_CONFIG,
621
- together: TOGETHER_CONFIG
622
- };
623
- async function makeRequestOptions(args, options) {
1291
+ async function makeRequestOptions(args, providerHelper, options) {
624
1292
  const { provider: maybeProvider, model: maybeModel } = args;
625
1293
  const provider = maybeProvider ?? "hf-inference";
626
- const providerConfig = providerConfigs[provider];
627
- const { task, chatCompletion: chatCompletion2 } = options ?? {};
1294
+ const { task } = options ?? {};
628
1295
  if (args.endpointUrl && provider !== "hf-inference") {
629
1296
  throw new Error(`Cannot use endpointUrl with a third-party provider.`);
630
1297
  }
631
1298
  if (maybeModel && isUrl(maybeModel)) {
632
1299
  throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
633
1300
  }
1301
+ if (args.endpointUrl) {
1302
+ return makeRequestOptionsFromResolvedModel(
1303
+ maybeModel ?? args.endpointUrl,
1304
+ providerHelper,
1305
+ args,
1306
+ void 0,
1307
+ options
1308
+ );
1309
+ }
634
1310
  if (!maybeModel && !task) {
635
1311
  throw new Error("No model provided, and no task has been specified.");
636
1312
  }
637
- if (!providerConfig) {
638
- throw new Error(`No provider config found for provider ${provider}`);
639
- }
640
- if (providerConfig.clientSideRoutingOnly && !maybeModel) {
1313
+ const hfModel = maybeModel ?? await loadDefaultModel(task);
1314
+ if (providerHelper.clientSideRoutingOnly && !maybeModel) {
641
1315
  throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
642
1316
  }
643
- const hfModel = maybeModel ?? await loadDefaultModel(task);
644
- const resolvedModel = providerConfig.clientSideRoutingOnly ? (
1317
+ const inferenceProviderMapping = providerHelper.clientSideRoutingOnly ? {
645
1318
  // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
646
- removeProviderPrefix(maybeModel, provider)
647
- ) : await getProviderModelId({ model: hfModel, provider }, args, {
648
- task,
649
- chatCompletion: chatCompletion2,
650
- fetch: options?.fetch
651
- });
652
- return makeRequestOptionsFromResolvedModel(resolvedModel, args, options);
1319
+ providerId: removeProviderPrefix(maybeModel, provider),
1320
+ // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
1321
+ hfModelId: maybeModel,
1322
+ status: "live",
1323
+ // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
1324
+ task
1325
+ } : await getInferenceProviderMapping(
1326
+ {
1327
+ modelId: hfModel,
1328
+ // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
1329
+ task,
1330
+ provider,
1331
+ accessToken: args.accessToken
1332
+ },
1333
+ { fetch: options?.fetch }
1334
+ );
1335
+ if (!inferenceProviderMapping) {
1336
+ throw new Error(`We have not been able to find inference provider information for model ${hfModel}.`);
1337
+ }
1338
+ return makeRequestOptionsFromResolvedModel(
1339
+ inferenceProviderMapping.providerId,
1340
+ providerHelper,
1341
+ args,
1342
+ inferenceProviderMapping,
1343
+ options
1344
+ );
653
1345
  }
654
- function makeRequestOptionsFromResolvedModel(resolvedModel, args, options) {
1346
+ function makeRequestOptionsFromResolvedModel(resolvedModel, providerHelper, args, mapping, options) {
655
1347
  const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
656
1348
  const provider = maybeProvider ?? "hf-inference";
657
- const providerConfig = providerConfigs[provider];
658
- const { includeCredentials, task, chatCompletion: chatCompletion2, signal, billTo } = options ?? {};
1349
+ const { includeCredentials, task, signal, billTo } = options ?? {};
659
1350
  const authMethod = (() => {
660
- if (providerConfig.clientSideRoutingOnly) {
1351
+ if (providerHelper.clientSideRoutingOnly) {
661
1352
  if (accessToken && accessToken.startsWith("hf_")) {
662
1353
  throw new Error(`Provider ${provider} is closed-source and does not support HF tokens.`);
663
1354
  }
@@ -671,35 +1362,31 @@ function makeRequestOptionsFromResolvedModel(resolvedModel, args, options) {
671
1362
  }
672
1363
  return "none";
673
1364
  })();
674
- const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : providerConfig.makeUrl({
1365
+ const modelId = endpointUrl ?? resolvedModel;
1366
+ const url = providerHelper.makeUrl({
675
1367
  authMethod,
676
- baseUrl: authMethod !== "provider-key" ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider) : providerConfig.makeBaseUrl(task),
677
- model: resolvedModel,
678
- chatCompletion: chatCompletion2,
1368
+ model: modelId,
679
1369
  task
680
1370
  });
681
- const binary = "data" in args && !!args.data;
682
- const headers = providerConfig.makeHeaders({
683
- accessToken,
684
- authMethod
685
- });
1371
+ const headers = providerHelper.prepareHeaders(
1372
+ {
1373
+ accessToken,
1374
+ authMethod
1375
+ },
1376
+ "data" in args && !!args.data
1377
+ );
686
1378
  if (billTo) {
687
1379
  headers[HF_HEADER_X_BILL_TO] = billTo;
688
1380
  }
689
- if (!binary) {
690
- headers["Content-Type"] = "application/json";
691
- }
692
1381
  const ownUserAgent = `${name}/${version}`;
693
1382
  const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : void 0].filter((x) => x !== void 0).join(" ");
694
1383
  headers["User-Agent"] = userAgent;
695
- const body = binary ? args.data : JSON.stringify(
696
- providerConfig.makeBody({
697
- args: remainingArgs,
698
- model: resolvedModel,
699
- task,
700
- chatCompletion: chatCompletion2
701
- })
702
- );
1384
+ const body = providerHelper.makeBody({
1385
+ args: remainingArgs,
1386
+ model: resolvedModel,
1387
+ task,
1388
+ mapping
1389
+ });
703
1390
  let credentials;
704
1391
  if (typeof includeCredentials === "string") {
705
1392
  credentials = includeCredentials;
@@ -839,12 +1526,12 @@ function newMessage() {
839
1526
  }
840
1527
 
841
1528
  // src/utils/request.ts
842
- async function innerRequest(args, options) {
843
- const { url, info } = await makeRequestOptions(args, options);
1529
+ async function innerRequest(args, providerHelper, options) {
1530
+ const { url, info } = await makeRequestOptions(args, providerHelper, options);
844
1531
  const response = await (options?.fetch ?? fetch)(url, info);
845
1532
  const requestContext = { url, info };
846
1533
  if (options?.retry_on_error !== false && response.status === 503) {
847
- return innerRequest(args, options);
1534
+ return innerRequest(args, providerHelper, options);
848
1535
  }
849
1536
  if (!response.ok) {
850
1537
  const contentType = response.headers.get("Content-Type");
@@ -871,11 +1558,11 @@ async function innerRequest(args, options) {
871
1558
  const blob = await response.blob();
872
1559
  return { data: blob, requestContext };
873
1560
  }
874
- async function* innerStreamingRequest(args, options) {
875
- const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
1561
+ async function* innerStreamingRequest(args, providerHelper, options) {
1562
+ const { url, info } = await makeRequestOptions({ ...args, stream: true }, providerHelper, options);
876
1563
  const response = await (options?.fetch ?? fetch)(url, info);
877
1564
  if (options?.retry_on_error !== false && response.status === 503) {
878
- return yield* innerStreamingRequest(args, options);
1565
+ return yield* innerStreamingRequest(args, providerHelper, options);
879
1566
  }
880
1567
  if (!response.ok) {
881
1568
  if (response.headers.get("Content-Type")?.startsWith("application/json")) {
@@ -949,7 +1636,8 @@ async function request(args, options) {
949
1636
  console.warn(
950
1637
  "The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
951
1638
  );
952
- const result = await innerRequest(args, options);
1639
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
1640
+ const result = await innerRequest(args, providerHelper, options);
953
1641
  return result.data;
954
1642
  }
955
1643
 
@@ -958,31 +1646,8 @@ async function* streamingRequest(args, options) {
958
1646
  console.warn(
959
1647
  "The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
960
1648
  );
961
- yield* innerStreamingRequest(args, options);
962
- }
963
-
964
- // src/utils/pick.ts
965
- function pick(o, props) {
966
- return Object.assign(
967
- {},
968
- ...props.map((prop) => {
969
- if (o[prop] !== void 0) {
970
- return { [prop]: o[prop] };
971
- }
972
- })
973
- );
974
- }
975
-
976
- // src/utils/typedInclude.ts
977
- function typedInclude(arr, v) {
978
- return arr.includes(v);
979
- }
980
-
981
- // src/utils/omit.ts
982
- function omit(o, props) {
983
- const propsArr = Array.isArray(props) ? props : [props];
984
- const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
985
- return pick(o, letsKeep);
1649
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
1650
+ yield* innerStreamingRequest(args, providerHelper, options);
986
1651
  }
987
1652
 
988
1653
  // src/tasks/audio/utils.ts
@@ -995,16 +1660,24 @@ function preparePayload(args) {
995
1660
 
996
1661
  // src/tasks/audio/audioClassification.ts
997
1662
  async function audioClassification(args, options) {
1663
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
998
1664
  const payload = preparePayload(args);
999
- const { data: res } = await innerRequest(payload, {
1665
+ const { data: res } = await innerRequest(payload, providerHelper, {
1000
1666
  ...options,
1001
1667
  task: "audio-classification"
1002
1668
  });
1003
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
1004
- if (!isValidOutput) {
1005
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
1006
- }
1007
- return res;
1669
+ return providerHelper.getResponse(res);
1670
+ }
1671
+
1672
+ // src/tasks/audio/audioToAudio.ts
1673
+ async function audioToAudio(args, options) {
1674
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
1675
+ const payload = preparePayload(args);
1676
+ const { data: res } = await innerRequest(payload, providerHelper, {
1677
+ ...options,
1678
+ task: "audio-to-audio"
1679
+ });
1680
+ return providerHelper.getResponse(res);
1008
1681
  }
1009
1682
 
1010
1683
  // src/utils/base64FromBytes.ts
@@ -1022,8 +1695,9 @@ function base64FromBytes(arr) {
1022
1695
 
1023
1696
  // src/tasks/audio/automaticSpeechRecognition.ts
1024
1697
  async function automaticSpeechRecognition(args, options) {
1698
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
1025
1699
  const payload = await buildPayload(args);
1026
- const { data: res } = await innerRequest(payload, {
1700
+ const { data: res } = await innerRequest(payload, providerHelper, {
1027
1701
  ...options,
1028
1702
  task: "automatic-speech-recognition"
1029
1703
  });
@@ -1031,9 +1705,8 @@ async function automaticSpeechRecognition(args, options) {
1031
1705
  if (!isValidOutput) {
1032
1706
  throw new InferenceOutputError("Expected {text: string}");
1033
1707
  }
1034
- return res;
1708
+ return providerHelper.getResponse(res);
1035
1709
  }
1036
- var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
1037
1710
  async function buildPayload(args) {
1038
1711
  if (args.provider === "fal-ai") {
1039
1712
  const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : void 0;
@@ -1062,215 +1735,45 @@ async function buildPayload(args) {
1062
1735
 
1063
1736
  // src/tasks/audio/textToSpeech.ts
1064
1737
  async function textToSpeech(args, options) {
1065
- const payload = args.provider === "replicate" ? {
1066
- ...omit(args, ["inputs", "parameters"]),
1067
- ...args.parameters,
1068
- text: args.inputs
1069
- } : args;
1070
- const { data: res } = await innerRequest(payload, {
1738
+ const provider = args.provider ?? "hf-inference";
1739
+ const providerHelper = getProviderHelper(provider, "text-to-speech");
1740
+ const { data: res } = await innerRequest(args, providerHelper, {
1071
1741
  ...options,
1072
1742
  task: "text-to-speech"
1073
1743
  });
1074
- if (res instanceof Blob) {
1075
- return res;
1076
- }
1077
- if (res && typeof res === "object") {
1078
- if ("output" in res) {
1079
- if (typeof res.output === "string") {
1080
- const urlResponse = await fetch(res.output);
1081
- const blob = await urlResponse.blob();
1082
- return blob;
1083
- } else if (Array.isArray(res.output)) {
1084
- const urlResponse = await fetch(res.output[0]);
1085
- const blob = await urlResponse.blob();
1086
- return blob;
1087
- }
1088
- }
1089
- }
1090
- throw new InferenceOutputError("Expected Blob or object with output");
1091
- }
1092
-
1093
- // src/tasks/audio/audioToAudio.ts
1094
- async function audioToAudio(args, options) {
1095
- const payload = preparePayload(args);
1096
- const { data: res } = await innerRequest(payload, {
1097
- ...options,
1098
- task: "audio-to-audio"
1099
- });
1100
- return validateOutput(res);
1101
- }
1102
- function validateOutput(output) {
1103
- if (!Array.isArray(output)) {
1104
- throw new InferenceOutputError("Expected Array");
1105
- }
1106
- if (!output.every((elem) => {
1107
- return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
1108
- })) {
1109
- throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
1110
- }
1111
- return output;
1112
- }
1113
-
1114
- // src/tasks/cv/utils.ts
1115
- function preparePayload2(args) {
1116
- return "data" in args ? args : { ...omit(args, "inputs"), data: args.inputs };
1117
- }
1118
-
1119
- // src/tasks/cv/imageClassification.ts
1120
- async function imageClassification(args, options) {
1121
- const payload = preparePayload2(args);
1122
- const { data: res } = await innerRequest(payload, {
1123
- ...options,
1124
- task: "image-classification"
1125
- });
1126
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
1127
- if (!isValidOutput) {
1128
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
1129
- }
1130
- return res;
1131
- }
1132
-
1133
- // src/tasks/cv/imageSegmentation.ts
1134
- async function imageSegmentation(args, options) {
1135
- const payload = preparePayload2(args);
1136
- const { data: res } = await innerRequest(payload, {
1137
- ...options,
1138
- task: "image-segmentation"
1139
- });
1140
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
1141
- if (!isValidOutput) {
1142
- throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
1143
- }
1144
- return res;
1744
+ return providerHelper.getResponse(res);
1145
1745
  }
1146
-
1147
- // src/tasks/cv/imageToText.ts
1148
- async function imageToText(args, options) {
1149
- const payload = preparePayload2(args);
1150
- const { data: res } = await innerRequest(payload, {
1151
- ...options,
1152
- task: "image-to-text"
1153
- });
1154
- if (typeof res?.[0]?.generated_text !== "string") {
1155
- throw new InferenceOutputError("Expected {generated_text: string}");
1156
- }
1157
- return res?.[0];
1746
+
1747
+ // src/tasks/cv/utils.ts
1748
+ function preparePayload2(args) {
1749
+ return "data" in args ? args : { ...omit(args, "inputs"), data: args.inputs };
1158
1750
  }
1159
1751
 
1160
- // src/tasks/cv/objectDetection.ts
1161
- async function objectDetection(args, options) {
1752
+ // src/tasks/cv/imageClassification.ts
1753
+ async function imageClassification(args, options) {
1754
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
1162
1755
  const payload = preparePayload2(args);
1163
- const { data: res } = await innerRequest(payload, {
1756
+ const { data: res } = await innerRequest(payload, providerHelper, {
1164
1757
  ...options,
1165
- task: "object-detection"
1758
+ task: "image-classification"
1166
1759
  });
1167
- const isValidOutput = Array.isArray(res) && res.every(
1168
- (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
1169
- );
1170
- if (!isValidOutput) {
1171
- throw new InferenceOutputError(
1172
- "Expected Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
1173
- );
1174
- }
1175
- return res;
1760
+ return providerHelper.getResponse(res);
1176
1761
  }
1177
1762
 
1178
- // src/tasks/cv/textToImage.ts
1179
- function getResponseFormatArg(provider) {
1180
- switch (provider) {
1181
- case "fal-ai":
1182
- return { sync_mode: true };
1183
- case "nebius":
1184
- return { response_format: "b64_json" };
1185
- case "replicate":
1186
- return void 0;
1187
- case "together":
1188
- return { response_format: "base64" };
1189
- default:
1190
- return void 0;
1191
- }
1192
- }
1193
- async function textToImage(args, options) {
1194
- const payload = !args.provider || args.provider === "hf-inference" || args.provider === "sambanova" ? args : {
1195
- ...omit(args, ["inputs", "parameters"]),
1196
- ...args.parameters,
1197
- ...getResponseFormatArg(args.provider),
1198
- prompt: args.inputs
1199
- };
1200
- const { data: res } = await innerRequest(payload, {
1763
+ // src/tasks/cv/imageSegmentation.ts
1764
+ async function imageSegmentation(args, options) {
1765
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
1766
+ const payload = preparePayload2(args);
1767
+ const { data: res } = await innerRequest(payload, providerHelper, {
1201
1768
  ...options,
1202
- task: "text-to-image"
1769
+ task: "image-segmentation"
1203
1770
  });
1204
- if (res && typeof res === "object") {
1205
- if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
1206
- return await pollBflResponse(res.polling_url, options?.outputType);
1207
- }
1208
- if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
1209
- if (options?.outputType === "url") {
1210
- return res.images[0].url;
1211
- } else {
1212
- const image = await fetch(res.images[0].url);
1213
- return await image.blob();
1214
- }
1215
- }
1216
- if (args.provider === "hyperbolic" && "images" in res && Array.isArray(res.images) && res.images[0] && typeof res.images[0].image === "string") {
1217
- if (options?.outputType === "url") {
1218
- return `data:image/jpeg;base64,${res.images[0].image}`;
1219
- }
1220
- const base64Response = await fetch(`data:image/jpeg;base64,${res.images[0].image}`);
1221
- return await base64Response.blob();
1222
- }
1223
- if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
1224
- const base64Data = res.data[0].b64_json;
1225
- if (options?.outputType === "url") {
1226
- return `data:image/jpeg;base64,${base64Data}`;
1227
- }
1228
- const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
1229
- return await base64Response.blob();
1230
- }
1231
- if ("output" in res && Array.isArray(res.output)) {
1232
- if (options?.outputType === "url") {
1233
- return res.output[0];
1234
- }
1235
- const urlResponse = await fetch(res.output[0]);
1236
- const blob = await urlResponse.blob();
1237
- return blob;
1238
- }
1239
- }
1240
- const isValidOutput = res && res instanceof Blob;
1241
- if (!isValidOutput) {
1242
- throw new InferenceOutputError("Expected Blob");
1243
- }
1244
- if (options?.outputType === "url") {
1245
- const b64 = await res.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
1246
- return `data:image/jpeg;base64,${b64}`;
1247
- }
1248
- return res;
1249
- }
1250
- async function pollBflResponse(url, outputType) {
1251
- const urlObj = new URL(url);
1252
- for (let step = 0; step < 5; step++) {
1253
- await delay(1e3);
1254
- console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
1255
- urlObj.searchParams.set("attempt", step.toString(10));
1256
- const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
1257
- if (!resp.ok) {
1258
- throw new InferenceOutputError("Failed to fetch result from black forest labs API");
1259
- }
1260
- const payload = await resp.json();
1261
- if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
1262
- if (outputType === "url") {
1263
- return payload.result.sample;
1264
- }
1265
- const image = await fetch(payload.result.sample);
1266
- return await image.blob();
1267
- }
1268
- }
1269
- throw new InferenceOutputError("Failed to fetch result from black forest labs API");
1771
+ return providerHelper.getResponse(res);
1270
1772
  }
1271
1773
 
1272
1774
  // src/tasks/cv/imageToImage.ts
1273
1775
  async function imageToImage(args, options) {
1776
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
1274
1777
  let reqArgs;
1275
1778
  if (!args.parameters) {
1276
1779
  reqArgs = {
@@ -1286,15 +1789,61 @@ async function imageToImage(args, options) {
1286
1789
  )
1287
1790
  };
1288
1791
  }
1289
- const { data: res } = await innerRequest(reqArgs, {
1792
+ const { data: res } = await innerRequest(reqArgs, providerHelper, {
1290
1793
  ...options,
1291
1794
  task: "image-to-image"
1292
1795
  });
1293
- const isValidOutput = res && res instanceof Blob;
1294
- if (!isValidOutput) {
1295
- throw new InferenceOutputError("Expected Blob");
1296
- }
1297
- return res;
1796
+ return providerHelper.getResponse(res);
1797
+ }
1798
+
1799
+ // src/tasks/cv/imageToText.ts
1800
+ async function imageToText(args, options) {
1801
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
1802
+ const payload = preparePayload2(args);
1803
+ const { data: res } = await innerRequest(payload, providerHelper, {
1804
+ ...options,
1805
+ task: "image-to-text"
1806
+ });
1807
+ return providerHelper.getResponse(res[0]);
1808
+ }
1809
+
1810
+ // src/tasks/cv/objectDetection.ts
1811
+ async function objectDetection(args, options) {
1812
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
1813
+ const payload = preparePayload2(args);
1814
+ const { data: res } = await innerRequest(payload, providerHelper, {
1815
+ ...options,
1816
+ task: "object-detection"
1817
+ });
1818
+ return providerHelper.getResponse(res);
1819
+ }
1820
+
1821
+ // src/tasks/cv/textToImage.ts
1822
+ async function textToImage(args, options) {
1823
+ const provider = args.provider ?? "hf-inference";
1824
+ const providerHelper = getProviderHelper(provider, "text-to-image");
1825
+ const { data: res } = await innerRequest(args, providerHelper, {
1826
+ ...options,
1827
+ task: "text-to-image"
1828
+ });
1829
+ const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "text-to-image" });
1830
+ return providerHelper.getResponse(res, url, info.headers, options?.outputType);
1831
+ }
1832
+
1833
+ // src/tasks/cv/textToVideo.ts
1834
+ async function textToVideo(args, options) {
1835
+ const provider = args.provider ?? "hf-inference";
1836
+ const providerHelper = getProviderHelper(provider, "text-to-video");
1837
+ const { data: response } = await innerRequest(
1838
+ args,
1839
+ providerHelper,
1840
+ {
1841
+ ...options,
1842
+ task: "text-to-video"
1843
+ }
1844
+ );
1845
+ const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "text-to-video" });
1846
+ return providerHelper.getResponse(response, url, info.headers);
1298
1847
  }
1299
1848
 
1300
1849
  // src/tasks/cv/zeroShotImageClassification.ts
@@ -1320,231 +1869,126 @@ async function preparePayload3(args) {
1320
1869
  }
1321
1870
  }
1322
1871
  async function zeroShotImageClassification(args, options) {
1872
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
1323
1873
  const payload = await preparePayload3(args);
1324
- const { data: res } = await innerRequest(payload, {
1874
+ const { data: res } = await innerRequest(payload, providerHelper, {
1325
1875
  ...options,
1326
1876
  task: "zero-shot-image-classification"
1327
1877
  });
1328
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
1329
- if (!isValidOutput) {
1330
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
1331
- }
1332
- return res;
1878
+ return providerHelper.getResponse(res);
1333
1879
  }
1334
1880
 
1335
- // src/tasks/cv/textToVideo.ts
1336
- var SUPPORTED_PROVIDERS = ["fal-ai", "novita", "replicate"];
1337
- async function textToVideo(args, options) {
1338
- if (!args.provider || !typedInclude(SUPPORTED_PROVIDERS, args.provider)) {
1339
- throw new Error(
1340
- `textToVideo inference is only supported for the following providers: ${SUPPORTED_PROVIDERS.join(", ")}`
1341
- );
1342
- }
1343
- const payload = args.provider === "fal-ai" || args.provider === "replicate" || args.provider === "novita" ? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs } : args;
1344
- const { data, requestContext } = await innerRequest(payload, {
1881
+ // src/tasks/nlp/chatCompletion.ts
1882
+ async function chatCompletion(args, options) {
1883
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
1884
+ const { data: response } = await innerRequest(args, providerHelper, {
1345
1885
  ...options,
1346
- task: "text-to-video"
1886
+ task: "conversational"
1887
+ });
1888
+ return providerHelper.getResponse(response);
1889
+ }
1890
+
1891
+ // src/tasks/nlp/chatCompletionStream.ts
1892
+ async function* chatCompletionStream(args, options) {
1893
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
1894
+ yield* innerStreamingRequest(args, providerHelper, {
1895
+ ...options,
1896
+ task: "conversational"
1347
1897
  });
1348
- if (args.provider === "fal-ai") {
1349
- return await pollFalResponse(
1350
- data,
1351
- requestContext.url,
1352
- requestContext.info.headers
1353
- );
1354
- } else if (args.provider === "novita") {
1355
- const isValidOutput = typeof data === "object" && !!data && "video" in data && typeof data.video === "object" && !!data.video && "video_url" in data.video && typeof data.video.video_url === "string" && isUrl(data.video.video_url);
1356
- if (!isValidOutput) {
1357
- throw new InferenceOutputError("Expected { video: { video_url: string } }");
1358
- }
1359
- const urlResponse = await fetch(data.video.video_url);
1360
- return await urlResponse.blob();
1361
- } else {
1362
- const isValidOutput = typeof data === "object" && !!data && "output" in data && typeof data.output === "string" && isUrl(data.output);
1363
- if (!isValidOutput) {
1364
- throw new InferenceOutputError("Expected { output: string }");
1365
- }
1366
- const urlResponse = await fetch(data.output);
1367
- return await urlResponse.blob();
1368
- }
1369
1898
  }
1370
1899
 
1371
1900
  // src/tasks/nlp/featureExtraction.ts
1372
1901
  async function featureExtraction(args, options) {
1373
- const { data: res } = await innerRequest(args, {
1902
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction");
1903
+ const { data: res } = await innerRequest(args, providerHelper, {
1374
1904
  ...options,
1375
1905
  task: "feature-extraction"
1376
1906
  });
1377
- let isValidOutput = true;
1378
- const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
1379
- if (curDepth > maxDepth)
1380
- return false;
1381
- if (arr.every((x) => Array.isArray(x))) {
1382
- return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
1383
- } else {
1384
- return arr.every((x) => typeof x === "number");
1385
- }
1386
- };
1387
- isValidOutput = Array.isArray(res) && isNumArrayRec(res, 3, 0);
1388
- if (!isValidOutput) {
1389
- throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
1390
- }
1391
- return res;
1907
+ return providerHelper.getResponse(res);
1392
1908
  }
1393
1909
 
1394
1910
  // src/tasks/nlp/fillMask.ts
1395
1911
  async function fillMask(args, options) {
1396
- const { data: res } = await innerRequest(args, {
1912
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask");
1913
+ const { data: res } = await innerRequest(args, providerHelper, {
1397
1914
  ...options,
1398
1915
  task: "fill-mask"
1399
1916
  });
1400
- const isValidOutput = Array.isArray(res) && res.every(
1401
- (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
1402
- );
1403
- if (!isValidOutput) {
1404
- throw new InferenceOutputError(
1405
- "Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
1406
- );
1407
- }
1408
- return res;
1917
+ return providerHelper.getResponse(res);
1409
1918
  }
1410
1919
 
1411
1920
  // src/tasks/nlp/questionAnswering.ts
1412
1921
  async function questionAnswering(args, options) {
1413
- const { data: res } = await innerRequest(args, {
1414
- ...options,
1415
- task: "question-answering"
1416
- });
1417
- const isValidOutput = Array.isArray(res) ? res.every(
1418
- (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
1419
- ) : typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";
1420
- if (!isValidOutput) {
1421
- throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
1422
- }
1423
- return Array.isArray(res) ? res[0] : res;
1922
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering");
1923
+ const { data: res } = await innerRequest(
1924
+ args,
1925
+ providerHelper,
1926
+ {
1927
+ ...options,
1928
+ task: "question-answering"
1929
+ }
1930
+ );
1931
+ return providerHelper.getResponse(res);
1424
1932
  }
1425
1933
 
1426
1934
  // src/tasks/nlp/sentenceSimilarity.ts
1427
1935
  async function sentenceSimilarity(args, options) {
1428
- const { data: res } = await innerRequest(args, {
1936
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity");
1937
+ const { data: res } = await innerRequest(args, providerHelper, {
1429
1938
  ...options,
1430
1939
  task: "sentence-similarity"
1431
1940
  });
1432
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1433
- if (!isValidOutput) {
1434
- throw new InferenceOutputError("Expected number[]");
1435
- }
1436
- return res;
1941
+ return providerHelper.getResponse(res);
1437
1942
  }
1438
1943
 
1439
1944
  // src/tasks/nlp/summarization.ts
1440
1945
  async function summarization(args, options) {
1441
- const { data: res } = await innerRequest(args, {
1946
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization");
1947
+ const { data: res } = await innerRequest(args, providerHelper, {
1442
1948
  ...options,
1443
1949
  task: "summarization"
1444
1950
  });
1445
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string");
1446
- if (!isValidOutput) {
1447
- throw new InferenceOutputError("Expected Array<{summary_text: string}>");
1448
- }
1449
- return res?.[0];
1951
+ return providerHelper.getResponse(res);
1450
1952
  }
1451
1953
 
1452
1954
  // src/tasks/nlp/tableQuestionAnswering.ts
1453
1955
  async function tableQuestionAnswering(args, options) {
1454
- const { data: res } = await innerRequest(args, {
1455
- ...options,
1456
- task: "table-question-answering"
1457
- });
1458
- const isValidOutput = Array.isArray(res) ? res.every((elem) => validate(elem)) : validate(res);
1459
- if (!isValidOutput) {
1460
- throw new InferenceOutputError(
1461
- "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
1462
- );
1463
- }
1464
- return Array.isArray(res) ? res[0] : res;
1465
- }
1466
- function validate(elem) {
1467
- return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
1468
- (coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
1956
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
1957
+ const { data: res } = await innerRequest(
1958
+ args,
1959
+ providerHelper,
1960
+ {
1961
+ ...options,
1962
+ task: "table-question-answering"
1963
+ }
1469
1964
  );
1965
+ return providerHelper.getResponse(res);
1470
1966
  }
1471
1967
 
1472
1968
  // src/tasks/nlp/textClassification.ts
1473
1969
  async function textClassification(args, options) {
1474
- const { data: res } = await innerRequest(args, {
1970
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
1971
+ const { data: res } = await innerRequest(args, providerHelper, {
1475
1972
  ...options,
1476
1973
  task: "text-classification"
1477
1974
  });
1478
- const output = res?.[0];
1479
- const isValidOutput = Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number");
1480
- if (!isValidOutput) {
1481
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
1482
- }
1483
- return output;
1484
- }
1485
-
1486
- // src/utils/toArray.ts
1487
- function toArray(obj) {
1488
- if (Array.isArray(obj)) {
1489
- return obj;
1490
- }
1491
- return [obj];
1975
+ return providerHelper.getResponse(res);
1492
1976
  }
1493
1977
 
1494
1978
  // src/tasks/nlp/textGeneration.ts
1495
1979
  async function textGeneration(args, options) {
1496
- if (args.provider === "together") {
1497
- args.prompt = args.inputs;
1498
- const { data: raw } = await innerRequest(args, {
1499
- ...options,
1500
- task: "text-generation"
1501
- });
1502
- const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
1503
- if (!isValidOutput) {
1504
- throw new InferenceOutputError("Expected ChatCompletionOutput");
1505
- }
1506
- const completion = raw.choices[0];
1507
- return {
1508
- generated_text: completion.text
1509
- };
1510
- } else if (args.provider === "hyperbolic") {
1511
- const payload = {
1512
- messages: [{ content: args.inputs, role: "user" }],
1513
- ...args.parameters ? {
1514
- max_tokens: args.parameters.max_new_tokens,
1515
- ...omit(args.parameters, "max_new_tokens")
1516
- } : void 0,
1517
- ...omit(args, ["inputs", "parameters"])
1518
- };
1519
- const raw = (await innerRequest(payload, {
1520
- ...options,
1521
- task: "text-generation"
1522
- })).data;
1523
- const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
1524
- if (!isValidOutput) {
1525
- throw new InferenceOutputError("Expected ChatCompletionOutput");
1526
- }
1527
- const completion = raw.choices[0];
1528
- return {
1529
- generated_text: completion.message.content
1530
- };
1531
- } else {
1532
- const { data: res } = await innerRequest(args, {
1533
- ...options,
1534
- task: "text-generation"
1535
- });
1536
- const output = toArray(res);
1537
- const isValidOutput = Array.isArray(output) && output.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
1538
- if (!isValidOutput) {
1539
- throw new InferenceOutputError("Expected Array<{generated_text: string}>");
1540
- }
1541
- return output?.[0];
1542
- }
1980
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
1981
+ const { data: response } = await innerRequest(args, providerHelper, {
1982
+ ...options,
1983
+ task: "text-generation"
1984
+ });
1985
+ return providerHelper.getResponse(response);
1543
1986
  }
1544
1987
 
1545
1988
  // src/tasks/nlp/textGenerationStream.ts
1546
1989
  async function* textGenerationStream(args, options) {
1547
- yield* innerStreamingRequest(args, {
1990
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
1991
+ yield* innerStreamingRequest(args, providerHelper, {
1548
1992
  ...options,
1549
1993
  task: "text-generation"
1550
1994
  });
@@ -1552,77 +1996,45 @@ async function* textGenerationStream(args, options) {
1552
1996
 
1553
1997
  // src/tasks/nlp/tokenClassification.ts
1554
1998
  async function tokenClassification(args, options) {
1555
- const { data: res } = await innerRequest(args, {
1556
- ...options,
1557
- task: "token-classification"
1558
- });
1559
- const output = toArray(res);
1560
- const isValidOutput = Array.isArray(output) && output.every(
1561
- (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
1999
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification");
2000
+ const { data: res } = await innerRequest(
2001
+ args,
2002
+ providerHelper,
2003
+ {
2004
+ ...options,
2005
+ task: "token-classification"
2006
+ }
1562
2007
  );
1563
- if (!isValidOutput) {
1564
- throw new InferenceOutputError(
1565
- "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
1566
- );
1567
- }
1568
- return output;
2008
+ return providerHelper.getResponse(res);
1569
2009
  }
1570
2010
 
1571
2011
  // src/tasks/nlp/translation.ts
1572
2012
  async function translation(args, options) {
1573
- const { data: res } = await innerRequest(args, {
2013
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation");
2014
+ const { data: res } = await innerRequest(args, providerHelper, {
1574
2015
  ...options,
1575
2016
  task: "translation"
1576
2017
  });
1577
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");
1578
- if (!isValidOutput) {
1579
- throw new InferenceOutputError("Expected type Array<{translation_text: string}>");
1580
- }
1581
- return res?.length === 1 ? res?.[0] : res;
2018
+ return providerHelper.getResponse(res);
1582
2019
  }
1583
2020
 
1584
2021
  // src/tasks/nlp/zeroShotClassification.ts
1585
2022
  async function zeroShotClassification(args, options) {
1586
- const { data: res } = await innerRequest(args, {
1587
- ...options,
1588
- task: "zero-shot-classification"
1589
- });
1590
- const output = toArray(res);
1591
- const isValidOutput = Array.isArray(output) && output.every(
1592
- (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
2023
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification");
2024
+ const { data: res } = await innerRequest(
2025
+ args,
2026
+ providerHelper,
2027
+ {
2028
+ ...options,
2029
+ task: "zero-shot-classification"
2030
+ }
1593
2031
  );
1594
- if (!isValidOutput) {
1595
- throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
1596
- }
1597
- return output;
1598
- }
1599
-
1600
- // src/tasks/nlp/chatCompletion.ts
1601
- async function chatCompletion(args, options) {
1602
- const { data: res } = await innerRequest(args, {
1603
- ...options,
1604
- task: "text-generation",
1605
- chatCompletion: true
1606
- });
1607
- const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
1608
- (res.system_fingerprint === void 0 || res.system_fingerprint === null || typeof res.system_fingerprint === "string") && typeof res?.usage === "object";
1609
- if (!isValidOutput) {
1610
- throw new InferenceOutputError("Expected ChatCompletionOutput");
1611
- }
1612
- return res;
1613
- }
1614
-
1615
- // src/tasks/nlp/chatCompletionStream.ts
1616
- async function* chatCompletionStream(args, options) {
1617
- yield* innerStreamingRequest(args, {
1618
- ...options,
1619
- task: "text-generation",
1620
- chatCompletion: true
1621
- });
2032
+ return providerHelper.getResponse(res);
1622
2033
  }
1623
2034
 
1624
2035
  // src/tasks/multimodal/documentQuestionAnswering.ts
1625
2036
  async function documentQuestionAnswering(args, options) {
2037
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering");
1626
2038
  const reqArgs = {
1627
2039
  ...args,
1628
2040
  inputs: {
@@ -1633,23 +2045,18 @@ async function documentQuestionAnswering(args, options) {
1633
2045
  };
1634
2046
  const { data: res } = await innerRequest(
1635
2047
  reqArgs,
2048
+ providerHelper,
1636
2049
  {
1637
2050
  ...options,
1638
2051
  task: "document-question-answering"
1639
2052
  }
1640
2053
  );
1641
- const output = toArray(res);
1642
- const isValidOutput = Array.isArray(output) && output.every(
1643
- (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
1644
- );
1645
- if (!isValidOutput) {
1646
- throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>");
1647
- }
1648
- return output[0];
2054
+ return providerHelper.getResponse(res);
1649
2055
  }
1650
2056
 
1651
2057
  // src/tasks/multimodal/visualQuestionAnswering.ts
1652
2058
  async function visualQuestionAnswering(args, options) {
2059
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "visual-question-answering");
1653
2060
  const reqArgs = {
1654
2061
  ...args,
1655
2062
  inputs: {
@@ -1658,43 +2065,31 @@ async function visualQuestionAnswering(args, options) {
1658
2065
  image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer()))
1659
2066
  }
1660
2067
  };
1661
- const { data: res } = await innerRequest(reqArgs, {
2068
+ const { data: res } = await innerRequest(reqArgs, providerHelper, {
1662
2069
  ...options,
1663
2070
  task: "visual-question-answering"
1664
2071
  });
1665
- const isValidOutput = Array.isArray(res) && res.every(
1666
- (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
1667
- );
1668
- if (!isValidOutput) {
1669
- throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
1670
- }
1671
- return res[0];
2072
+ return providerHelper.getResponse(res);
1672
2073
  }
1673
2074
 
1674
- // src/tasks/tabular/tabularRegression.ts
1675
- async function tabularRegression(args, options) {
1676
- const { data: res } = await innerRequest(args, {
2075
+ // src/tasks/tabular/tabularClassification.ts
2076
+ async function tabularClassification(args, options) {
2077
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification");
2078
+ const { data: res } = await innerRequest(args, providerHelper, {
1677
2079
  ...options,
1678
- task: "tabular-regression"
2080
+ task: "tabular-classification"
1679
2081
  });
1680
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1681
- if (!isValidOutput) {
1682
- throw new InferenceOutputError("Expected number[]");
1683
- }
1684
- return res;
2082
+ return providerHelper.getResponse(res);
1685
2083
  }
1686
2084
 
1687
- // src/tasks/tabular/tabularClassification.ts
1688
- async function tabularClassification(args, options) {
1689
- const { data: res } = await innerRequest(args, {
2085
+ // src/tasks/tabular/tabularRegression.ts
2086
+ async function tabularRegression(args, options) {
2087
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression");
2088
+ const { data: res } = await innerRequest(args, providerHelper, {
1690
2089
  ...options,
1691
- task: "tabular-classification"
2090
+ task: "tabular-regression"
1692
2091
  });
1693
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1694
- if (!isValidOutput) {
1695
- throw new InferenceOutputError("Expected number[]");
1696
- }
1697
- return res;
2092
+ return providerHelper.getResponse(res);
1698
2093
  }
1699
2094
 
1700
2095
  // src/InferenceClient.ts
@@ -1763,26 +2158,26 @@ __export(snippets_exports, {
1763
2158
  });
1764
2159
 
1765
2160
  // src/snippets/getInferenceSnippets.ts
1766
- var import_tasks = require("@huggingface/tasks");
1767
2161
  var import_jinja = require("@huggingface/jinja");
2162
+ var import_tasks = require("@huggingface/tasks");
1768
2163
 
1769
2164
  // src/snippets/templates.exported.ts
1770
2165
  var templates = {
1771
2166
  "js": {
1772
2167
  "fetch": {
1773
- "basic": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
1774
- "basicAudio": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "audio/flac"\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
1775
- "basicImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "image/jpeg"\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
1776
- "textToAudio": '{% if model.library_name == "transformers" %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ',
1777
- "textToImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\n\nquery({ {{ providerInputs.asTsString }} }).then((response) => {\n // Use image\n});',
1778
- "zeroShotClassification": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({\n inputs: {{ providerInputs.asObj.inputs }},\n parameters: { candidate_labels: ["refund", "legal", "faq"] }\n}).then((response) => {\n console.log(JSON.stringify(response));\n});'
2168
+ "basic": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
2169
+ "basicAudio": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "audio/flac",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
2170
+ "basicImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "image/jpeg",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
2171
+ "textToAudio": '{% if model.library_name == "transformers" %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ',
2172
+ "textToImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\n\nquery({ {{ providerInputs.asTsString }} }).then((response) => {\n // Use image\n});',
2173
+ "zeroShotClassification": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({\n inputs: {{ providerInputs.asObj.inputs }},\n parameters: { candidate_labels: ["refund", "legal", "faq"] }\n}).then((response) => {\n console.log(JSON.stringify(response));\n});'
1779
2174
  },
1780
2175
  "huggingface.js": {
1781
- "basic": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst output = await client.{{ methodName }}({\n model: "{{ model.id }}",\n inputs: {{ inputs.asObj.inputs }},\n provider: "{{ provider }}",\n});\n\nconsole.log(output);',
1782
- "basicAudio": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst data = fs.readFileSync({{inputs.asObj.inputs}});\n\nconst output = await client.{{ methodName }}({\n data,\n model: "{{ model.id }}",\n provider: "{{ provider }}",\n});\n\nconsole.log(output);',
1783
- "basicImage": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst data = fs.readFileSync({{inputs.asObj.inputs}});\n\nconst output = await client.{{ methodName }}({\n data,\n model: "{{ model.id }}",\n provider: "{{ provider }}",\n});\n\nconsole.log(output);',
1784
- "conversational": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst chatCompletion = await client.chatCompletion({\n provider: "{{ provider }}",\n model: "{{ model.id }}",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);',
1785
- "conversationalStream": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nlet out = "";\n\nconst stream = await client.chatCompletionStream({\n provider: "{{ provider }}",\n model: "{{ model.id }}",\n{{ inputs.asTsString }}\n});\n\nfor await (const chunk of stream) {\n if (chunk.choices && chunk.choices.length > 0) {\n const newContent = chunk.choices[0].delta.content;\n out += newContent;\n console.log(newContent);\n } \n}',
2176
+ "basic": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst output = await client.{{ methodName }}({\n model: "{{ model.id }}",\n inputs: {{ inputs.asObj.inputs }},\n provider: "{{ provider }}",\n}{% if billTo %}, {\n billTo: "{{ billTo }}",\n}{% endif %});\n\nconsole.log(output);',
2177
+ "basicAudio": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst data = fs.readFileSync({{inputs.asObj.inputs}});\n\nconst output = await client.{{ methodName }}({\n data,\n model: "{{ model.id }}",\n provider: "{{ provider }}",\n}{% if billTo %}, {\n billTo: "{{ billTo }}",\n}{% endif %});\n\nconsole.log(output);',
2178
+ "basicImage": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst data = fs.readFileSync({{inputs.asObj.inputs}});\n\nconst output = await client.{{ methodName }}({\n data,\n model: "{{ model.id }}",\n provider: "{{ provider }}",\n}{% if billTo %}, {\n billTo: "{{ billTo }}",\n}{% endif %});\n\nconsole.log(output);',
2179
+ "conversational": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst chatCompletion = await client.chatCompletion({\n provider: "{{ provider }}",\n model: "{{ model.id }}",\n{{ inputs.asTsString }}\n}{% if billTo %}, {\n billTo: "{{ billTo }}",\n}{% endif %});\n\nconsole.log(chatCompletion.choices[0].message);',
2180
+ "conversationalStream": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nlet out = "";\n\nconst stream = await client.chatCompletionStream({\n provider: "{{ provider }}",\n model: "{{ model.id }}",\n{{ inputs.asTsString }}\n}{% if billTo %}, {\n billTo: "{{ billTo }}",\n}{% endif %});\n\nfor await (const chunk of stream) {\n if (chunk.choices && chunk.choices.length > 0) {\n const newContent = chunk.choices[0].delta.content;\n out += newContent;\n console.log(newContent);\n } \n}',
1786
2181
  "textToImage": `import { InferenceClient } from "@huggingface/inference";
1787
2182
 
1788
2183
  const client = new InferenceClient("{{ accessToken }}");
@@ -1792,7 +2187,9 @@ const image = await client.textToImage({
1792
2187
  model: "{{ model.id }}",
1793
2188
  inputs: {{ inputs.asObj.inputs }},
1794
2189
  parameters: { num_inference_steps: 5 },
1795
- });
2190
+ }{% if billTo %}, {
2191
+ billTo: "{{ billTo }}",
2192
+ }{% endif %});
1796
2193
  /// Use the generated image (it's a Blob)`,
1797
2194
  "textToVideo": `import { InferenceClient } from "@huggingface/inference";
1798
2195
 
@@ -1802,12 +2199,14 @@ const image = await client.textToVideo({
1802
2199
  provider: "{{ provider }}",
1803
2200
  model: "{{ model.id }}",
1804
2201
  inputs: {{ inputs.asObj.inputs }},
1805
- });
2202
+ }{% if billTo %}, {
2203
+ billTo: "{{ billTo }}",
2204
+ }{% endif %});
1806
2205
  // Use the generated video (it's a Blob)`
1807
2206
  },
1808
2207
  "openai": {
1809
- "conversational": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\nconst chatCompletion = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);',
1810
- "conversationalStream": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\nlet out = "";\n\nconst stream = await client.chat.completions.create({\n provider: "{{ provider }}",\n model: "{{ model.id }}",\n{{ inputs.asTsString }}\n});\n\nfor await (const chunk of stream) {\n if (chunk.choices && chunk.choices.length > 0) {\n const newContent = chunk.choices[0].delta.content;\n out += newContent;\n console.log(newContent);\n } \n}'
2208
+ "conversational": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n{% if billTo %}\n defaultHeaders: {\n "X-HF-Bill-To": "{{ billTo }}" \n }\n{% endif %}\n});\n\nconst chatCompletion = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);',
2209
+ "conversationalStream": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n{% if billTo %}\n defaultHeaders: {\n "X-HF-Bill-To": "{{ billTo }}" \n }\n{% endif %}\n});\n\nconst stream = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n stream: true,\n});\n\nfor await (const chunk of stream) {\n process.stdout.write(chunk.choices[0]?.delta?.content || "");\n}'
1811
2210
  }
1812
2211
  },
1813
2212
  "python": {
@@ -1822,13 +2221,13 @@ const image = await client.textToVideo({
1822
2221
  "conversationalStream": 'stream = client.chat.completions.create(\n model="{{ model.id }}",\n{{ inputs.asPythonString }}\n stream=True,\n)\n\nfor chunk in stream:\n print(chunk.choices[0].delta.content, end="") ',
1823
2222
  "documentQuestionAnswering": 'output = client.document_question_answering(\n "{{ inputs.asObj.image }}",\n question="{{ inputs.asObj.question }}",\n model="{{ model.id }}",\n) ',
1824
2223
  "imageToImage": '# output is a PIL.Image object\nimage = client.image_to_image(\n "{{ inputs.asObj.inputs }}",\n prompt="{{ inputs.asObj.parameters.prompt }}",\n model="{{ model.id }}",\n) ',
1825
- "importInferenceClient": 'from huggingface_hub import InferenceClient\n\nclient = InferenceClient(\n provider="{{ provider }}",\n api_key="{{ accessToken }}",\n)',
2224
+ "importInferenceClient": 'from huggingface_hub import InferenceClient\n\nclient = InferenceClient(\n provider="{{ provider }}",\n api_key="{{ accessToken }}",\n{% if billTo %}\n bill_to="{{ billTo }}",\n{% endif %}\n)',
1826
2225
  "textToImage": '# output is a PIL.Image object\nimage = client.text_to_image(\n {{ inputs.asObj.inputs }},\n model="{{ model.id }}",\n) ',
1827
2226
  "textToVideo": 'video = client.text_to_video(\n {{ inputs.asObj.inputs }},\n model="{{ model.id }}",\n) '
1828
2227
  },
1829
2228
  "openai": {
1830
- "conversational": 'from openai import OpenAI\n\nclient = OpenAI(\n base_url="{{ baseUrl }}",\n api_key="{{ accessToken }}"\n)\n\ncompletion = client.chat.completions.create(\n model="{{ providerModelId }}",\n{{ inputs.asPythonString }}\n)\n\nprint(completion.choices[0].message) ',
1831
- "conversationalStream": 'from openai import OpenAI\n\nclient = OpenAI(\n base_url="{{ baseUrl }}",\n api_key="{{ accessToken }}"\n)\n\nstream = client.chat.completions.create(\n model="{{ providerModelId }}",\n{{ inputs.asPythonString }}\n stream=True,\n)\n\nfor chunk in stream:\n print(chunk.choices[0].delta.content, end="")'
2229
+ "conversational": 'from openai import OpenAI\n\nclient = OpenAI(\n base_url="{{ baseUrl }}",\n api_key="{{ accessToken }}",\n{% if billTo %}\n default_headers={\n "X-HF-Bill-To": "{{ billTo }}"\n }\n{% endif %}\n)\n\ncompletion = client.chat.completions.create(\n model="{{ providerModelId }}",\n{{ inputs.asPythonString }}\n)\n\nprint(completion.choices[0].message) ',
2230
+ "conversationalStream": 'from openai import OpenAI\n\nclient = OpenAI(\n base_url="{{ baseUrl }}",\n api_key="{{ accessToken }}",\n{% if billTo %}\n default_headers={\n "X-HF-Bill-To": "{{ billTo }}"\n }\n{% endif %}\n)\n\nstream = client.chat.completions.create(\n model="{{ providerModelId }}",\n{{ inputs.asPythonString }}\n stream=True,\n)\n\nfor chunk in stream:\n print(chunk.choices[0].delta.content, end="")'
1832
2231
  },
1833
2232
  "requests": {
1834
2233
  "basic": 'def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\noutput = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n}) ',
@@ -1838,7 +2237,7 @@ const image = await client.textToVideo({
1838
2237
  "conversationalStream": 'def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload, stream=True)\n for line in response.iter_lines():\n if not line.startswith(b"data:"):\n continue\n if line.strip() == b"data: [DONE]":\n return\n yield json.loads(line.decode("utf-8").lstrip("data:").rstrip("/n"))\n\nchunks = query({\n{{ providerInputs.asJsonString }},\n "stream": True,\n})\n\nfor chunk in chunks:\n print(chunk["choices"][0]["delta"]["content"], end="")',
1839
2238
  "documentQuestionAnswering": 'def query(payload):\n with open(payload["image"], "rb") as f:\n img = f.read()\n payload["image"] = base64.b64encode(img).decode("utf-8")\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\noutput = query({\n "inputs": {\n "image": "{{ inputs.asObj.image }}",\n "question": "{{ inputs.asObj.question }}",\n },\n}) ',
1840
2239
  "imageToImage": 'def query(payload):\n with open(payload["inputs"], "rb") as f:\n img = f.read()\n payload["inputs"] = base64.b64encode(img).decode("utf-8")\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nimage_bytes = query({\n{{ providerInputs.asJsonString }}\n})\n\n# You can access the image with PIL.Image for example\nimport io\nfrom PIL import Image\nimage = Image.open(io.BytesIO(image_bytes)) ',
1841
- "importRequests": '{% if importBase64 %}\nimport base64\n{% endif %}\n{% if importJson %}\nimport json\n{% endif %}\nimport requests\n\nAPI_URL = "{{ fullUrl }}"\nheaders = {"Authorization": "{{ authorizationHeader }}"}',
2240
+ "importRequests": '{% if importBase64 %}\nimport base64\n{% endif %}\n{% if importJson %}\nimport json\n{% endif %}\nimport requests\n\nAPI_URL = "{{ fullUrl }}"\nheaders = {\n "Authorization": "{{ authorizationHeader }}",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}"\n{% endif %}\n}',
1842
2241
  "tabular": 'def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nresponse = query({\n "inputs": {\n "data": {{ providerInputs.asObj.inputs }}\n },\n}) ',
1843
2242
  "textToAudio": '{% if model.library_name == "transformers" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\naudio_bytes = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio_bytes)\n{% else %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\naudio, sampling_rate = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio, rate=sampling_rate)\n{% endif %} ',
1844
2243
  "textToImage": '{% if provider == "hf-inference" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nimage_bytes = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n})\n\n# You can access the image with PIL.Image for example\nimport io\nfrom PIL import Image\nimage = Image.open(io.BytesIO(image_bytes))\n{% endif %}',
@@ -1848,12 +2247,15 @@ const image = await client.textToVideo({
1848
2247
  },
1849
2248
  "sh": {
1850
2249
  "curl": {
1851
- "basic": "curl {{ fullUrl }} \\\n -X POST \\\n -H 'Authorization: {{ authorizationHeader }}' \\\n -H 'Content-Type: application/json' \\\n -d '{\n{{ providerInputs.asCurlString }}\n }'",
1852
- "basicAudio": "curl {{ fullUrl }} \\\n -X POST \\\n -H 'Authorization: {{ authorizationHeader }}' \\\n -H 'Content-Type: audio/flac' \\\n --data-binary @{{ providerInputs.asObj.inputs }}",
1853
- "basicImage": "curl {{ fullUrl }} \\\n -X POST \\\n -H 'Authorization: {{ authorizationHeader }}' \\\n -H 'Content-Type: image/jpeg' \\\n --data-binary @{{ providerInputs.asObj.inputs }}",
2250
+ "basic": "curl {{ fullUrl }} \\\n -X POST \\\n -H 'Authorization: {{ authorizationHeader }}' \\\n -H 'Content-Type: application/json' \\\n{% if billTo %}\n -H 'X-HF-Bill-To: {{ billTo }}' \\\n{% endif %}\n -d '{\n{{ providerInputs.asCurlString }}\n }'",
2251
+ "basicAudio": "curl {{ fullUrl }} \\\n -X POST \\\n -H 'Authorization: {{ authorizationHeader }}' \\\n -H 'Content-Type: audio/flac' \\\n{% if billTo %}\n -H 'X-HF-Bill-To: {{ billTo }}' \\\n{% endif %}\n --data-binary @{{ providerInputs.asObj.inputs }}",
2252
+ "basicImage": "curl {{ fullUrl }} \\\n -X POST \\\n -H 'Authorization: {{ authorizationHeader }}' \\\n -H 'Content-Type: image/jpeg' \\\n{% if billTo %}\n -H 'X-HF-Bill-To: {{ billTo }}' \\\n{% endif %}\n --data-binary @{{ providerInputs.asObj.inputs }}",
1854
2253
  "conversational": `curl {{ fullUrl }} \\
1855
2254
  -H 'Authorization: {{ authorizationHeader }}' \\
1856
2255
  -H 'Content-Type: application/json' \\
2256
+ {% if billTo %}
2257
+ -H 'X-HF-Bill-To: {{ billTo }}' \\
2258
+ {% endif %}
1857
2259
  -d '{
1858
2260
  {{ providerInputs.asCurlString }},
1859
2261
  "stream": false
@@ -1861,6 +2263,9 @@ const image = await client.textToVideo({
1861
2263
  "conversationalStream": `curl {{ fullUrl }} \\
1862
2264
  -H 'Authorization: {{ authorizationHeader }}' \\
1863
2265
  -H 'Content-Type: application/json' \\
2266
+ {% if billTo %}
2267
+ -H 'X-HF-Bill-To: {{ billTo }}' \\
2268
+ {% endif %}
1864
2269
  -d '{
1865
2270
  {{ providerInputs.asCurlString }},
1866
2271
  "stream": true
@@ -1869,7 +2274,10 @@ const image = await client.textToVideo({
1869
2274
  -X POST \\
1870
2275
  -d '{"inputs": {{ providerInputs.asObj.inputs }}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\
1871
2276
  -H 'Content-Type: application/json' \\
1872
- -H 'Authorization: {{ authorizationHeader }}'`
2277
+ -H 'Authorization: {{ authorizationHeader }}'
2278
+ {% if billTo %} \\
2279
+ -H 'X-HF-Bill-To: {{ billTo }}'
2280
+ {% endif %}`
1873
2281
  }
1874
2282
  }
1875
2283
  };
@@ -1938,16 +2346,35 @@ var HF_JS_METHODS = {
1938
2346
  translation: "translation"
1939
2347
  };
1940
2348
  var snippetGenerator = (templateName, inputPreparationFn) => {
1941
- return (model, accessToken, provider, providerModelId, opts) => {
2349
+ return (model, accessToken, provider, inferenceProviderMapping, billTo, opts) => {
2350
+ const providerModelId = inferenceProviderMapping?.providerId ?? model.id;
2351
+ let task = model.pipeline_tag;
1942
2352
  if (model.pipeline_tag && ["text-generation", "image-text-to-text"].includes(model.pipeline_tag) && model.tags.includes("conversational")) {
1943
2353
  templateName = opts?.streaming ? "conversationalStream" : "conversational";
1944
2354
  inputPreparationFn = prepareConversationalInput;
2355
+ task = "conversational";
2356
+ }
2357
+ let providerHelper;
2358
+ try {
2359
+ providerHelper = getProviderHelper(provider, task);
2360
+ } catch (e) {
2361
+ console.error(`Failed to get provider helper for ${provider} (${task})`, e);
2362
+ return [];
1945
2363
  }
1946
2364
  const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: (0, import_tasks.getModelInputSnippet)(model) };
1947
2365
  const request2 = makeRequestOptionsFromResolvedModel(
1948
- providerModelId ?? model.id,
1949
- { accessToken, provider, ...inputs },
1950
- { chatCompletion: templateName.includes("conversational"), task: model.pipeline_tag }
2366
+ providerModelId,
2367
+ providerHelper,
2368
+ {
2369
+ accessToken,
2370
+ provider,
2371
+ ...inputs
2372
+ },
2373
+ inferenceProviderMapping,
2374
+ {
2375
+ task,
2376
+ billTo
2377
+ }
1951
2378
  );
1952
2379
  let providerInputs = inputs;
1953
2380
  const bodyAsObj = request2.info.body;
@@ -1979,7 +2406,8 @@ var snippetGenerator = (templateName, inputPreparationFn) => {
1979
2406
  },
1980
2407
  model,
1981
2408
  provider,
1982
- providerModelId: providerModelId ?? model.id
2409
+ providerModelId: providerModelId ?? model.id,
2410
+ billTo
1983
2411
  };
1984
2412
  return import_tasks.inferenceSnippetLanguages.map((language) => {
1985
2413
  return CLIENTS[language].map((client) => {
@@ -2034,7 +2462,7 @@ var prepareConversationalInput = (model, opts) => {
2034
2462
  return {
2035
2463
  messages: opts?.messages ?? (0, import_tasks.getModelInputSnippet)(model),
2036
2464
  ...opts?.temperature ? { temperature: opts?.temperature } : void 0,
2037
- max_tokens: opts?.max_tokens ?? 500,
2465
+ max_tokens: opts?.max_tokens ?? 512,
2038
2466
  ...opts?.top_p ? { top_p: opts?.top_p } : void 0
2039
2467
  };
2040
2468
  };
@@ -2069,8 +2497,8 @@ var snippets = {
2069
2497
  "zero-shot-classification": snippetGenerator("zeroShotClassification"),
2070
2498
  "zero-shot-image-classification": snippetGenerator("zeroShotImageClassification")
2071
2499
  };
2072
- function getInferenceSnippets(model, accessToken, provider, providerModelId, opts) {
2073
- return model.pipeline_tag && model.pipeline_tag in snippets ? snippets[model.pipeline_tag]?.(model, accessToken, provider, providerModelId, opts) ?? [] : [];
2500
+ function getInferenceSnippets(model, accessToken, provider, inferenceProviderMapping, billTo, opts) {
2501
+ return model.pipeline_tag && model.pipeline_tag in snippets ? snippets[model.pipeline_tag]?.(model, accessToken, provider, inferenceProviderMapping, billTo, opts) ?? [] : [];
2074
2502
  }
2075
2503
  function formatBody(obj, format) {
2076
2504
  switch (format) {