@huggingface/inference 3.6.2 → 3.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. package/README.md +0 -25
  2. package/dist/index.cjs +1232 -898
  3. package/dist/index.js +1234 -900
  4. package/dist/src/config.d.ts +1 -0
  5. package/dist/src/config.d.ts.map +1 -1
  6. package/dist/src/lib/getProviderHelper.d.ts +37 -0
  7. package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
  8. package/dist/src/lib/makeRequestOptions.d.ts +0 -2
  9. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  10. package/dist/src/providers/black-forest-labs.d.ts +14 -18
  11. package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
  12. package/dist/src/providers/cerebras.d.ts +4 -2
  13. package/dist/src/providers/cerebras.d.ts.map +1 -1
  14. package/dist/src/providers/cohere.d.ts +5 -2
  15. package/dist/src/providers/cohere.d.ts.map +1 -1
  16. package/dist/src/providers/fal-ai.d.ts +50 -3
  17. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  18. package/dist/src/providers/fireworks-ai.d.ts +5 -2
  19. package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
  20. package/dist/src/providers/hf-inference.d.ts +125 -2
  21. package/dist/src/providers/hf-inference.d.ts.map +1 -1
  22. package/dist/src/providers/hyperbolic.d.ts +31 -2
  23. package/dist/src/providers/hyperbolic.d.ts.map +1 -1
  24. package/dist/src/providers/nebius.d.ts +20 -18
  25. package/dist/src/providers/nebius.d.ts.map +1 -1
  26. package/dist/src/providers/novita.d.ts +21 -18
  27. package/dist/src/providers/novita.d.ts.map +1 -1
  28. package/dist/src/providers/openai.d.ts +4 -2
  29. package/dist/src/providers/openai.d.ts.map +1 -1
  30. package/dist/src/providers/providerHelper.d.ts +182 -0
  31. package/dist/src/providers/providerHelper.d.ts.map +1 -0
  32. package/dist/src/providers/replicate.d.ts +23 -19
  33. package/dist/src/providers/replicate.d.ts.map +1 -1
  34. package/dist/src/providers/sambanova.d.ts +4 -2
  35. package/dist/src/providers/sambanova.d.ts.map +1 -1
  36. package/dist/src/providers/together.d.ts +32 -2
  37. package/dist/src/providers/together.d.ts.map +1 -1
  38. package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
  39. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  40. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  41. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  42. package/dist/src/tasks/audio/utils.d.ts +2 -1
  43. package/dist/src/tasks/audio/utils.d.ts.map +1 -1
  44. package/dist/src/tasks/custom/request.d.ts +1 -2
  45. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  46. package/dist/src/tasks/custom/streamingRequest.d.ts +1 -2
  47. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  48. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  49. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  50. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  51. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  52. package/dist/src/tasks/cv/objectDetection.d.ts +1 -1
  53. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  54. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  55. package/dist/src/tasks/cv/textToVideo.d.ts +1 -1
  56. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
  57. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts +1 -1
  58. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  59. package/dist/src/tasks/index.d.ts +6 -6
  60. package/dist/src/tasks/index.d.ts.map +1 -1
  61. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts +1 -1
  62. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  63. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  64. package/dist/src/tasks/nlp/chatCompletion.d.ts +1 -1
  65. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  66. package/dist/src/tasks/nlp/chatCompletionStream.d.ts +1 -1
  67. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
  68. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  69. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  70. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  71. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  72. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  73. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  74. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  75. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  76. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  77. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  78. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  79. package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
  80. package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
  81. package/dist/src/types.d.ts +10 -13
  82. package/dist/src/types.d.ts.map +1 -1
  83. package/dist/src/utils/request.d.ts +27 -0
  84. package/dist/src/utils/request.d.ts.map +1 -0
  85. package/package.json +3 -3
  86. package/src/config.ts +1 -0
  87. package/src/lib/getProviderHelper.ts +270 -0
  88. package/src/lib/makeRequestOptions.ts +36 -90
  89. package/src/providers/black-forest-labs.ts +73 -22
  90. package/src/providers/cerebras.ts +6 -27
  91. package/src/providers/cohere.ts +9 -28
  92. package/src/providers/fal-ai.ts +195 -77
  93. package/src/providers/fireworks-ai.ts +8 -29
  94. package/src/providers/hf-inference.ts +555 -34
  95. package/src/providers/hyperbolic.ts +107 -29
  96. package/src/providers/nebius.ts +65 -29
  97. package/src/providers/novita.ts +68 -32
  98. package/src/providers/openai.ts +6 -32
  99. package/src/providers/providerHelper.ts +354 -0
  100. package/src/providers/replicate.ts +124 -34
  101. package/src/providers/sambanova.ts +5 -30
  102. package/src/providers/together.ts +92 -28
  103. package/src/snippets/getInferenceSnippets.ts +16 -9
  104. package/src/snippets/templates.exported.ts +2 -2
  105. package/src/tasks/audio/audioClassification.ts +6 -9
  106. package/src/tasks/audio/audioToAudio.ts +5 -28
  107. package/src/tasks/audio/automaticSpeechRecognition.ts +7 -6
  108. package/src/tasks/audio/textToSpeech.ts +6 -30
  109. package/src/tasks/audio/utils.ts +2 -1
  110. package/src/tasks/custom/request.ts +7 -34
  111. package/src/tasks/custom/streamingRequest.ts +5 -87
  112. package/src/tasks/cv/imageClassification.ts +5 -9
  113. package/src/tasks/cv/imageSegmentation.ts +5 -10
  114. package/src/tasks/cv/imageToImage.ts +5 -8
  115. package/src/tasks/cv/imageToText.ts +8 -13
  116. package/src/tasks/cv/objectDetection.ts +6 -21
  117. package/src/tasks/cv/textToImage.ts +10 -138
  118. package/src/tasks/cv/textToVideo.ts +11 -59
  119. package/src/tasks/cv/zeroShotImageClassification.ts +7 -12
  120. package/src/tasks/index.ts +6 -6
  121. package/src/tasks/multimodal/documentQuestionAnswering.ts +10 -26
  122. package/src/tasks/multimodal/visualQuestionAnswering.ts +6 -12
  123. package/src/tasks/nlp/chatCompletion.ts +7 -23
  124. package/src/tasks/nlp/chatCompletionStream.ts +4 -5
  125. package/src/tasks/nlp/featureExtraction.ts +5 -20
  126. package/src/tasks/nlp/fillMask.ts +5 -18
  127. package/src/tasks/nlp/questionAnswering.ts +5 -23
  128. package/src/tasks/nlp/sentenceSimilarity.ts +5 -18
  129. package/src/tasks/nlp/summarization.ts +5 -8
  130. package/src/tasks/nlp/tableQuestionAnswering.ts +5 -29
  131. package/src/tasks/nlp/textClassification.ts +8 -14
  132. package/src/tasks/nlp/textGeneration.ts +13 -80
  133. package/src/tasks/nlp/textGenerationStream.ts +2 -2
  134. package/src/tasks/nlp/tokenClassification.ts +8 -24
  135. package/src/tasks/nlp/translation.ts +5 -8
  136. package/src/tasks/nlp/zeroShotClassification.ts +8 -22
  137. package/src/tasks/tabular/tabularClassification.ts +5 -8
  138. package/src/tasks/tabular/tabularRegression.ts +5 -8
  139. package/src/types.ts +11 -14
  140. package/src/utils/request.ts +161 -0
package/dist/index.js CHANGED
@@ -41,90 +41,215 @@ __export(tasks_exports, {
41
41
  zeroShotImageClassification: () => zeroShotImageClassification
42
42
  });
43
43
 
44
+ // package.json
45
+ var name = "@huggingface/inference";
46
+ var version = "3.7.1";
47
+
44
48
  // src/config.ts
45
49
  var HF_HUB_URL = "https://huggingface.co";
46
50
  var HF_ROUTER_URL = "https://router.huggingface.co";
51
+ var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
47
52
 
48
- // src/providers/black-forest-labs.ts
49
- var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
50
- var makeBaseUrl = () => {
51
- return BLACK_FOREST_LABS_AI_API_BASE_URL;
53
+ // src/lib/InferenceOutputError.ts
54
+ var InferenceOutputError = class extends TypeError {
55
+ constructor(message) {
56
+ super(
57
+ `Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.`
58
+ );
59
+ this.name = "InferenceOutputError";
60
+ }
52
61
  };
53
- var makeBody = (params) => {
54
- return params.args;
62
+
63
+ // src/utils/delay.ts
64
+ function delay(ms) {
65
+ return new Promise((resolve) => {
66
+ setTimeout(() => resolve(), ms);
67
+ });
68
+ }
69
+
70
+ // src/utils/pick.ts
71
+ function pick(o, props) {
72
+ return Object.assign(
73
+ {},
74
+ ...props.map((prop) => {
75
+ if (o[prop] !== void 0) {
76
+ return { [prop]: o[prop] };
77
+ }
78
+ })
79
+ );
80
+ }
81
+
82
+ // src/utils/typedInclude.ts
83
+ function typedInclude(arr, v) {
84
+ return arr.includes(v);
85
+ }
86
+
87
+ // src/utils/omit.ts
88
+ function omit(o, props) {
89
+ const propsArr = Array.isArray(props) ? props : [props];
90
+ const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
91
+ return pick(o, letsKeep);
92
+ }
93
+
94
+ // src/utils/toArray.ts
95
+ function toArray(obj) {
96
+ if (Array.isArray(obj)) {
97
+ return obj;
98
+ }
99
+ return [obj];
100
+ }
101
+
102
+ // src/providers/providerHelper.ts
103
+ var TaskProviderHelper = class {
104
+ constructor(provider, baseUrl, clientSideRoutingOnly = false) {
105
+ this.provider = provider;
106
+ this.baseUrl = baseUrl;
107
+ this.clientSideRoutingOnly = clientSideRoutingOnly;
108
+ }
109
+ /**
110
+ * Prepare the base URL for the request
111
+ */
112
+ makeBaseUrl(params) {
113
+ return params.authMethod !== "provider-key" ? `${HF_ROUTER_URL}/${this.provider}` : this.baseUrl;
114
+ }
115
+ /**
116
+ * Prepare the body for the request
117
+ */
118
+ makeBody(params) {
119
+ if ("data" in params.args && !!params.args.data) {
120
+ return params.args.data;
121
+ }
122
+ return JSON.stringify(this.preparePayload(params));
123
+ }
124
+ /**
125
+ * Prepare the URL for the request
126
+ */
127
+ makeUrl(params) {
128
+ const baseUrl = this.makeBaseUrl(params);
129
+ const route = this.makeRoute(params).replace(/^\/+/, "");
130
+ return `${baseUrl}/${route}`;
131
+ }
132
+ /**
133
+ * Prepare the headers for the request
134
+ */
135
+ prepareHeaders(params, isBinary) {
136
+ const headers = { Authorization: `Bearer ${params.accessToken}` };
137
+ if (!isBinary) {
138
+ headers["Content-Type"] = "application/json";
139
+ }
140
+ return headers;
141
+ }
55
142
  };
56
- var makeHeaders = (params) => {
57
- if (params.authMethod === "provider-key") {
58
- return { "X-Key": `${params.accessToken}` };
59
- } else {
60
- return { Authorization: `Bearer ${params.accessToken}` };
143
+ var BaseConversationalTask = class extends TaskProviderHelper {
144
+ constructor(provider, baseUrl, clientSideRoutingOnly = false) {
145
+ super(provider, baseUrl, clientSideRoutingOnly);
146
+ }
147
+ makeRoute() {
148
+ return "v1/chat/completions";
149
+ }
150
+ preparePayload(params) {
151
+ return {
152
+ ...params.args,
153
+ model: params.model
154
+ };
155
+ }
156
+ async getResponse(response) {
157
+ if (typeof response === "object" && Array.isArray(response?.choices) && typeof response?.created === "number" && typeof response?.id === "string" && typeof response?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
158
+ (response.system_fingerprint === void 0 || response.system_fingerprint === null || typeof response.system_fingerprint === "string") && typeof response?.usage === "object") {
159
+ return response;
160
+ }
161
+ throw new InferenceOutputError("Expected ChatCompletionOutput");
61
162
  }
62
163
  };
63
- var makeUrl = (params) => {
64
- return `${params.baseUrl}/v1/${params.model}`;
164
+ var BaseTextGenerationTask = class extends TaskProviderHelper {
165
+ constructor(provider, baseUrl, clientSideRoutingOnly = false) {
166
+ super(provider, baseUrl, clientSideRoutingOnly);
167
+ }
168
+ preparePayload(params) {
169
+ return {
170
+ ...params.args,
171
+ model: params.model
172
+ };
173
+ }
174
+ makeRoute() {
175
+ return "v1/completions";
176
+ }
177
+ async getResponse(response) {
178
+ const res = toArray(response);
179
+ if (Array.isArray(res) && res.length > 0 && res.every(
180
+ (x) => typeof x === "object" && !!x && "generated_text" in x && typeof x.generated_text === "string"
181
+ )) {
182
+ return res[0];
183
+ }
184
+ throw new InferenceOutputError("Expected Array<{generated_text: string}>");
185
+ }
65
186
  };
66
- var BLACK_FOREST_LABS_CONFIG = {
67
- makeBaseUrl,
68
- makeBody,
69
- makeHeaders,
70
- makeUrl
187
+
188
+ // src/providers/black-forest-labs.ts
189
+ var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
190
+ var BlackForestLabsTextToImageTask = class extends TaskProviderHelper {
191
+ constructor() {
192
+ super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
193
+ }
194
+ preparePayload(params) {
195
+ return {
196
+ ...omit(params.args, ["inputs", "parameters"]),
197
+ ...params.args.parameters,
198
+ prompt: params.args.inputs
199
+ };
200
+ }
201
+ prepareHeaders(params, binary) {
202
+ const headers = {
203
+ Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`
204
+ };
205
+ if (!binary) {
206
+ headers["Content-Type"] = "application/json";
207
+ }
208
+ return headers;
209
+ }
210
+ makeRoute(params) {
211
+ if (!params) {
212
+ throw new Error("Params are required");
213
+ }
214
+ return `/v1/${params.model}`;
215
+ }
216
+ async getResponse(response, url, headers, outputType) {
217
+ const urlObj = new URL(response.polling_url);
218
+ for (let step = 0; step < 5; step++) {
219
+ await delay(1e3);
220
+ console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
221
+ urlObj.searchParams.set("attempt", step.toString(10));
222
+ const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
223
+ if (!resp.ok) {
224
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
225
+ }
226
+ const payload = await resp.json();
227
+ if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
228
+ if (outputType === "url") {
229
+ return payload.result.sample;
230
+ }
231
+ const image = await fetch(payload.result.sample);
232
+ return await image.blob();
233
+ }
234
+ }
235
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
236
+ }
71
237
  };
72
238
 
73
239
  // src/providers/cerebras.ts
74
- var CEREBRAS_API_BASE_URL = "https://api.cerebras.ai";
75
- var makeBaseUrl2 = () => {
76
- return CEREBRAS_API_BASE_URL;
77
- };
78
- var makeBody2 = (params) => {
79
- return {
80
- ...params.args,
81
- model: params.model
82
- };
83
- };
84
- var makeHeaders2 = (params) => {
85
- return { Authorization: `Bearer ${params.accessToken}` };
86
- };
87
- var makeUrl2 = (params) => {
88
- return `${params.baseUrl}/v1/chat/completions`;
89
- };
90
- var CEREBRAS_CONFIG = {
91
- makeBaseUrl: makeBaseUrl2,
92
- makeBody: makeBody2,
93
- makeHeaders: makeHeaders2,
94
- makeUrl: makeUrl2
240
+ var CerebrasConversationalTask = class extends BaseConversationalTask {
241
+ constructor() {
242
+ super("cerebras", "https://api.cerebras.ai");
243
+ }
95
244
  };
96
245
 
97
246
  // src/providers/cohere.ts
98
- var COHERE_API_BASE_URL = "https://api.cohere.com";
99
- var makeBaseUrl3 = () => {
100
- return COHERE_API_BASE_URL;
101
- };
102
- var makeBody3 = (params) => {
103
- return {
104
- ...params.args,
105
- model: params.model
106
- };
107
- };
108
- var makeHeaders3 = (params) => {
109
- return { Authorization: `Bearer ${params.accessToken}` };
110
- };
111
- var makeUrl3 = (params) => {
112
- return `${params.baseUrl}/compatibility/v1/chat/completions`;
113
- };
114
- var COHERE_CONFIG = {
115
- makeBaseUrl: makeBaseUrl3,
116
- makeBody: makeBody3,
117
- makeHeaders: makeHeaders3,
118
- makeUrl: makeUrl3
119
- };
120
-
121
- // src/lib/InferenceOutputError.ts
122
- var InferenceOutputError = class extends TypeError {
123
- constructor(message) {
124
- super(
125
- `Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.`
126
- );
127
- this.name = "InferenceOutputError";
247
+ var CohereConversationalTask = class extends BaseConversationalTask {
248
+ constructor() {
249
+ super("cohere", "https://api.cohere.com");
250
+ }
251
+ makeRoute() {
252
+ return "/compatibility/v1/chat/completions";
128
253
  }
129
254
  };
130
255
 
@@ -133,349 +258,871 @@ function isUrl(modelOrUrl) {
133
258
  return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
134
259
  }
135
260
 
136
- // src/utils/delay.ts
137
- function delay(ms) {
138
- return new Promise((resolve) => {
139
- setTimeout(() => resolve(), ms);
140
- });
141
- }
142
-
143
261
  // src/providers/fal-ai.ts
144
- var FAL_AI_API_BASE_URL = "https://fal.run";
145
- var FAL_AI_API_BASE_URL_QUEUE = "https://queue.fal.run";
146
- var makeBaseUrl4 = (task) => {
147
- return task === "text-to-video" ? FAL_AI_API_BASE_URL_QUEUE : FAL_AI_API_BASE_URL;
262
+ var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
263
+ var FalAITask = class extends TaskProviderHelper {
264
+ constructor(url) {
265
+ super("fal-ai", url || "https://fal.run");
266
+ }
267
+ preparePayload(params) {
268
+ return params.args;
269
+ }
270
+ makeRoute(params) {
271
+ return `/${params.model}`;
272
+ }
273
+ prepareHeaders(params, binary) {
274
+ const headers = {
275
+ Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`
276
+ };
277
+ if (!binary) {
278
+ headers["Content-Type"] = "application/json";
279
+ }
280
+ return headers;
281
+ }
148
282
  };
149
- var makeBody4 = (params) => {
150
- return params.args;
283
+ var FalAITextToImageTask = class extends FalAITask {
284
+ preparePayload(params) {
285
+ return {
286
+ ...omit(params.args, ["inputs", "parameters"]),
287
+ ...params.args.parameters,
288
+ sync_mode: true,
289
+ prompt: params.args.inputs
290
+ };
291
+ }
292
+ async getResponse(response, outputType) {
293
+ if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") {
294
+ if (outputType === "url") {
295
+ return response.images[0].url;
296
+ }
297
+ const urlResponse = await fetch(response.images[0].url);
298
+ return await urlResponse.blob();
299
+ }
300
+ throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
301
+ }
151
302
  };
152
- var makeHeaders4 = (params) => {
153
- return {
154
- Authorization: params.authMethod === "provider-key" ? `Key ${params.accessToken}` : `Bearer ${params.accessToken}`
155
- };
303
+ var FalAITextToVideoTask = class extends FalAITask {
304
+ constructor() {
305
+ super("https://queue.fal.run");
306
+ }
307
+ makeRoute(params) {
308
+ if (params.authMethod !== "provider-key") {
309
+ return `/${params.model}?_subdomain=queue`;
310
+ }
311
+ return `/${params.model}`;
312
+ }
313
+ preparePayload(params) {
314
+ return {
315
+ ...omit(params.args, ["inputs", "parameters"]),
316
+ ...params.args.parameters,
317
+ prompt: params.args.inputs
318
+ };
319
+ }
320
+ async getResponse(response, url, headers) {
321
+ if (!url || !headers) {
322
+ throw new InferenceOutputError("URL and headers are required for text-to-video task");
323
+ }
324
+ const requestId = response.request_id;
325
+ if (!requestId) {
326
+ throw new InferenceOutputError("No request ID found in the response");
327
+ }
328
+ let status = response.status;
329
+ const parsedUrl = new URL(url);
330
+ const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
331
+ const modelId = new URL(response.response_url).pathname;
332
+ const queryParams = parsedUrl.search;
333
+ const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
334
+ const resultUrl = `${baseUrl}${modelId}${queryParams}`;
335
+ while (status !== "COMPLETED") {
336
+ await delay(500);
337
+ const statusResponse = await fetch(statusUrl, { headers });
338
+ if (!statusResponse.ok) {
339
+ throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
340
+ }
341
+ try {
342
+ status = (await statusResponse.json()).status;
343
+ } catch (error) {
344
+ throw new InferenceOutputError("Failed to parse status response from fal-ai API");
345
+ }
346
+ }
347
+ const resultResponse = await fetch(resultUrl, { headers });
348
+ let result;
349
+ try {
350
+ result = await resultResponse.json();
351
+ } catch (error) {
352
+ throw new InferenceOutputError("Failed to parse result response from fal-ai API");
353
+ }
354
+ if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
355
+ const urlResponse = await fetch(result.video.url);
356
+ return await urlResponse.blob();
357
+ } else {
358
+ throw new InferenceOutputError(
359
+ "Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
360
+ );
361
+ }
362
+ }
363
+ };
364
+ var FalAIAutomaticSpeechRecognitionTask = class extends FalAITask {
365
+ prepareHeaders(params, binary) {
366
+ const headers = super.prepareHeaders(params, binary);
367
+ headers["Content-Type"] = "application/json";
368
+ return headers;
369
+ }
370
+ async getResponse(response) {
371
+ const res = response;
372
+ if (typeof res?.text !== "string") {
373
+ throw new InferenceOutputError(
374
+ `Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
375
+ );
376
+ }
377
+ return { text: res.text };
378
+ }
156
379
  };
157
- var makeUrl4 = (params) => {
158
- const baseUrl = `${params.baseUrl}/${params.model}`;
159
- if (params.authMethod !== "provider-key" && params.task === "text-to-video") {
160
- return `${baseUrl}?_subdomain=queue`;
161
- }
162
- return baseUrl;
163
- };
164
- var FAL_AI_CONFIG = {
165
- makeBaseUrl: makeBaseUrl4,
166
- makeBody: makeBody4,
167
- makeHeaders: makeHeaders4,
168
- makeUrl: makeUrl4
169
- };
170
- async function pollFalResponse(res, url, headers) {
171
- const requestId = res.request_id;
172
- if (!requestId) {
173
- throw new InferenceOutputError("No request ID found in the response");
174
- }
175
- let status = res.status;
176
- const parsedUrl = new URL(url);
177
- const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
178
- const modelId = new URL(res.response_url).pathname;
179
- const queryParams = parsedUrl.search;
180
- const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
181
- const resultUrl = `${baseUrl}${modelId}${queryParams}`;
182
- while (status !== "COMPLETED") {
183
- await delay(500);
184
- const statusResponse = await fetch(statusUrl, { headers });
185
- if (!statusResponse.ok) {
186
- throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
380
+ var FalAITextToSpeechTask = class extends FalAITask {
381
+ preparePayload(params) {
382
+ return {
383
+ ...omit(params.args, ["inputs", "parameters"]),
384
+ ...params.args.parameters,
385
+ lyrics: params.args.inputs
386
+ };
387
+ }
388
+ async getResponse(response) {
389
+ const res = response;
390
+ if (typeof res?.audio?.url !== "string") {
391
+ throw new InferenceOutputError(
392
+ `Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
393
+ );
187
394
  }
188
395
  try {
189
- status = (await statusResponse.json()).status;
396
+ const urlResponse = await fetch(res.audio.url);
397
+ if (!urlResponse.ok) {
398
+ throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
399
+ }
400
+ return await urlResponse.blob();
190
401
  } catch (error) {
191
- throw new InferenceOutputError("Failed to parse status response from fal-ai API");
402
+ throw new InferenceOutputError(
403
+ `Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)}`
404
+ );
192
405
  }
193
406
  }
194
- const resultResponse = await fetch(resultUrl, { headers });
195
- let result;
196
- try {
197
- result = await resultResponse.json();
198
- } catch (error) {
199
- throw new InferenceOutputError("Failed to parse result response from fal-ai API");
407
+ };
408
+
409
+ // src/providers/fireworks-ai.ts
410
+ var FireworksConversationalTask = class extends BaseConversationalTask {
411
+ constructor() {
412
+ super("fireworks-ai", "https://api.fireworks.ai");
200
413
  }
201
- if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
202
- const urlResponse = await fetch(result.video.url);
203
- return await urlResponse.blob();
204
- } else {
414
+ makeRoute() {
415
+ return "/inference/v1/chat/completions";
416
+ }
417
+ };
418
+
419
+ // src/providers/hf-inference.ts
420
+ var HFInferenceTask = class extends TaskProviderHelper {
421
+ constructor() {
422
+ super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
423
+ }
424
+ preparePayload(params) {
425
+ return params.args;
426
+ }
427
+ makeUrl(params) {
428
+ if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
429
+ return params.model;
430
+ }
431
+ return super.makeUrl(params);
432
+ }
433
+ makeRoute(params) {
434
+ if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
435
+ return `pipeline/${params.task}/${params.model}`;
436
+ }
437
+ return `models/${params.model}`;
438
+ }
439
+ async getResponse(response) {
440
+ return response;
441
+ }
442
+ };
443
+ var HFInferenceTextToImageTask = class extends HFInferenceTask {
444
+ async getResponse(response, url, headers, outputType) {
445
+ if (!response) {
446
+ throw new InferenceOutputError("response is undefined");
447
+ }
448
+ if (typeof response == "object") {
449
+ if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
450
+ const base64Data = response.data[0].b64_json;
451
+ if (outputType === "url") {
452
+ return `data:image/jpeg;base64,${base64Data}`;
453
+ }
454
+ const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
455
+ return await base64Response.blob();
456
+ }
457
+ if ("output" in response && Array.isArray(response.output)) {
458
+ if (outputType === "url") {
459
+ return response.output[0];
460
+ }
461
+ const urlResponse = await fetch(response.output[0]);
462
+ const blob = await urlResponse.blob();
463
+ return blob;
464
+ }
465
+ }
466
+ if (response instanceof Blob) {
467
+ if (outputType === "url") {
468
+ const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
469
+ return `data:image/jpeg;base64,${b64}`;
470
+ }
471
+ return response;
472
+ }
473
+ throw new InferenceOutputError("Expected a Blob ");
474
+ }
475
+ };
476
+ var HFInferenceConversationalTask = class extends HFInferenceTask {
477
+ makeUrl(params) {
478
+ let url;
479
+ if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
480
+ url = params.model.trim();
481
+ } else {
482
+ url = `${this.makeBaseUrl(params)}/models/${params.model}`;
483
+ }
484
+ url = url.replace(/\/+$/, "");
485
+ if (url.endsWith("/v1")) {
486
+ url += "/chat/completions";
487
+ } else if (!url.endsWith("/chat/completions")) {
488
+ url += "/v1/chat/completions";
489
+ }
490
+ return url;
491
+ }
492
+ preparePayload(params) {
493
+ return {
494
+ ...params.args,
495
+ model: params.model
496
+ };
497
+ }
498
+ async getResponse(response) {
499
+ return response;
500
+ }
501
+ };
502
+ var HFInferenceTextGenerationTask = class extends HFInferenceTask {
503
+ async getResponse(response) {
504
+ const res = toArray(response);
505
+ if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
506
+ return res?.[0];
507
+ }
508
+ throw new InferenceOutputError("Expected Array<{generated_text: string}>");
509
+ }
510
+ };
511
+ var HFInferenceAudioClassificationTask = class extends HFInferenceTask {
512
+ async getResponse(response) {
513
+ if (Array.isArray(response) && response.every(
514
+ (x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
515
+ )) {
516
+ return response;
517
+ }
518
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
519
+ }
520
+ };
521
+ var HFInferenceAutomaticSpeechRecognitionTask = class extends HFInferenceTask {
522
+ async getResponse(response) {
523
+ return response;
524
+ }
525
+ };
526
+ var HFInferenceAudioToAudioTask = class extends HFInferenceTask {
527
+ async getResponse(response) {
528
+ if (!Array.isArray(response)) {
529
+ throw new InferenceOutputError("Expected Array");
530
+ }
531
+ if (!response.every((elem) => {
532
+ return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
533
+ })) {
534
+ throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
535
+ }
536
+ return response;
537
+ }
538
+ };
539
+ var HFInferenceDocumentQuestionAnsweringTask = class extends HFInferenceTask {
540
+ async getResponse(response) {
541
+ if (Array.isArray(response) && response.every(
542
+ (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
543
+ )) {
544
+ return response[0];
545
+ }
546
+ throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
547
+ }
548
+ };
549
+ var HFInferenceFeatureExtractionTask = class extends HFInferenceTask {
550
+ async getResponse(response) {
551
+ const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
552
+ if (curDepth > maxDepth)
553
+ return false;
554
+ if (arr.every((x) => Array.isArray(x))) {
555
+ return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
556
+ } else {
557
+ return arr.every((x) => typeof x === "number");
558
+ }
559
+ };
560
+ if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
561
+ return response;
562
+ }
563
+ throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
564
+ }
565
+ };
566
+ var HFInferenceImageClassificationTask = class extends HFInferenceTask {
567
+ async getResponse(response) {
568
+ if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
569
+ return response;
570
+ }
571
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
572
+ }
573
+ };
574
+ var HFInferenceImageSegmentationTask = class extends HFInferenceTask {
575
+ async getResponse(response) {
576
+ if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) {
577
+ return response;
578
+ }
579
+ throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
580
+ }
581
+ };
582
+ var HFInferenceImageToTextTask = class extends HFInferenceTask {
583
+ async getResponse(response) {
584
+ if (typeof response?.generated_text !== "string") {
585
+ throw new InferenceOutputError("Expected {generated_text: string}");
586
+ }
587
+ return response;
588
+ }
589
+ };
590
+ var HFInferenceImageToImageTask = class extends HFInferenceTask {
591
+ async getResponse(response) {
592
+ if (response instanceof Blob) {
593
+ return response;
594
+ }
595
+ throw new InferenceOutputError("Expected Blob");
596
+ }
597
+ };
598
+ var HFInferenceObjectDetectionTask = class extends HFInferenceTask {
599
+ async getResponse(response) {
600
+ if (Array.isArray(response) && response.every(
601
+ (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
602
+ )) {
603
+ return response;
604
+ }
205
605
  throw new InferenceOutputError(
206
- "Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
606
+ "Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
207
607
  );
208
608
  }
209
- }
210
-
211
- // src/providers/fireworks-ai.ts
212
- var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai";
213
- var makeBaseUrl5 = () => {
214
- return FIREWORKS_AI_API_BASE_URL;
215
609
  };
216
- var makeBody5 = (params) => {
217
- return {
218
- ...params.args,
219
- ...params.chatCompletion ? { model: params.model } : void 0
220
- };
610
+ var HFInferenceZeroShotImageClassificationTask = class extends HFInferenceTask {
611
+ async getResponse(response) {
612
+ if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
613
+ return response;
614
+ }
615
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
616
+ }
617
+ };
618
+ var HFInferenceTextClassificationTask = class extends HFInferenceTask {
619
+ async getResponse(response) {
620
+ const output = response?.[0];
621
+ if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
622
+ return output;
623
+ }
624
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
625
+ }
626
+ };
627
+ var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask {
628
+ async getResponse(response) {
629
+ if (Array.isArray(response) ? response.every(
630
+ (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
631
+ ) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number") {
632
+ return Array.isArray(response) ? response[0] : response;
633
+ }
634
+ throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
635
+ }
636
+ };
637
+ var HFInferenceFillMaskTask = class extends HFInferenceTask {
638
+ async getResponse(response) {
639
+ if (Array.isArray(response) && response.every(
640
+ (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
641
+ )) {
642
+ return response;
643
+ }
644
+ throw new InferenceOutputError(
645
+ "Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
646
+ );
647
+ }
221
648
  };
222
- var makeHeaders5 = (params) => {
223
- return { Authorization: `Bearer ${params.accessToken}` };
649
+ var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask {
650
+ async getResponse(response) {
651
+ if (Array.isArray(response) && response.every(
652
+ (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
653
+ )) {
654
+ return response;
655
+ }
656
+ throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
657
+ }
224
658
  };
225
- var makeUrl5 = (params) => {
226
- if (params.chatCompletion) {
227
- return `${params.baseUrl}/inference/v1/chat/completions`;
659
+ var HFInferenceSentenceSimilarityTask = class extends HFInferenceTask {
660
+ async getResponse(response) {
661
+ if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
662
+ return response;
663
+ }
664
+ throw new InferenceOutputError("Expected Array<number>");
228
665
  }
229
- return `${params.baseUrl}/inference`;
230
666
  };
231
- var FIREWORKS_AI_CONFIG = {
232
- makeBaseUrl: makeBaseUrl5,
233
- makeBody: makeBody5,
234
- makeHeaders: makeHeaders5,
235
- makeUrl: makeUrl5
667
+ var HFInferenceTableQuestionAnsweringTask = class extends HFInferenceTask {
668
+ static validate(elem) {
669
+ return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
670
+ (coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
671
+ );
672
+ }
673
+ async getResponse(response) {
674
+ if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) {
675
+ return Array.isArray(response) ? response[0] : response;
676
+ }
677
+ throw new InferenceOutputError(
678
+ "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
679
+ );
680
+ }
236
681
  };
237
-
238
- // src/providers/hf-inference.ts
239
- var makeBaseUrl6 = () => {
240
- return `${HF_ROUTER_URL}/hf-inference`;
682
+ var HFInferenceTokenClassificationTask = class extends HFInferenceTask {
683
+ async getResponse(response) {
684
+ if (Array.isArray(response) && response.every(
685
+ (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
686
+ )) {
687
+ return response;
688
+ }
689
+ throw new InferenceOutputError(
690
+ "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
691
+ );
692
+ }
241
693
  };
242
- var makeBody6 = (params) => {
243
- return {
244
- ...params.args,
245
- ...params.chatCompletion ? { model: params.model } : void 0
246
- };
694
+ var HFInferenceTranslationTask = class extends HFInferenceTask {
695
+ async getResponse(response) {
696
+ if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
697
+ return response?.length === 1 ? response?.[0] : response;
698
+ }
699
+ throw new InferenceOutputError("Expected Array<{translation_text: string}>");
700
+ }
247
701
  };
248
- var makeHeaders6 = (params) => {
249
- return { Authorization: `Bearer ${params.accessToken}` };
702
+ var HFInferenceSummarizationTask = class extends HFInferenceTask {
703
+ async getResponse(response) {
704
+ if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
705
+ return response?.[0];
706
+ }
707
+ throw new InferenceOutputError("Expected Array<{summary_text: string}>");
708
+ }
250
709
  };
251
- var makeUrl6 = (params) => {
252
- if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
253
- return `${params.baseUrl}/pipeline/${params.task}/${params.model}`;
710
+ var HFInferenceTextToSpeechTask = class extends HFInferenceTask {
711
+ async getResponse(response) {
712
+ return response;
254
713
  }
255
- if (params.chatCompletion) {
256
- return `${params.baseUrl}/models/${params.model}/v1/chat/completions`;
714
+ };
715
+ var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
716
+ async getResponse(response) {
717
+ if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
718
+ return response;
719
+ }
720
+ throw new InferenceOutputError("Expected Array<number>");
721
+ }
722
+ };
723
+ var HFInferenceVisualQuestionAnsweringTask = class extends HFInferenceTask {
724
+ async getResponse(response) {
725
+ if (Array.isArray(response) && response.every(
726
+ (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
727
+ )) {
728
+ return response[0];
729
+ }
730
+ throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
257
731
  }
258
- return `${params.baseUrl}/models/${params.model}`;
259
732
  };
260
- var HF_INFERENCE_CONFIG = {
261
- makeBaseUrl: makeBaseUrl6,
262
- makeBody: makeBody6,
263
- makeHeaders: makeHeaders6,
264
- makeUrl: makeUrl6
733
+ var HFInferenceTabularRegressionTask = class extends HFInferenceTask {
734
+ async getResponse(response) {
735
+ if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
736
+ return response;
737
+ }
738
+ throw new InferenceOutputError("Expected Array<number>");
739
+ }
740
+ };
741
+ var HFInferenceTextToAudioTask = class extends HFInferenceTask {
742
+ async getResponse(response) {
743
+ return response;
744
+ }
265
745
  };
266
746
 
267
747
  // src/providers/hyperbolic.ts
268
748
  var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
269
- var makeBaseUrl7 = () => {
270
- return HYPERBOLIC_API_BASE_URL;
271
- };
272
- var makeBody7 = (params) => {
273
- return {
274
- ...params.args,
275
- ...params.task === "text-to-image" ? { model_name: params.model } : { model: params.model }
276
- };
277
- };
278
- var makeHeaders7 = (params) => {
279
- return { Authorization: `Bearer ${params.accessToken}` };
749
+ var HyperbolicConversationalTask = class extends BaseConversationalTask {
750
+ constructor() {
751
+ super("hyperbolic", HYPERBOLIC_API_BASE_URL);
752
+ }
280
753
  };
281
- var makeUrl7 = (params) => {
282
- if (params.task === "text-to-image") {
283
- return `${params.baseUrl}/v1/images/generations`;
754
+ var HyperbolicTextGenerationTask = class extends BaseTextGenerationTask {
755
+ constructor() {
756
+ super("hyperbolic", HYPERBOLIC_API_BASE_URL);
757
+ }
758
+ makeRoute() {
759
+ return "v1/chat/completions";
760
+ }
761
+ preparePayload(params) {
762
+ return {
763
+ messages: [{ content: params.args.inputs, role: "user" }],
764
+ ...params.args.parameters ? {
765
+ max_tokens: params.args.parameters.max_new_tokens,
766
+ ...omit(params.args.parameters, "max_new_tokens")
767
+ } : void 0,
768
+ ...omit(params.args, ["inputs", "parameters"]),
769
+ model: params.model
770
+ };
771
+ }
772
+ async getResponse(response) {
773
+ if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
774
+ const completion = response.choices[0];
775
+ return {
776
+ generated_text: completion.message.content
777
+ };
778
+ }
779
+ throw new InferenceOutputError("Expected Hyperbolic text generation response format");
284
780
  }
285
- return `${params.baseUrl}/v1/chat/completions`;
286
781
  };
287
- var HYPERBOLIC_CONFIG = {
288
- makeBaseUrl: makeBaseUrl7,
289
- makeBody: makeBody7,
290
- makeHeaders: makeHeaders7,
291
- makeUrl: makeUrl7
782
+ var HyperbolicTextToImageTask = class extends TaskProviderHelper {
783
+ constructor() {
784
+ super("hyperbolic", HYPERBOLIC_API_BASE_URL);
785
+ }
786
+ makeRoute(params) {
787
+ return `/v1/images/generations`;
788
+ }
789
+ preparePayload(params) {
790
+ return {
791
+ ...omit(params.args, ["inputs", "parameters"]),
792
+ ...params.args.parameters,
793
+ prompt: params.args.inputs,
794
+ model_name: params.model
795
+ };
796
+ }
797
+ async getResponse(response, url, headers, outputType) {
798
+ if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images[0] && typeof response.images[0].image === "string") {
799
+ if (outputType === "url") {
800
+ return `data:image/jpeg;base64,${response.images[0].image}`;
801
+ }
802
+ return fetch(`data:image/jpeg;base64,${response.images[0].image}`).then((res) => res.blob());
803
+ }
804
+ throw new InferenceOutputError("Expected Hyperbolic text-to-image response format");
805
+ }
292
806
  };
293
807
 
294
808
  // src/providers/nebius.ts
295
809
  var NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
296
- var makeBaseUrl8 = () => {
297
- return NEBIUS_API_BASE_URL;
298
- };
299
- var makeBody8 = (params) => {
300
- return {
301
- ...params.args,
302
- model: params.model
303
- };
810
+ var NebiusConversationalTask = class extends BaseConversationalTask {
811
+ constructor() {
812
+ super("nebius", NEBIUS_API_BASE_URL);
813
+ }
304
814
  };
305
- var makeHeaders8 = (params) => {
306
- return { Authorization: `Bearer ${params.accessToken}` };
815
+ var NebiusTextGenerationTask = class extends BaseTextGenerationTask {
816
+ constructor() {
817
+ super("nebius", NEBIUS_API_BASE_URL);
818
+ }
307
819
  };
308
- var makeUrl8 = (params) => {
309
- if (params.task === "text-to-image") {
310
- return `${params.baseUrl}/v1/images/generations`;
820
+ var NebiusTextToImageTask = class extends TaskProviderHelper {
821
+ constructor() {
822
+ super("nebius", NEBIUS_API_BASE_URL);
311
823
  }
312
- if (params.chatCompletion) {
313
- return `${params.baseUrl}/v1/chat/completions`;
824
+ preparePayload(params) {
825
+ return {
826
+ ...omit(params.args, ["inputs", "parameters"]),
827
+ ...params.args.parameters,
828
+ response_format: "b64_json",
829
+ prompt: params.args.inputs,
830
+ model: params.model
831
+ };
314
832
  }
315
- if (params.task === "text-generation") {
316
- return `${params.baseUrl}/v1/completions`;
833
+ makeRoute(params) {
834
+ return "v1/images/generations";
835
+ }
836
+ async getResponse(response, url, headers, outputType) {
837
+ if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
838
+ const base64Data = response.data[0].b64_json;
839
+ if (outputType === "url") {
840
+ return `data:image/jpeg;base64,${base64Data}`;
841
+ }
842
+ return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
843
+ }
844
+ throw new InferenceOutputError("Expected Nebius text-to-image response format");
317
845
  }
318
- return params.baseUrl;
319
- };
320
- var NEBIUS_CONFIG = {
321
- makeBaseUrl: makeBaseUrl8,
322
- makeBody: makeBody8,
323
- makeHeaders: makeHeaders8,
324
- makeUrl: makeUrl8
325
846
  };
326
847
 
327
848
  // src/providers/novita.ts
328
849
  var NOVITA_API_BASE_URL = "https://api.novita.ai";
329
- var makeBaseUrl9 = () => {
330
- return NOVITA_API_BASE_URL;
331
- };
332
- var makeBody9 = (params) => {
333
- return {
334
- ...params.args,
335
- ...params.chatCompletion ? { model: params.model } : void 0
336
- };
337
- };
338
- var makeHeaders9 = (params) => {
339
- return { Authorization: `Bearer ${params.accessToken}` };
850
+ var NovitaTextGenerationTask = class extends BaseTextGenerationTask {
851
+ constructor() {
852
+ super("novita", NOVITA_API_BASE_URL);
853
+ }
854
+ makeRoute() {
855
+ return "/v3/openai/chat/completions";
856
+ }
340
857
  };
341
- var makeUrl9 = (params) => {
342
- if (params.chatCompletion) {
343
- return `${params.baseUrl}/v3/openai/chat/completions`;
344
- } else if (params.task === "text-generation") {
345
- return `${params.baseUrl}/v3/openai/completions`;
346
- } else if (params.task === "text-to-video") {
347
- return `${params.baseUrl}/v3/hf/${params.model}`;
858
+ var NovitaConversationalTask = class extends BaseConversationalTask {
859
+ constructor() {
860
+ super("novita", NOVITA_API_BASE_URL);
861
+ }
862
+ makeRoute() {
863
+ return "/v3/openai/chat/completions";
348
864
  }
349
- return params.baseUrl;
350
865
  };
351
- var NOVITA_CONFIG = {
352
- makeBaseUrl: makeBaseUrl9,
353
- makeBody: makeBody9,
354
- makeHeaders: makeHeaders9,
355
- makeUrl: makeUrl9
866
+
867
+ // src/providers/openai.ts
868
+ var OPENAI_API_BASE_URL = "https://api.openai.com";
869
+ var OpenAIConversationalTask = class extends BaseConversationalTask {
870
+ constructor() {
871
+ super("openai", OPENAI_API_BASE_URL, true);
872
+ }
356
873
  };
357
874
 
358
875
  // src/providers/replicate.ts
359
- var REPLICATE_API_BASE_URL = "https://api.replicate.com";
360
- var makeBaseUrl10 = () => {
361
- return REPLICATE_API_BASE_URL;
362
- };
363
- var makeBody10 = (params) => {
364
- return {
365
- input: params.args,
366
- version: params.model.includes(":") ? params.model.split(":")[1] : void 0
367
- };
876
+ var ReplicateTask = class extends TaskProviderHelper {
877
+ constructor(url) {
878
+ super("replicate", url || "https://api.replicate.com");
879
+ }
880
+ makeRoute(params) {
881
+ if (params.model.includes(":")) {
882
+ return "v1/predictions";
883
+ }
884
+ return `v1/models/${params.model}/predictions`;
885
+ }
886
+ preparePayload(params) {
887
+ return {
888
+ input: {
889
+ ...omit(params.args, ["inputs", "parameters"]),
890
+ ...params.args.parameters,
891
+ prompt: params.args.inputs
892
+ },
893
+ version: params.model.includes(":") ? params.model.split(":")[1] : void 0
894
+ };
895
+ }
896
+ prepareHeaders(params, binary) {
897
+ const headers = { Authorization: `Bearer ${params.accessToken}`, Prefer: "wait" };
898
+ if (!binary) {
899
+ headers["Content-Type"] = "application/json";
900
+ }
901
+ return headers;
902
+ }
903
+ makeUrl(params) {
904
+ const baseUrl = this.makeBaseUrl(params);
905
+ if (params.model.includes(":")) {
906
+ return `${baseUrl}/v1/predictions`;
907
+ }
908
+ return `${baseUrl}/v1/models/${params.model}/predictions`;
909
+ }
368
910
  };
369
- var makeHeaders10 = (params) => {
370
- return { Authorization: `Bearer ${params.accessToken}`, Prefer: "wait" };
911
+ var ReplicateTextToImageTask = class extends ReplicateTask {
912
+ async getResponse(res, url, headers, outputType) {
913
+ if (typeof res === "object" && "output" in res && Array.isArray(res.output) && res.output.length > 0 && typeof res.output[0] === "string") {
914
+ if (outputType === "url") {
915
+ return res.output[0];
916
+ }
917
+ const urlResponse = await fetch(res.output[0]);
918
+ return await urlResponse.blob();
919
+ }
920
+ throw new InferenceOutputError("Expected Replicate text-to-image response format");
921
+ }
371
922
  };
372
- var makeUrl10 = (params) => {
373
- if (params.model.includes(":")) {
374
- return `${params.baseUrl}/v1/predictions`;
923
+ var ReplicateTextToSpeechTask = class extends ReplicateTask {
924
+ preparePayload(params) {
925
+ const payload = super.preparePayload(params);
926
+ const input = payload["input"];
927
+ if (typeof input === "object" && input !== null && "prompt" in input) {
928
+ const inputObj = input;
929
+ inputObj["text"] = inputObj["prompt"];
930
+ delete inputObj["prompt"];
931
+ }
932
+ return payload;
933
+ }
934
+ async getResponse(response) {
935
+ if (response instanceof Blob) {
936
+ return response;
937
+ }
938
+ if (response && typeof response === "object") {
939
+ if ("output" in response) {
940
+ if (typeof response.output === "string") {
941
+ const urlResponse = await fetch(response.output);
942
+ return await urlResponse.blob();
943
+ } else if (Array.isArray(response.output)) {
944
+ const urlResponse = await fetch(response.output[0]);
945
+ return await urlResponse.blob();
946
+ }
947
+ }
948
+ }
949
+ throw new InferenceOutputError("Expected Blob or object with output");
375
950
  }
376
- return `${params.baseUrl}/v1/models/${params.model}/predictions`;
377
951
  };
378
- var REPLICATE_CONFIG = {
379
- makeBaseUrl: makeBaseUrl10,
380
- makeBody: makeBody10,
381
- makeHeaders: makeHeaders10,
382
- makeUrl: makeUrl10
952
+ var ReplicateTextToVideoTask = class extends ReplicateTask {
953
+ async getResponse(response) {
954
+ if (typeof response === "object" && !!response && "output" in response && typeof response.output === "string" && isUrl(response.output)) {
955
+ const urlResponse = await fetch(response.output);
956
+ return await urlResponse.blob();
957
+ }
958
+ throw new InferenceOutputError("Expected { output: string }");
959
+ }
383
960
  };
384
961
 
385
962
  // src/providers/sambanova.ts
386
- var SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
387
- var makeBaseUrl11 = () => {
388
- return SAMBANOVA_API_BASE_URL;
389
- };
390
- var makeBody11 = (params) => {
391
- return {
392
- ...params.args,
393
- ...params.chatCompletion ? { model: params.model } : void 0
394
- };
395
- };
396
- var makeHeaders11 = (params) => {
397
- return { Authorization: `Bearer ${params.accessToken}` };
398
- };
399
- var makeUrl11 = (params) => {
400
- if (params.chatCompletion) {
401
- return `${params.baseUrl}/v1/chat/completions`;
963
+ var SambanovaConversationalTask = class extends BaseConversationalTask {
964
+ constructor() {
965
+ super("sambanova", "https://api.sambanova.ai");
402
966
  }
403
- return params.baseUrl;
404
- };
405
- var SAMBANOVA_CONFIG = {
406
- makeBaseUrl: makeBaseUrl11,
407
- makeBody: makeBody11,
408
- makeHeaders: makeHeaders11,
409
- makeUrl: makeUrl11
410
967
  };
411
968
 
412
969
  // src/providers/together.ts
413
970
  var TOGETHER_API_BASE_URL = "https://api.together.xyz";
414
- var makeBaseUrl12 = () => {
415
- return TOGETHER_API_BASE_URL;
416
- };
417
- var makeBody12 = (params) => {
418
- return {
419
- ...params.args,
420
- model: params.model
421
- };
422
- };
423
- var makeHeaders12 = (params) => {
424
- return { Authorization: `Bearer ${params.accessToken}` };
971
+ var TogetherConversationalTask = class extends BaseConversationalTask {
972
+ constructor() {
973
+ super("together", TOGETHER_API_BASE_URL);
974
+ }
425
975
  };
426
- var makeUrl12 = (params) => {
427
- if (params.task === "text-to-image") {
428
- return `${params.baseUrl}/v1/images/generations`;
976
+ var TogetherTextGenerationTask = class extends BaseTextGenerationTask {
977
+ constructor() {
978
+ super("together", TOGETHER_API_BASE_URL);
429
979
  }
430
- if (params.chatCompletion) {
431
- return `${params.baseUrl}/v1/chat/completions`;
980
+ preparePayload(params) {
981
+ return {
982
+ model: params.model,
983
+ ...params.args,
984
+ prompt: params.args.inputs
985
+ };
432
986
  }
433
- if (params.task === "text-generation") {
434
- return `${params.baseUrl}/v1/completions`;
987
+ async getResponse(response) {
988
+ if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
989
+ const completion = response.choices[0];
990
+ return {
991
+ generated_text: completion.text
992
+ };
993
+ }
994
+ throw new InferenceOutputError("Expected Together text generation response format");
435
995
  }
436
- return params.baseUrl;
437
996
  };
438
- var TOGETHER_CONFIG = {
439
- makeBaseUrl: makeBaseUrl12,
440
- makeBody: makeBody12,
441
- makeHeaders: makeHeaders12,
442
- makeUrl: makeUrl12
997
+ var TogetherTextToImageTask = class extends TaskProviderHelper {
998
+ constructor() {
999
+ super("together", TOGETHER_API_BASE_URL);
1000
+ }
1001
+ makeRoute() {
1002
+ return "v1/images/generations";
1003
+ }
1004
+ preparePayload(params) {
1005
+ return {
1006
+ ...omit(params.args, ["inputs", "parameters"]),
1007
+ ...params.args.parameters,
1008
+ prompt: params.args.inputs,
1009
+ response_format: "base64",
1010
+ model: params.model
1011
+ };
1012
+ }
1013
+ async getResponse(response, outputType) {
1014
+ if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
1015
+ const base64Data = response.data[0].b64_json;
1016
+ if (outputType === "url") {
1017
+ return `data:image/jpeg;base64,${base64Data}`;
1018
+ }
1019
+ return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
1020
+ }
1021
+ throw new InferenceOutputError("Expected Together text-to-image response format");
1022
+ }
443
1023
  };
444
1024
 
445
- // src/providers/openai.ts
446
- var OPENAI_API_BASE_URL = "https://api.openai.com";
447
- var makeBaseUrl13 = () => {
448
- return OPENAI_API_BASE_URL;
449
- };
450
- var makeBody13 = (params) => {
451
- if (!params.chatCompletion) {
452
- throw new Error("OpenAI only supports chat completions.");
1025
+ // src/lib/getProviderHelper.ts
1026
+ var PROVIDERS = {
1027
+ "black-forest-labs": {
1028
+ "text-to-image": new BlackForestLabsTextToImageTask()
1029
+ },
1030
+ cerebras: {
1031
+ conversational: new CerebrasConversationalTask()
1032
+ },
1033
+ cohere: {
1034
+ conversational: new CohereConversationalTask()
1035
+ },
1036
+ "fal-ai": {
1037
+ "text-to-image": new FalAITextToImageTask(),
1038
+ "text-to-speech": new FalAITextToSpeechTask(),
1039
+ "text-to-video": new FalAITextToVideoTask(),
1040
+ "automatic-speech-recognition": new FalAIAutomaticSpeechRecognitionTask()
1041
+ },
1042
+ "hf-inference": {
1043
+ "text-to-image": new HFInferenceTextToImageTask(),
1044
+ conversational: new HFInferenceConversationalTask(),
1045
+ "text-generation": new HFInferenceTextGenerationTask(),
1046
+ "text-classification": new HFInferenceTextClassificationTask(),
1047
+ "question-answering": new HFInferenceQuestionAnsweringTask(),
1048
+ "audio-classification": new HFInferenceAudioClassificationTask(),
1049
+ "automatic-speech-recognition": new HFInferenceAutomaticSpeechRecognitionTask(),
1050
+ "fill-mask": new HFInferenceFillMaskTask(),
1051
+ "feature-extraction": new HFInferenceFeatureExtractionTask(),
1052
+ "image-classification": new HFInferenceImageClassificationTask(),
1053
+ "image-segmentation": new HFInferenceImageSegmentationTask(),
1054
+ "document-question-answering": new HFInferenceDocumentQuestionAnsweringTask(),
1055
+ "image-to-text": new HFInferenceImageToTextTask(),
1056
+ "object-detection": new HFInferenceObjectDetectionTask(),
1057
+ "audio-to-audio": new HFInferenceAudioToAudioTask(),
1058
+ "zero-shot-image-classification": new HFInferenceZeroShotImageClassificationTask(),
1059
+ "zero-shot-classification": new HFInferenceZeroShotClassificationTask(),
1060
+ "image-to-image": new HFInferenceImageToImageTask(),
1061
+ "sentence-similarity": new HFInferenceSentenceSimilarityTask(),
1062
+ "table-question-answering": new HFInferenceTableQuestionAnsweringTask(),
1063
+ "tabular-classification": new HFInferenceTabularClassificationTask(),
1064
+ "text-to-speech": new HFInferenceTextToSpeechTask(),
1065
+ "token-classification": new HFInferenceTokenClassificationTask(),
1066
+ translation: new HFInferenceTranslationTask(),
1067
+ summarization: new HFInferenceSummarizationTask(),
1068
+ "visual-question-answering": new HFInferenceVisualQuestionAnsweringTask(),
1069
+ "tabular-regression": new HFInferenceTabularRegressionTask(),
1070
+ "text-to-audio": new HFInferenceTextToAudioTask()
1071
+ },
1072
+ "fireworks-ai": {
1073
+ conversational: new FireworksConversationalTask()
1074
+ },
1075
+ hyperbolic: {
1076
+ "text-to-image": new HyperbolicTextToImageTask(),
1077
+ conversational: new HyperbolicConversationalTask(),
1078
+ "text-generation": new HyperbolicTextGenerationTask()
1079
+ },
1080
+ nebius: {
1081
+ "text-to-image": new NebiusTextToImageTask(),
1082
+ conversational: new NebiusConversationalTask(),
1083
+ "text-generation": new NebiusTextGenerationTask()
1084
+ },
1085
+ novita: {
1086
+ conversational: new NovitaConversationalTask(),
1087
+ "text-generation": new NovitaTextGenerationTask()
1088
+ },
1089
+ openai: {
1090
+ conversational: new OpenAIConversationalTask()
1091
+ },
1092
+ replicate: {
1093
+ "text-to-image": new ReplicateTextToImageTask(),
1094
+ "text-to-speech": new ReplicateTextToSpeechTask(),
1095
+ "text-to-video": new ReplicateTextToVideoTask()
1096
+ },
1097
+ sambanova: {
1098
+ conversational: new SambanovaConversationalTask()
1099
+ },
1100
+ together: {
1101
+ "text-to-image": new TogetherTextToImageTask(),
1102
+ conversational: new TogetherConversationalTask(),
1103
+ "text-generation": new TogetherTextGenerationTask()
453
1104
  }
454
- return {
455
- ...params.args,
456
- model: params.model
457
- };
458
1105
  };
459
- var makeHeaders13 = (params) => {
460
- return { Authorization: `Bearer ${params.accessToken}` };
461
- };
462
- var makeUrl13 = (params) => {
463
- if (!params.chatCompletion) {
464
- throw new Error("OpenAI only supports chat completions.");
1106
+ function getProviderHelper(provider, task) {
1107
+ if (provider === "hf-inference") {
1108
+ if (!task) {
1109
+ return new HFInferenceTask();
1110
+ }
465
1111
  }
466
- return `${params.baseUrl}/v1/chat/completions`;
467
- };
468
- var OPENAI_CONFIG = {
469
- makeBaseUrl: makeBaseUrl13,
470
- makeBody: makeBody13,
471
- makeHeaders: makeHeaders13,
472
- makeUrl: makeUrl13,
473
- clientSideRoutingOnly: true
474
- };
475
-
476
- // package.json
477
- var name = "@huggingface/inference";
478
- var version = "3.6.2";
1112
+ if (!task) {
1113
+ throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'");
1114
+ }
1115
+ if (!(provider in PROVIDERS)) {
1116
+ throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
1117
+ }
1118
+ const providerTasks = PROVIDERS[provider];
1119
+ if (!providerTasks || !(task in providerTasks)) {
1120
+ throw new Error(
1121
+ `Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}`
1122
+ );
1123
+ }
1124
+ return providerTasks[task];
1125
+ }
479
1126
 
480
1127
  // src/providers/consts.ts
481
1128
  var HARDCODED_MODEL_ID_MAPPING = {
@@ -545,28 +1192,11 @@ async function getProviderModelId(params, args, options = {}) {
545
1192
  }
546
1193
 
547
1194
  // src/lib/makeRequestOptions.ts
548
- var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
549
1195
  var tasks = null;
550
- var providerConfigs = {
551
- "black-forest-labs": BLACK_FOREST_LABS_CONFIG,
552
- cerebras: CEREBRAS_CONFIG,
553
- cohere: COHERE_CONFIG,
554
- "fal-ai": FAL_AI_CONFIG,
555
- "fireworks-ai": FIREWORKS_AI_CONFIG,
556
- "hf-inference": HF_INFERENCE_CONFIG,
557
- hyperbolic: HYPERBOLIC_CONFIG,
558
- openai: OPENAI_CONFIG,
559
- nebius: NEBIUS_CONFIG,
560
- novita: NOVITA_CONFIG,
561
- replicate: REPLICATE_CONFIG,
562
- sambanova: SAMBANOVA_CONFIG,
563
- together: TOGETHER_CONFIG
564
- };
565
1196
  async function makeRequestOptions(args, options) {
566
1197
  const { provider: maybeProvider, model: maybeModel } = args;
567
1198
  const provider = maybeProvider ?? "hf-inference";
568
- const providerConfig = providerConfigs[provider];
569
- const { task, chatCompletion: chatCompletion2 } = options ?? {};
1199
+ const { task } = options ?? {};
570
1200
  if (args.endpointUrl && provider !== "hf-inference") {
571
1201
  throw new Error(`Cannot use endpointUrl with a third-party provider.`);
572
1202
  }
@@ -576,19 +1206,16 @@ async function makeRequestOptions(args, options) {
576
1206
  if (!maybeModel && !task) {
577
1207
  throw new Error("No model provided, and no task has been specified.");
578
1208
  }
579
- if (!providerConfig) {
580
- throw new Error(`No provider config found for provider ${provider}`);
581
- }
582
- if (providerConfig.clientSideRoutingOnly && !maybeModel) {
1209
+ const hfModel = maybeModel ?? await loadDefaultModel(task);
1210
+ const providerHelper = getProviderHelper(provider, task);
1211
+ if (providerHelper.clientSideRoutingOnly && !maybeModel) {
583
1212
  throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
584
1213
  }
585
- const hfModel = maybeModel ?? await loadDefaultModel(task);
586
- const resolvedModel = providerConfig.clientSideRoutingOnly ? (
1214
+ const resolvedModel = providerHelper.clientSideRoutingOnly ? (
587
1215
  // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
588
1216
  removeProviderPrefix(maybeModel, provider)
589
1217
  ) : await getProviderModelId({ model: hfModel, provider }, args, {
590
1218
  task,
591
- chatCompletion: chatCompletion2,
592
1219
  fetch: options?.fetch
593
1220
  });
594
1221
  return makeRequestOptionsFromResolvedModel(resolvedModel, args, options);
@@ -596,10 +1223,10 @@ async function makeRequestOptions(args, options) {
596
1223
  function makeRequestOptionsFromResolvedModel(resolvedModel, args, options) {
597
1224
  const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
598
1225
  const provider = maybeProvider ?? "hf-inference";
599
- const providerConfig = providerConfigs[provider];
600
- const { includeCredentials, task, chatCompletion: chatCompletion2, signal } = options ?? {};
1226
+ const { includeCredentials, task, signal, billTo } = options ?? {};
1227
+ const providerHelper = getProviderHelper(provider, task);
601
1228
  const authMethod = (() => {
602
- if (providerConfig.clientSideRoutingOnly) {
1229
+ if (providerHelper.clientSideRoutingOnly) {
603
1230
  if (accessToken && accessToken.startsWith("hf_")) {
604
1231
  throw new Error(`Provider ${provider} is closed-source and does not support HF tokens.`);
605
1232
  }
@@ -613,32 +1240,30 @@ function makeRequestOptionsFromResolvedModel(resolvedModel, args, options) {
613
1240
  }
614
1241
  return "none";
615
1242
  })();
616
- const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : providerConfig.makeUrl({
1243
+ const modelId = endpointUrl ?? resolvedModel;
1244
+ const url = providerHelper.makeUrl({
617
1245
  authMethod,
618
- baseUrl: authMethod !== "provider-key" ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider) : providerConfig.makeBaseUrl(task),
619
- model: resolvedModel,
620
- chatCompletion: chatCompletion2,
1246
+ model: modelId,
621
1247
  task
622
1248
  });
623
- const binary = "data" in args && !!args.data;
624
- const headers = providerConfig.makeHeaders({
625
- accessToken,
626
- authMethod
627
- });
628
- if (!binary) {
629
- headers["Content-Type"] = "application/json";
1249
+ const headers = providerHelper.prepareHeaders(
1250
+ {
1251
+ accessToken,
1252
+ authMethod
1253
+ },
1254
+ "data" in args && !!args.data
1255
+ );
1256
+ if (billTo) {
1257
+ headers[HF_HEADER_X_BILL_TO] = billTo;
630
1258
  }
631
1259
  const ownUserAgent = `${name}/${version}`;
632
1260
  const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : void 0].filter((x) => x !== void 0).join(" ");
633
1261
  headers["User-Agent"] = userAgent;
634
- const body = binary ? args.data : JSON.stringify(
635
- providerConfig.makeBody({
636
- args: remainingArgs,
637
- model: resolvedModel,
638
- task,
639
- chatCompletion: chatCompletion2
640
- })
641
- );
1262
+ const body = providerHelper.makeBody({
1263
+ args: remainingArgs,
1264
+ model: resolvedModel,
1265
+ task
1266
+ });
642
1267
  let credentials;
643
1268
  if (typeof includeCredentials === "string") {
644
1269
  credentials = includeCredentials;
@@ -678,37 +1303,6 @@ function removeProviderPrefix(model, provider) {
678
1303
  return model.slice(provider.length + 1);
679
1304
  }
680
1305
 
681
- // src/tasks/custom/request.ts
682
- async function request(args, options) {
683
- const { url, info } = await makeRequestOptions(args, options);
684
- const response = await (options?.fetch ?? fetch)(url, info);
685
- if (options?.retry_on_error !== false && response.status === 503) {
686
- return request(args, options);
687
- }
688
- if (!response.ok) {
689
- const contentType = response.headers.get("Content-Type");
690
- if (["application/json", "application/problem+json"].some((ct) => contentType?.startsWith(ct))) {
691
- const output = await response.json();
692
- if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
693
- throw new Error(
694
- `Server ${args.model} does not seem to support chat completion. Error: ${JSON.stringify(output.error)}`
695
- );
696
- }
697
- if (output.error || output.detail) {
698
- throw new Error(JSON.stringify(output.error ?? output.detail));
699
- } else {
700
- throw new Error(output);
701
- }
702
- }
703
- const message = contentType?.startsWith("text/plain;") ? await response.text() : void 0;
704
- throw new Error(message ?? "An error occurred while fetching the blob");
705
- }
706
- if (response.headers.get("Content-Type")?.startsWith("application/json")) {
707
- return await response.json();
708
- }
709
- return await response.blob();
710
- }
711
-
712
1306
  // src/vendor/fetch-event-source/parse.ts
713
1307
  function getLines(onLine) {
714
1308
  let buffer;
@@ -808,12 +1402,44 @@ function newMessage() {
808
1402
  };
809
1403
  }
810
1404
 
811
- // src/tasks/custom/streamingRequest.ts
812
- async function* streamingRequest(args, options) {
1405
+ // src/utils/request.ts
1406
+ async function innerRequest(args, options) {
1407
+ const { url, info } = await makeRequestOptions(args, options);
1408
+ const response = await (options?.fetch ?? fetch)(url, info);
1409
+ const requestContext = { url, info };
1410
+ if (options?.retry_on_error !== false && response.status === 503) {
1411
+ return innerRequest(args, options);
1412
+ }
1413
+ if (!response.ok) {
1414
+ const contentType = response.headers.get("Content-Type");
1415
+ if (["application/json", "application/problem+json"].some((ct) => contentType?.startsWith(ct))) {
1416
+ const output = await response.json();
1417
+ if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
1418
+ throw new Error(
1419
+ `Server ${args.model} does not seem to support chat completion. Error: ${JSON.stringify(output.error)}`
1420
+ );
1421
+ }
1422
+ if (output.error || output.detail) {
1423
+ throw new Error(JSON.stringify(output.error ?? output.detail));
1424
+ } else {
1425
+ throw new Error(output);
1426
+ }
1427
+ }
1428
+ const message = contentType?.startsWith("text/plain;") ? await response.text() : void 0;
1429
+ throw new Error(message ?? "An error occurred while fetching the blob");
1430
+ }
1431
+ if (response.headers.get("Content-Type")?.startsWith("application/json")) {
1432
+ const data = await response.json();
1433
+ return { data, requestContext };
1434
+ }
1435
+ const blob = await response.blob();
1436
+ return { data: blob, requestContext };
1437
+ }
1438
+ async function* innerStreamingRequest(args, options) {
813
1439
  const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
814
1440
  const response = await (options?.fetch ?? fetch)(url, info);
815
1441
  if (options?.retry_on_error !== false && response.status === 503) {
816
- return yield* streamingRequest(args, options);
1442
+ return yield* innerStreamingRequest(args, options);
817
1443
  }
818
1444
  if (!response.ok) {
819
1445
  if (response.headers.get("Content-Type")?.startsWith("application/json")) {
@@ -827,6 +1453,9 @@ async function* streamingRequest(args, options) {
827
1453
  if (output.error && "message" in output.error && typeof output.error.message === "string") {
828
1454
  throw new Error(output.error.message);
829
1455
  }
1456
+ if (typeof output.message === "string") {
1457
+ throw new Error(output.message);
1458
+ }
830
1459
  }
831
1460
  throw new Error(`Server response contains error: ${response.status}`);
832
1461
  }
@@ -879,28 +1508,21 @@ async function* streamingRequest(args, options) {
879
1508
  }
880
1509
  }
881
1510
 
882
- // src/utils/pick.ts
883
- function pick(o, props) {
884
- return Object.assign(
885
- {},
886
- ...props.map((prop) => {
887
- if (o[prop] !== void 0) {
888
- return { [prop]: o[prop] };
889
- }
890
- })
1511
+ // src/tasks/custom/request.ts
1512
+ async function request(args, options) {
1513
+ console.warn(
1514
+ "The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
891
1515
  );
1516
+ const result = await innerRequest(args, options);
1517
+ return result.data;
892
1518
  }
893
1519
 
894
- // src/utils/typedInclude.ts
895
- function typedInclude(arr, v) {
896
- return arr.includes(v);
897
- }
898
-
899
- // src/utils/omit.ts
900
- function omit(o, props) {
901
- const propsArr = Array.isArray(props) ? props : [props];
902
- const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
903
- return pick(o, letsKeep);
1520
+ // src/tasks/custom/streamingRequest.ts
1521
+ async function* streamingRequest(args, options) {
1522
+ console.warn(
1523
+ "The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1524
+ );
1525
+ yield* innerStreamingRequest(args, options);
904
1526
  }
905
1527
 
906
1528
  // src/tasks/audio/utils.ts
@@ -913,16 +1535,24 @@ function preparePayload(args) {
913
1535
 
914
1536
  // src/tasks/audio/audioClassification.ts
915
1537
  async function audioClassification(args, options) {
1538
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
916
1539
  const payload = preparePayload(args);
917
- const res = await request(payload, {
1540
+ const { data: res } = await innerRequest(payload, {
918
1541
  ...options,
919
1542
  task: "audio-classification"
920
1543
  });
921
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
922
- if (!isValidOutput) {
923
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
924
- }
925
- return res;
1544
+ return providerHelper.getResponse(res);
1545
+ }
1546
+
1547
+ // src/tasks/audio/audioToAudio.ts
1548
+ async function audioToAudio(args, options) {
1549
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
1550
+ const payload = preparePayload(args);
1551
+ const { data: res } = await innerRequest(payload, {
1552
+ ...options,
1553
+ task: "audio-to-audio"
1554
+ });
1555
+ return providerHelper.getResponse(res);
926
1556
  }
927
1557
 
928
1558
  // src/utils/base64FromBytes.ts
@@ -940,8 +1570,9 @@ function base64FromBytes(arr) {
940
1570
 
941
1571
  // src/tasks/audio/automaticSpeechRecognition.ts
942
1572
  async function automaticSpeechRecognition(args, options) {
1573
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
943
1574
  const payload = await buildPayload(args);
944
- const res = await request(payload, {
1575
+ const { data: res } = await innerRequest(payload, {
945
1576
  ...options,
946
1577
  task: "automatic-speech-recognition"
947
1578
  });
@@ -949,9 +1580,8 @@ async function automaticSpeechRecognition(args, options) {
949
1580
  if (!isValidOutput) {
950
1581
  throw new InferenceOutputError("Expected {text: string}");
951
1582
  }
952
- return res;
1583
+ return providerHelper.getResponse(res);
953
1584
  }
954
- var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
955
1585
  async function buildPayload(args) {
956
1586
  if (args.provider === "fal-ai") {
957
1587
  const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : void 0;
@@ -976,57 +1606,17 @@ async function buildPayload(args) {
976
1606
  } else {
977
1607
  return preparePayload(args);
978
1608
  }
979
- }
980
-
981
- // src/tasks/audio/textToSpeech.ts
982
- async function textToSpeech(args, options) {
983
- const payload = args.provider === "replicate" ? {
984
- ...omit(args, ["inputs", "parameters"]),
985
- ...args.parameters,
986
- text: args.inputs
987
- } : args;
988
- const res = await request(payload, {
989
- ...options,
990
- task: "text-to-speech"
991
- });
992
- if (res instanceof Blob) {
993
- return res;
994
- }
995
- if (res && typeof res === "object") {
996
- if ("output" in res) {
997
- if (typeof res.output === "string") {
998
- const urlResponse = await fetch(res.output);
999
- const blob = await urlResponse.blob();
1000
- return blob;
1001
- } else if (Array.isArray(res.output)) {
1002
- const urlResponse = await fetch(res.output[0]);
1003
- const blob = await urlResponse.blob();
1004
- return blob;
1005
- }
1006
- }
1007
- }
1008
- throw new InferenceOutputError("Expected Blob or object with output");
1009
- }
1010
-
1011
- // src/tasks/audio/audioToAudio.ts
1012
- async function audioToAudio(args, options) {
1013
- const payload = preparePayload(args);
1014
- const res = await request(payload, {
1609
+ }
1610
+
1611
+ // src/tasks/audio/textToSpeech.ts
1612
+ async function textToSpeech(args, options) {
1613
+ const provider = args.provider ?? "hf-inference";
1614
+ const providerHelper = getProviderHelper(provider, "text-to-speech");
1615
+ const { data: res } = await innerRequest(args, {
1015
1616
  ...options,
1016
- task: "audio-to-audio"
1617
+ task: "text-to-speech"
1017
1618
  });
1018
- return validateOutput(res);
1019
- }
1020
- function validateOutput(output) {
1021
- if (!Array.isArray(output)) {
1022
- throw new InferenceOutputError("Expected Array");
1023
- }
1024
- if (!output.every((elem) => {
1025
- return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
1026
- })) {
1027
- throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
1028
- }
1029
- return output;
1619
+ return providerHelper.getResponse(res);
1030
1620
  }
1031
1621
 
1032
1622
  // src/tasks/cv/utils.ts
@@ -1036,183 +1626,95 @@ function preparePayload2(args) {
1036
1626
 
1037
1627
  // src/tasks/cv/imageClassification.ts
1038
1628
  async function imageClassification(args, options) {
1629
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
1039
1630
  const payload = preparePayload2(args);
1040
- const res = await request(payload, {
1631
+ const { data: res } = await innerRequest(payload, {
1041
1632
  ...options,
1042
1633
  task: "image-classification"
1043
1634
  });
1044
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
1045
- if (!isValidOutput) {
1046
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
1047
- }
1048
- return res;
1635
+ return providerHelper.getResponse(res);
1049
1636
  }
1050
1637
 
1051
1638
  // src/tasks/cv/imageSegmentation.ts
1052
1639
  async function imageSegmentation(args, options) {
1640
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
1053
1641
  const payload = preparePayload2(args);
1054
- const res = await request(payload, {
1642
+ const { data: res } = await innerRequest(payload, {
1055
1643
  ...options,
1056
1644
  task: "image-segmentation"
1057
1645
  });
1058
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
1059
- if (!isValidOutput) {
1060
- throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
1646
+ return providerHelper.getResponse(res);
1647
+ }
1648
+
1649
+ // src/tasks/cv/imageToImage.ts
1650
+ async function imageToImage(args, options) {
1651
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
1652
+ let reqArgs;
1653
+ if (!args.parameters) {
1654
+ reqArgs = {
1655
+ accessToken: args.accessToken,
1656
+ model: args.model,
1657
+ data: args.inputs
1658
+ };
1659
+ } else {
1660
+ reqArgs = {
1661
+ ...args,
1662
+ inputs: base64FromBytes(
1663
+ new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await args.inputs.arrayBuffer())
1664
+ )
1665
+ };
1061
1666
  }
1062
- return res;
1667
+ const { data: res } = await innerRequest(reqArgs, {
1668
+ ...options,
1669
+ task: "image-to-image"
1670
+ });
1671
+ return providerHelper.getResponse(res);
1063
1672
  }
1064
1673
 
1065
1674
  // src/tasks/cv/imageToText.ts
1066
1675
  async function imageToText(args, options) {
1676
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
1067
1677
  const payload = preparePayload2(args);
1068
- const res = (await request(payload, {
1678
+ const { data: res } = await innerRequest(payload, {
1069
1679
  ...options,
1070
1680
  task: "image-to-text"
1071
- }))?.[0];
1072
- if (typeof res?.generated_text !== "string") {
1073
- throw new InferenceOutputError("Expected {generated_text: string}");
1074
- }
1075
- return res;
1681
+ });
1682
+ return providerHelper.getResponse(res[0]);
1076
1683
  }
1077
1684
 
1078
1685
  // src/tasks/cv/objectDetection.ts
1079
1686
  async function objectDetection(args, options) {
1687
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
1080
1688
  const payload = preparePayload2(args);
1081
- const res = await request(payload, {
1689
+ const { data: res } = await innerRequest(payload, {
1082
1690
  ...options,
1083
1691
  task: "object-detection"
1084
1692
  });
1085
- const isValidOutput = Array.isArray(res) && res.every(
1086
- (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
1087
- );
1088
- if (!isValidOutput) {
1089
- throw new InferenceOutputError(
1090
- "Expected Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
1091
- );
1092
- }
1093
- return res;
1693
+ return providerHelper.getResponse(res);
1094
1694
  }
1095
1695
 
1096
1696
  // src/tasks/cv/textToImage.ts
1097
- function getResponseFormatArg(provider) {
1098
- switch (provider) {
1099
- case "fal-ai":
1100
- return { sync_mode: true };
1101
- case "nebius":
1102
- return { response_format: "b64_json" };
1103
- case "replicate":
1104
- return void 0;
1105
- case "together":
1106
- return { response_format: "base64" };
1107
- default:
1108
- return void 0;
1109
- }
1110
- }
1111
1697
  async function textToImage(args, options) {
1112
- const payload = !args.provider || args.provider === "hf-inference" || args.provider === "sambanova" ? args : {
1113
- ...omit(args, ["inputs", "parameters"]),
1114
- ...args.parameters,
1115
- ...getResponseFormatArg(args.provider),
1116
- prompt: args.inputs
1117
- };
1118
- const res = await request(payload, {
1698
+ const provider = args.provider ?? "hf-inference";
1699
+ const providerHelper = getProviderHelper(provider, "text-to-image");
1700
+ const { data: res } = await innerRequest(args, {
1119
1701
  ...options,
1120
1702
  task: "text-to-image"
1121
1703
  });
1122
- if (res && typeof res === "object") {
1123
- if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
1124
- return await pollBflResponse(res.polling_url, options?.outputType);
1125
- }
1126
- if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
1127
- if (options?.outputType === "url") {
1128
- return res.images[0].url;
1129
- } else {
1130
- const image = await fetch(res.images[0].url);
1131
- return await image.blob();
1132
- }
1133
- }
1134
- if (args.provider === "hyperbolic" && "images" in res && Array.isArray(res.images) && res.images[0] && typeof res.images[0].image === "string") {
1135
- if (options?.outputType === "url") {
1136
- return `data:image/jpeg;base64,${res.images[0].image}`;
1137
- }
1138
- const base64Response = await fetch(`data:image/jpeg;base64,${res.images[0].image}`);
1139
- return await base64Response.blob();
1140
- }
1141
- if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
1142
- const base64Data = res.data[0].b64_json;
1143
- if (options?.outputType === "url") {
1144
- return `data:image/jpeg;base64,${base64Data}`;
1145
- }
1146
- const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
1147
- return await base64Response.blob();
1148
- }
1149
- if ("output" in res && Array.isArray(res.output)) {
1150
- if (options?.outputType === "url") {
1151
- return res.output[0];
1152
- }
1153
- const urlResponse = await fetch(res.output[0]);
1154
- const blob = await urlResponse.blob();
1155
- return blob;
1156
- }
1157
- }
1158
- const isValidOutput = res && res instanceof Blob;
1159
- if (!isValidOutput) {
1160
- throw new InferenceOutputError("Expected Blob");
1161
- }
1162
- if (options?.outputType === "url") {
1163
- const b64 = await res.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
1164
- return `data:image/jpeg;base64,${b64}`;
1165
- }
1166
- return res;
1167
- }
1168
- async function pollBflResponse(url, outputType) {
1169
- const urlObj = new URL(url);
1170
- for (let step = 0; step < 5; step++) {
1171
- await delay(1e3);
1172
- console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
1173
- urlObj.searchParams.set("attempt", step.toString(10));
1174
- const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
1175
- if (!resp.ok) {
1176
- throw new InferenceOutputError("Failed to fetch result from black forest labs API");
1177
- }
1178
- const payload = await resp.json();
1179
- if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
1180
- if (outputType === "url") {
1181
- return payload.result.sample;
1182
- }
1183
- const image = await fetch(payload.result.sample);
1184
- return await image.blob();
1185
- }
1186
- }
1187
- throw new InferenceOutputError("Failed to fetch result from black forest labs API");
1704
+ const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-image" });
1705
+ return providerHelper.getResponse(res, url, info.headers, options?.outputType);
1188
1706
  }
1189
1707
 
1190
- // src/tasks/cv/imageToImage.ts
1191
- async function imageToImage(args, options) {
1192
- let reqArgs;
1193
- if (!args.parameters) {
1194
- reqArgs = {
1195
- accessToken: args.accessToken,
1196
- model: args.model,
1197
- data: args.inputs
1198
- };
1199
- } else {
1200
- reqArgs = {
1201
- ...args,
1202
- inputs: base64FromBytes(
1203
- new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await args.inputs.arrayBuffer())
1204
- )
1205
- };
1206
- }
1207
- const res = await request(reqArgs, {
1708
+ // src/tasks/cv/textToVideo.ts
1709
+ async function textToVideo(args, options) {
1710
+ const provider = args.provider ?? "hf-inference";
1711
+ const providerHelper = getProviderHelper(provider, "text-to-video");
1712
+ const { data: response } = await innerRequest(args, {
1208
1713
  ...options,
1209
- task: "image-to-image"
1714
+ task: "text-to-video"
1210
1715
  });
1211
- const isValidOutput = res && res instanceof Blob;
1212
- if (!isValidOutput) {
1213
- throw new InferenceOutputError("Expected Blob");
1214
- }
1215
- return res;
1716
+ const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-video" });
1717
+ return providerHelper.getResponse(response, url, info.headers);
1216
1718
  }
1217
1719
 
1218
1720
  // src/tasks/cv/zeroShotImageClassification.ts
@@ -1238,235 +1740,117 @@ async function preparePayload3(args) {
1238
1740
  }
1239
1741
  }
1240
1742
  async function zeroShotImageClassification(args, options) {
1743
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
1241
1744
  const payload = await preparePayload3(args);
1242
- const res = await request(payload, {
1745
+ const { data: res } = await innerRequest(payload, {
1243
1746
  ...options,
1244
1747
  task: "zero-shot-image-classification"
1245
1748
  });
1246
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
1247
- if (!isValidOutput) {
1248
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
1249
- }
1250
- return res;
1749
+ return providerHelper.getResponse(res);
1251
1750
  }
1252
1751
 
1253
- // src/tasks/cv/textToVideo.ts
1254
- var SUPPORTED_PROVIDERS = ["fal-ai", "novita", "replicate"];
1255
- async function textToVideo(args, options) {
1256
- if (!args.provider || !typedInclude(SUPPORTED_PROVIDERS, args.provider)) {
1257
- throw new Error(
1258
- `textToVideo inference is only supported for the following providers: ${SUPPORTED_PROVIDERS.join(", ")}`
1259
- );
1260
- }
1261
- const payload = args.provider === "fal-ai" || args.provider === "replicate" || args.provider === "novita" ? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs } : args;
1262
- const res = await request(payload, {
1752
+ // src/tasks/nlp/chatCompletion.ts
1753
+ async function chatCompletion(args, options) {
1754
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
1755
+ const { data: response } = await innerRequest(args, {
1263
1756
  ...options,
1264
- task: "text-to-video"
1757
+ task: "conversational"
1758
+ });
1759
+ return providerHelper.getResponse(response);
1760
+ }
1761
+
1762
+ // src/tasks/nlp/chatCompletionStream.ts
1763
+ async function* chatCompletionStream(args, options) {
1764
+ yield* innerStreamingRequest(args, {
1765
+ ...options,
1766
+ task: "conversational"
1265
1767
  });
1266
- if (args.provider === "fal-ai") {
1267
- const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-video" });
1268
- return await pollFalResponse(res, url, info.headers);
1269
- } else if (args.provider === "novita") {
1270
- const isValidOutput = typeof res === "object" && !!res && "video" in res && typeof res.video === "object" && !!res.video && "video_url" in res.video && typeof res.video.video_url === "string" && isUrl(res.video.video_url);
1271
- if (!isValidOutput) {
1272
- throw new InferenceOutputError("Expected { video: { video_url: string } }");
1273
- }
1274
- const urlResponse = await fetch(res.video.video_url);
1275
- return await urlResponse.blob();
1276
- } else {
1277
- const isValidOutput = typeof res === "object" && !!res && "output" in res && typeof res.output === "string" && isUrl(res.output);
1278
- if (!isValidOutput) {
1279
- throw new InferenceOutputError("Expected { output: string }");
1280
- }
1281
- const urlResponse = await fetch(res.output);
1282
- return await urlResponse.blob();
1283
- }
1284
1768
  }
1285
1769
 
1286
1770
  // src/tasks/nlp/featureExtraction.ts
1287
1771
  async function featureExtraction(args, options) {
1288
- const res = await request(args, {
1772
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction");
1773
+ const { data: res } = await innerRequest(args, {
1289
1774
  ...options,
1290
1775
  task: "feature-extraction"
1291
1776
  });
1292
- let isValidOutput = true;
1293
- const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
1294
- if (curDepth > maxDepth)
1295
- return false;
1296
- if (arr.every((x) => Array.isArray(x))) {
1297
- return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
1298
- } else {
1299
- return arr.every((x) => typeof x === "number");
1300
- }
1301
- };
1302
- isValidOutput = Array.isArray(res) && isNumArrayRec(res, 3, 0);
1303
- if (!isValidOutput) {
1304
- throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
1305
- }
1306
- return res;
1777
+ return providerHelper.getResponse(res);
1307
1778
  }
1308
1779
 
1309
1780
  // src/tasks/nlp/fillMask.ts
1310
1781
  async function fillMask(args, options) {
1311
- const res = await request(args, {
1782
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask");
1783
+ const { data: res } = await innerRequest(args, {
1312
1784
  ...options,
1313
1785
  task: "fill-mask"
1314
1786
  });
1315
- const isValidOutput = Array.isArray(res) && res.every(
1316
- (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
1317
- );
1318
- if (!isValidOutput) {
1319
- throw new InferenceOutputError(
1320
- "Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
1321
- );
1322
- }
1323
- return res;
1787
+ return providerHelper.getResponse(res);
1324
1788
  }
1325
1789
 
1326
1790
  // src/tasks/nlp/questionAnswering.ts
1327
1791
  async function questionAnswering(args, options) {
1328
- const res = await request(args, {
1792
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering");
1793
+ const { data: res } = await innerRequest(args, {
1329
1794
  ...options,
1330
1795
  task: "question-answering"
1331
1796
  });
1332
- const isValidOutput = Array.isArray(res) ? res.every(
1333
- (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
1334
- ) : typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";
1335
- if (!isValidOutput) {
1336
- throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
1337
- }
1338
- return Array.isArray(res) ? res[0] : res;
1797
+ return providerHelper.getResponse(res);
1339
1798
  }
1340
1799
 
1341
1800
  // src/tasks/nlp/sentenceSimilarity.ts
1342
1801
  async function sentenceSimilarity(args, options) {
1343
- const res = await request(prepareInput(args), {
1802
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity");
1803
+ const { data: res } = await innerRequest(args, {
1344
1804
  ...options,
1345
1805
  task: "sentence-similarity"
1346
1806
  });
1347
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1348
- if (!isValidOutput) {
1349
- throw new InferenceOutputError("Expected number[]");
1350
- }
1351
- return res;
1352
- }
1353
- function prepareInput(args) {
1354
- return {
1355
- ...omit(args, ["inputs", "parameters"]),
1356
- inputs: { ...omit(args.inputs, "sourceSentence") },
1357
- parameters: { source_sentence: args.inputs.sourceSentence, ...args.parameters }
1358
- };
1807
+ return providerHelper.getResponse(res);
1359
1808
  }
1360
1809
 
1361
1810
  // src/tasks/nlp/summarization.ts
1362
1811
  async function summarization(args, options) {
1363
- const res = await request(args, {
1812
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization");
1813
+ const { data: res } = await innerRequest(args, {
1364
1814
  ...options,
1365
1815
  task: "summarization"
1366
1816
  });
1367
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string");
1368
- if (!isValidOutput) {
1369
- throw new InferenceOutputError("Expected Array<{summary_text: string}>");
1370
- }
1371
- return res?.[0];
1817
+ return providerHelper.getResponse(res);
1372
1818
  }
1373
1819
 
1374
1820
  // src/tasks/nlp/tableQuestionAnswering.ts
1375
1821
  async function tableQuestionAnswering(args, options) {
1376
- const res = await request(args, {
1822
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
1823
+ const { data: res } = await innerRequest(args, {
1377
1824
  ...options,
1378
1825
  task: "table-question-answering"
1379
1826
  });
1380
- const isValidOutput = Array.isArray(res) ? res.every((elem) => validate(elem)) : validate(res);
1381
- if (!isValidOutput) {
1382
- throw new InferenceOutputError(
1383
- "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
1384
- );
1385
- }
1386
- return Array.isArray(res) ? res[0] : res;
1387
- }
1388
- function validate(elem) {
1389
- return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
1390
- (coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
1391
- );
1827
+ return providerHelper.getResponse(res);
1392
1828
  }
1393
1829
 
1394
1830
  // src/tasks/nlp/textClassification.ts
1395
1831
  async function textClassification(args, options) {
1396
- const res = (await request(args, {
1832
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
1833
+ const { data: res } = await innerRequest(args, {
1397
1834
  ...options,
1398
1835
  task: "text-classification"
1399
- }))?.[0];
1400
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number");
1401
- if (!isValidOutput) {
1402
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
1403
- }
1404
- return res;
1405
- }
1406
-
1407
- // src/utils/toArray.ts
1408
- function toArray(obj) {
1409
- if (Array.isArray(obj)) {
1410
- return obj;
1411
- }
1412
- return [obj];
1836
+ });
1837
+ return providerHelper.getResponse(res);
1413
1838
  }
1414
1839
 
1415
1840
  // src/tasks/nlp/textGeneration.ts
1416
1841
  async function textGeneration(args, options) {
1417
- if (args.provider === "together") {
1418
- args.prompt = args.inputs;
1419
- const raw = await request(args, {
1420
- ...options,
1421
- task: "text-generation"
1422
- });
1423
- const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
1424
- if (!isValidOutput) {
1425
- throw new InferenceOutputError("Expected ChatCompletionOutput");
1426
- }
1427
- const completion = raw.choices[0];
1428
- return {
1429
- generated_text: completion.text
1430
- };
1431
- } else if (args.provider === "hyperbolic") {
1432
- const payload = {
1433
- messages: [{ content: args.inputs, role: "user" }],
1434
- ...args.parameters ? {
1435
- max_tokens: args.parameters.max_new_tokens,
1436
- ...omit(args.parameters, "max_new_tokens")
1437
- } : void 0,
1438
- ...omit(args, ["inputs", "parameters"])
1439
- };
1440
- const raw = await request(payload, {
1441
- ...options,
1442
- task: "text-generation"
1443
- });
1444
- const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
1445
- if (!isValidOutput) {
1446
- throw new InferenceOutputError("Expected ChatCompletionOutput");
1447
- }
1448
- const completion = raw.choices[0];
1449
- return {
1450
- generated_text: completion.message.content
1451
- };
1452
- } else {
1453
- const res = toArray(
1454
- await request(args, {
1455
- ...options,
1456
- task: "text-generation"
1457
- })
1458
- );
1459
- const isValidOutput = Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
1460
- if (!isValidOutput) {
1461
- throw new InferenceOutputError("Expected Array<{generated_text: string}>");
1462
- }
1463
- return res?.[0];
1464
- }
1842
+ const provider = args.provider ?? "hf-inference";
1843
+ const providerHelper = getProviderHelper(provider, "text-generation");
1844
+ const { data: response } = await innerRequest(args, {
1845
+ ...options,
1846
+ task: "text-generation"
1847
+ });
1848
+ return providerHelper.getResponse(response);
1465
1849
  }
1466
1850
 
1467
1851
  // src/tasks/nlp/textGenerationStream.ts
1468
1852
  async function* textGenerationStream(args, options) {
1469
- yield* streamingRequest(args, {
1853
+ yield* innerStreamingRequest(args, {
1470
1854
  ...options,
1471
1855
  task: "text-generation"
1472
1856
  });
@@ -1474,79 +1858,37 @@ async function* textGenerationStream(args, options) {
1474
1858
 
1475
1859
  // src/tasks/nlp/tokenClassification.ts
1476
1860
  async function tokenClassification(args, options) {
1477
- const res = toArray(
1478
- await request(args, {
1479
- ...options,
1480
- task: "token-classification"
1481
- })
1482
- );
1483
- const isValidOutput = Array.isArray(res) && res.every(
1484
- (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
1485
- );
1486
- if (!isValidOutput) {
1487
- throw new InferenceOutputError(
1488
- "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
1489
- );
1490
- }
1491
- return res;
1861
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification");
1862
+ const { data: res } = await innerRequest(args, {
1863
+ ...options,
1864
+ task: "token-classification"
1865
+ });
1866
+ return providerHelper.getResponse(res);
1492
1867
  }
1493
1868
 
1494
1869
  // src/tasks/nlp/translation.ts
1495
1870
  async function translation(args, options) {
1496
- const res = await request(args, {
1871
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation");
1872
+ const { data: res } = await innerRequest(args, {
1497
1873
  ...options,
1498
1874
  task: "translation"
1499
1875
  });
1500
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");
1501
- if (!isValidOutput) {
1502
- throw new InferenceOutputError("Expected type Array<{translation_text: string}>");
1503
- }
1504
- return res?.length === 1 ? res?.[0] : res;
1876
+ return providerHelper.getResponse(res);
1505
1877
  }
1506
1878
 
1507
1879
  // src/tasks/nlp/zeroShotClassification.ts
1508
1880
  async function zeroShotClassification(args, options) {
1509
- const res = toArray(
1510
- await request(args, {
1511
- ...options,
1512
- task: "zero-shot-classification"
1513
- })
1514
- );
1515
- const isValidOutput = Array.isArray(res) && res.every(
1516
- (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
1517
- );
1518
- if (!isValidOutput) {
1519
- throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
1520
- }
1521
- return res;
1522
- }
1523
-
1524
- // src/tasks/nlp/chatCompletion.ts
1525
- async function chatCompletion(args, options) {
1526
- const res = await request(args, {
1527
- ...options,
1528
- task: "text-generation",
1529
- chatCompletion: true
1530
- });
1531
- const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
1532
- (res.system_fingerprint === void 0 || res.system_fingerprint === null || typeof res.system_fingerprint === "string") && typeof res?.usage === "object";
1533
- if (!isValidOutput) {
1534
- throw new InferenceOutputError("Expected ChatCompletionOutput");
1535
- }
1536
- return res;
1537
- }
1538
-
1539
- // src/tasks/nlp/chatCompletionStream.ts
1540
- async function* chatCompletionStream(args, options) {
1541
- yield* streamingRequest(args, {
1881
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification");
1882
+ const { data: res } = await innerRequest(args, {
1542
1883
  ...options,
1543
- task: "text-generation",
1544
- chatCompletion: true
1884
+ task: "zero-shot-classification"
1545
1885
  });
1886
+ return providerHelper.getResponse(res);
1546
1887
  }
1547
1888
 
1548
1889
  // src/tasks/multimodal/documentQuestionAnswering.ts
1549
1890
  async function documentQuestionAnswering(args, options) {
1891
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering");
1550
1892
  const reqArgs = {
1551
1893
  ...args,
1552
1894
  inputs: {
@@ -1555,23 +1897,19 @@ async function documentQuestionAnswering(args, options) {
1555
1897
  image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer()))
1556
1898
  }
1557
1899
  };
1558
- const res = toArray(
1559
- await request(reqArgs, {
1900
+ const { data: res } = await innerRequest(
1901
+ reqArgs,
1902
+ {
1560
1903
  ...options,
1561
1904
  task: "document-question-answering"
1562
- })
1563
- );
1564
- const isValidOutput = Array.isArray(res) && res.every(
1565
- (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
1905
+ }
1566
1906
  );
1567
- if (!isValidOutput) {
1568
- throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>");
1569
- }
1570
- return res[0];
1907
+ return providerHelper.getResponse(res);
1571
1908
  }
1572
1909
 
1573
1910
  // src/tasks/multimodal/visualQuestionAnswering.ts
1574
1911
  async function visualQuestionAnswering(args, options) {
1912
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "visual-question-answering");
1575
1913
  const reqArgs = {
1576
1914
  ...args,
1577
1915
  inputs: {
@@ -1580,43 +1918,31 @@ async function visualQuestionAnswering(args, options) {
1580
1918
  image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer()))
1581
1919
  }
1582
1920
  };
1583
- const res = await request(reqArgs, {
1921
+ const { data: res } = await innerRequest(reqArgs, {
1584
1922
  ...options,
1585
1923
  task: "visual-question-answering"
1586
1924
  });
1587
- const isValidOutput = Array.isArray(res) && res.every(
1588
- (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
1589
- );
1590
- if (!isValidOutput) {
1591
- throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
1592
- }
1593
- return res[0];
1925
+ return providerHelper.getResponse(res);
1594
1926
  }
1595
1927
 
1596
- // src/tasks/tabular/tabularRegression.ts
1597
- async function tabularRegression(args, options) {
1598
- const res = await request(args, {
1928
+ // src/tasks/tabular/tabularClassification.ts
1929
+ async function tabularClassification(args, options) {
1930
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification");
1931
+ const { data: res } = await innerRequest(args, {
1599
1932
  ...options,
1600
- task: "tabular-regression"
1933
+ task: "tabular-classification"
1601
1934
  });
1602
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1603
- if (!isValidOutput) {
1604
- throw new InferenceOutputError("Expected number[]");
1605
- }
1606
- return res;
1935
+ return providerHelper.getResponse(res);
1607
1936
  }
1608
1937
 
1609
- // src/tasks/tabular/tabularClassification.ts
1610
- async function tabularClassification(args, options) {
1611
- const res = await request(args, {
1938
+ // src/tasks/tabular/tabularRegression.ts
1939
+ async function tabularRegression(args, options) {
1940
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression");
1941
+ const { data: res } = await innerRequest(args, {
1612
1942
  ...options,
1613
- task: "tabular-classification"
1943
+ task: "tabular-regression"
1614
1944
  });
1615
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
1616
- if (!isValidOutput) {
1617
- throw new InferenceOutputError("Expected number[]");
1618
- }
1619
- return res;
1945
+ return providerHelper.getResponse(res);
1620
1946
  }
1621
1947
 
1622
1948
  // src/InferenceClient.ts
@@ -1685,11 +2011,11 @@ __export(snippets_exports, {
1685
2011
  });
1686
2012
 
1687
2013
  // src/snippets/getInferenceSnippets.ts
2014
+ import { Template } from "@huggingface/jinja";
1688
2015
  import {
1689
- inferenceSnippetLanguages,
1690
- getModelInputSnippet
2016
+ getModelInputSnippet,
2017
+ inferenceSnippetLanguages
1691
2018
  } from "@huggingface/tasks";
1692
- import { Template } from "@huggingface/jinja";
1693
2019
 
1694
2020
  // src/snippets/templates.exported.ts
1695
2021
  var templates = {
@@ -1699,7 +2025,7 @@ var templates = {
1699
2025
  "basicAudio": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "audio/flac"\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
1700
2026
  "basicImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "image/jpeg"\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
1701
2027
  "textToAudio": '{% if model.library_name == "transformers" %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ',
1702
- "textToImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n // Use image\n});',
2028
+ "textToImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\n\nquery({ {{ providerInputs.asTsString }} }).then((response) => {\n // Use image\n});',
1703
2029
  "zeroShotClassification": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({\n inputs: {{ providerInputs.asObj.inputs }},\n parameters: { candidate_labels: ["refund", "legal", "faq"] }\n}).then((response) => {\n console.log(JSON.stringify(response));\n});'
1704
2030
  },
1705
2031
  "huggingface.js": {
@@ -1732,7 +2058,7 @@ const image = await client.textToVideo({
1732
2058
  },
1733
2059
  "openai": {
1734
2060
  "conversational": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\nconst chatCompletion = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);',
1735
- "conversationalStream": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\nlet out = "";\n\nconst stream = await client.chat.completions.create({\n provider: "{{ provider }}",\n model: "{{ model.id }}",\n{{ inputs.asTsString }}\n});\n\nfor await (const chunk of stream) {\n if (chunk.choices && chunk.choices.length > 0) {\n const newContent = chunk.choices[0].delta.content;\n out += newContent;\n console.log(newContent);\n } \n}'
2061
+ "conversationalStream": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\nconst stream = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n stream: true,\n});\n\nfor await (const chunk of stream) {\n process.stdout.write(chunk.choices[0]?.delta?.content || "");\n}'
1736
2062
  }
1737
2063
  },
1738
2064
  "python": {
@@ -1864,15 +2190,23 @@ var HF_JS_METHODS = {
1864
2190
  };
1865
2191
  var snippetGenerator = (templateName, inputPreparationFn) => {
1866
2192
  return (model, accessToken, provider, providerModelId, opts) => {
2193
+ let task = model.pipeline_tag;
1867
2194
  if (model.pipeline_tag && ["text-generation", "image-text-to-text"].includes(model.pipeline_tag) && model.tags.includes("conversational")) {
1868
2195
  templateName = opts?.streaming ? "conversationalStream" : "conversational";
1869
2196
  inputPreparationFn = prepareConversationalInput;
2197
+ task = "conversational";
1870
2198
  }
1871
2199
  const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
1872
2200
  const request2 = makeRequestOptionsFromResolvedModel(
1873
2201
  providerModelId ?? model.id,
1874
- { accessToken, provider, ...inputs },
1875
- { chatCompletion: templateName.includes("conversational"), task: model.pipeline_tag }
2202
+ {
2203
+ accessToken,
2204
+ provider,
2205
+ ...inputs
2206
+ },
2207
+ {
2208
+ task
2209
+ }
1876
2210
  );
1877
2211
  let providerInputs = inputs;
1878
2212
  const bodyAsObj = request2.info.body;
@@ -1959,7 +2293,7 @@ var prepareConversationalInput = (model, opts) => {
1959
2293
  return {
1960
2294
  messages: opts?.messages ?? getModelInputSnippet(model),
1961
2295
  ...opts?.temperature ? { temperature: opts?.temperature } : void 0,
1962
- max_tokens: opts?.max_tokens ?? 500,
2296
+ max_tokens: opts?.max_tokens ?? 512,
1963
2297
  ...opts?.top_p ? { top_p: opts?.top_p } : void 0
1964
2298
  };
1965
2299
  };