@huggingface/inference 3.7.0 → 3.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. package/dist/index.cjs +1369 -941
  2. package/dist/index.js +1371 -943
  3. package/dist/src/lib/getInferenceProviderMapping.d.ts +21 -0
  4. package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -0
  5. package/dist/src/lib/getProviderHelper.d.ts +37 -0
  6. package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
  7. package/dist/src/lib/makeRequestOptions.d.ts +5 -5
  8. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  9. package/dist/src/providers/black-forest-labs.d.ts +14 -18
  10. package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
  11. package/dist/src/providers/cerebras.d.ts +4 -2
  12. package/dist/src/providers/cerebras.d.ts.map +1 -1
  13. package/dist/src/providers/cohere.d.ts +5 -2
  14. package/dist/src/providers/cohere.d.ts.map +1 -1
  15. package/dist/src/providers/consts.d.ts +2 -3
  16. package/dist/src/providers/consts.d.ts.map +1 -1
  17. package/dist/src/providers/fal-ai.d.ts +50 -3
  18. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  19. package/dist/src/providers/fireworks-ai.d.ts +5 -2
  20. package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
  21. package/dist/src/providers/hf-inference.d.ts +126 -2
  22. package/dist/src/providers/hf-inference.d.ts.map +1 -1
  23. package/dist/src/providers/hyperbolic.d.ts +31 -2
  24. package/dist/src/providers/hyperbolic.d.ts.map +1 -1
  25. package/dist/src/providers/nebius.d.ts +20 -18
  26. package/dist/src/providers/nebius.d.ts.map +1 -1
  27. package/dist/src/providers/novita.d.ts +21 -18
  28. package/dist/src/providers/novita.d.ts.map +1 -1
  29. package/dist/src/providers/openai.d.ts +4 -2
  30. package/dist/src/providers/openai.d.ts.map +1 -1
  31. package/dist/src/providers/providerHelper.d.ts +182 -0
  32. package/dist/src/providers/providerHelper.d.ts.map +1 -0
  33. package/dist/src/providers/replicate.d.ts +23 -19
  34. package/dist/src/providers/replicate.d.ts.map +1 -1
  35. package/dist/src/providers/sambanova.d.ts +4 -2
  36. package/dist/src/providers/sambanova.d.ts.map +1 -1
  37. package/dist/src/providers/together.d.ts +32 -2
  38. package/dist/src/providers/together.d.ts.map +1 -1
  39. package/dist/src/snippets/getInferenceSnippets.d.ts +2 -1
  40. package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
  41. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  42. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  43. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  44. package/dist/src/tasks/audio/utils.d.ts +2 -1
  45. package/dist/src/tasks/audio/utils.d.ts.map +1 -1
  46. package/dist/src/tasks/custom/request.d.ts +0 -2
  47. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  48. package/dist/src/tasks/custom/streamingRequest.d.ts +0 -2
  49. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  50. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  51. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  52. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  53. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  54. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  55. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  56. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
  57. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  58. package/dist/src/tasks/index.d.ts +6 -6
  59. package/dist/src/tasks/index.d.ts.map +1 -1
  60. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  61. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  62. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  63. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
  64. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  65. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  66. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  67. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  68. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  69. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  70. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  71. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  72. package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
  73. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  74. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  75. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  76. package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
  77. package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
  78. package/dist/src/types.d.ts +5 -13
  79. package/dist/src/types.d.ts.map +1 -1
  80. package/dist/src/utils/request.d.ts +3 -2
  81. package/dist/src/utils/request.d.ts.map +1 -1
  82. package/package.json +3 -3
  83. package/src/lib/getInferenceProviderMapping.ts +96 -0
  84. package/src/lib/getProviderHelper.ts +270 -0
  85. package/src/lib/makeRequestOptions.ts +78 -97
  86. package/src/providers/black-forest-labs.ts +73 -22
  87. package/src/providers/cerebras.ts +6 -27
  88. package/src/providers/cohere.ts +9 -28
  89. package/src/providers/consts.ts +5 -2
  90. package/src/providers/fal-ai.ts +224 -77
  91. package/src/providers/fireworks-ai.ts +8 -29
  92. package/src/providers/hf-inference.ts +557 -34
  93. package/src/providers/hyperbolic.ts +107 -29
  94. package/src/providers/nebius.ts +65 -29
  95. package/src/providers/novita.ts +68 -32
  96. package/src/providers/openai.ts +6 -32
  97. package/src/providers/providerHelper.ts +354 -0
  98. package/src/providers/replicate.ts +124 -34
  99. package/src/providers/sambanova.ts +5 -30
  100. package/src/providers/together.ts +92 -28
  101. package/src/snippets/getInferenceSnippets.ts +39 -14
  102. package/src/snippets/templates.exported.ts +25 -25
  103. package/src/tasks/audio/audioClassification.ts +5 -8
  104. package/src/tasks/audio/audioToAudio.ts +4 -27
  105. package/src/tasks/audio/automaticSpeechRecognition.ts +5 -4
  106. package/src/tasks/audio/textToSpeech.ts +5 -29
  107. package/src/tasks/audio/utils.ts +2 -1
  108. package/src/tasks/custom/request.ts +3 -3
  109. package/src/tasks/custom/streamingRequest.ts +4 -3
  110. package/src/tasks/cv/imageClassification.ts +4 -8
  111. package/src/tasks/cv/imageSegmentation.ts +4 -9
  112. package/src/tasks/cv/imageToImage.ts +4 -7
  113. package/src/tasks/cv/imageToText.ts +4 -7
  114. package/src/tasks/cv/objectDetection.ts +4 -19
  115. package/src/tasks/cv/textToImage.ts +9 -137
  116. package/src/tasks/cv/textToVideo.ts +17 -64
  117. package/src/tasks/cv/zeroShotImageClassification.ts +4 -8
  118. package/src/tasks/index.ts +6 -6
  119. package/src/tasks/multimodal/documentQuestionAnswering.ts +4 -19
  120. package/src/tasks/multimodal/visualQuestionAnswering.ts +4 -12
  121. package/src/tasks/nlp/chatCompletion.ts +5 -20
  122. package/src/tasks/nlp/chatCompletionStream.ts +4 -3
  123. package/src/tasks/nlp/featureExtraction.ts +4 -19
  124. package/src/tasks/nlp/fillMask.ts +4 -17
  125. package/src/tasks/nlp/questionAnswering.ts +11 -26
  126. package/src/tasks/nlp/sentenceSimilarity.ts +4 -8
  127. package/src/tasks/nlp/summarization.ts +4 -7
  128. package/src/tasks/nlp/tableQuestionAnswering.ts +10 -30
  129. package/src/tasks/nlp/textClassification.ts +4 -9
  130. package/src/tasks/nlp/textGeneration.ts +11 -79
  131. package/src/tasks/nlp/textGenerationStream.ts +3 -1
  132. package/src/tasks/nlp/tokenClassification.ts +11 -23
  133. package/src/tasks/nlp/translation.ts +4 -7
  134. package/src/tasks/nlp/zeroShotClassification.ts +11 -21
  135. package/src/tasks/tabular/tabularClassification.ts +4 -7
  136. package/src/tasks/tabular/tabularRegression.ts +4 -7
  137. package/src/types.ts +5 -14
  138. package/src/utils/request.ts +7 -4
  139. package/dist/src/lib/getProviderModelId.d.ts +0 -10
  140. package/dist/src/lib/getProviderModelId.d.ts.map +0 -1
  141. package/src/lib/getProviderModelId.ts +0 -74
@@ -0,0 +1,354 @@
1
+ import type {
2
+ AudioClassificationInput,
3
+ AudioClassificationOutput,
4
+ AutomaticSpeechRecognitionInput,
5
+ AutomaticSpeechRecognitionOutput,
6
+ ChatCompletionInput,
7
+ ChatCompletionOutput,
8
+ DocumentQuestionAnsweringInput,
9
+ DocumentQuestionAnsweringOutput,
10
+ FeatureExtractionInput,
11
+ FeatureExtractionOutput,
12
+ FillMaskInput,
13
+ FillMaskOutput,
14
+ ImageClassificationInput,
15
+ ImageClassificationOutput,
16
+ ImageSegmentationInput,
17
+ ImageSegmentationOutput,
18
+ ImageToImageInput,
19
+ ImageToTextInput,
20
+ ImageToTextOutput,
21
+ ObjectDetectionInput,
22
+ ObjectDetectionOutput,
23
+ QuestionAnsweringInput,
24
+ QuestionAnsweringOutput,
25
+ SentenceSimilarityInput,
26
+ SentenceSimilarityOutput,
27
+ SummarizationInput,
28
+ SummarizationOutput,
29
+ TableQuestionAnsweringInput,
30
+ TableQuestionAnsweringOutput,
31
+ TextClassificationOutput,
32
+ TextGenerationInput,
33
+ TextGenerationOutput,
34
+ TextToImageInput,
35
+ TextToSpeechInput,
36
+ TextToVideoInput,
37
+ TokenClassificationInput,
38
+ TokenClassificationOutput,
39
+ TranslationInput,
40
+ TranslationOutput,
41
+ VisualQuestionAnsweringInput,
42
+ VisualQuestionAnsweringOutput,
43
+ ZeroShotClassificationInput,
44
+ ZeroShotClassificationOutput,
45
+ ZeroShotImageClassificationInput,
46
+ ZeroShotImageClassificationOutput,
47
+ } from "@huggingface/tasks";
48
+ import { HF_ROUTER_URL } from "../config";
49
+ import { InferenceOutputError } from "../lib/InferenceOutputError";
50
+ import type { AudioToAudioOutput } from "../tasks/audio/audioToAudio";
51
+ import type { BaseArgs, BodyParams, HeaderParams, InferenceProvider, UrlParams } from "../types";
52
+ import { toArray } from "../utils/toArray";
53
+
54
+ /**
55
+ * Base class for task-specific provider helpers
56
+ */
57
+ export abstract class TaskProviderHelper {
58
+ constructor(
59
+ private provider: InferenceProvider,
60
+ private baseUrl: string,
61
+ readonly clientSideRoutingOnly: boolean = false
62
+ ) {}
63
+
64
+ /**
65
+ * Return the response in the expected format.
66
+ * Needs to be implemented in the subclasses.
67
+ */
68
+ abstract getResponse(
69
+ response: unknown,
70
+ url?: string,
71
+ headers?: HeadersInit,
72
+ outputType?: "url" | "blob"
73
+ ): Promise<unknown>;
74
+
75
+ /**
76
+ * Prepare the route for the request
77
+ * Needs to be implemented in the subclasses.
78
+ */
79
+ abstract makeRoute(params: UrlParams): string;
80
+ /**
81
+ * Prepare the payload for the request
82
+ * Needs to be implemented in the subclasses.
83
+ */
84
+ abstract preparePayload(params: BodyParams): unknown;
85
+
86
+ /**
87
+ * Prepare the base URL for the request
88
+ */
89
+ makeBaseUrl(params: UrlParams): string {
90
+ return params.authMethod !== "provider-key" ? `${HF_ROUTER_URL}/${this.provider}` : this.baseUrl;
91
+ }
92
+
93
+ /**
94
+ * Prepare the body for the request
95
+ */
96
+ makeBody(params: BodyParams): BodyInit {
97
+ if ("data" in params.args && !!params.args.data) {
98
+ return params.args.data as BodyInit;
99
+ }
100
+ return JSON.stringify(this.preparePayload(params));
101
+ }
102
+
103
+ /**
104
+ * Prepare the URL for the request
105
+ */
106
+ makeUrl(params: UrlParams): string {
107
+ const baseUrl = this.makeBaseUrl(params);
108
+ const route = this.makeRoute(params).replace(/^\/+/, "");
109
+ return `${baseUrl}/${route}`;
110
+ }
111
+
112
+ /**
113
+ * Prepare the headers for the request
114
+ */
115
+ prepareHeaders(params: HeaderParams, isBinary: boolean): Record<string, string> {
116
+ const headers: Record<string, string> = { Authorization: `Bearer ${params.accessToken}` };
117
+ if (!isBinary) {
118
+ headers["Content-Type"] = "application/json";
119
+ }
120
+ return headers;
121
+ }
122
+ }
123
+
124
+ // PER-TASK PROVIDER HELPER INTERFACES
125
+
126
+ // CV Tasks
127
+ export interface TextToImageTaskHelper {
128
+ getResponse(
129
+ response: unknown,
130
+ url?: string,
131
+ headers?: HeadersInit,
132
+ outputType?: "url" | "blob"
133
+ ): Promise<string | Blob>;
134
+ preparePayload(params: BodyParams<TextToImageInput & BaseArgs>): Record<string, unknown>;
135
+ }
136
+
137
+ export interface TextToVideoTaskHelper {
138
+ getResponse(response: unknown, url?: string, headers?: Record<string, string>): Promise<Blob>;
139
+ preparePayload(params: BodyParams<TextToVideoInput & BaseArgs>): Record<string, unknown>;
140
+ }
141
+
142
+ export interface ImageToImageTaskHelper {
143
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<Blob>;
144
+ preparePayload(params: BodyParams<ImageToImageInput & BaseArgs>): Record<string, unknown>;
145
+ }
146
+
147
+ export interface ImageSegmentationTaskHelper {
148
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ImageSegmentationOutput>;
149
+ preparePayload(params: BodyParams<ImageSegmentationInput & BaseArgs>): Record<string, unknown> | BodyInit;
150
+ }
151
+
152
+ export interface ImageClassificationTaskHelper {
153
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ImageClassificationOutput>;
154
+ preparePayload(params: BodyParams<ImageClassificationInput & BaseArgs>): Record<string, unknown> | BodyInit;
155
+ }
156
+
157
+ export interface ObjectDetectionTaskHelper {
158
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ObjectDetectionOutput>;
159
+ preparePayload(params: BodyParams<ObjectDetectionInput & BaseArgs>): Record<string, unknown> | BodyInit;
160
+ }
161
+
162
+ export interface ImageToTextTaskHelper {
163
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ImageToTextOutput>;
164
+ preparePayload(params: BodyParams<ImageToTextInput & BaseArgs>): Record<string, unknown> | BodyInit;
165
+ }
166
+
167
+ export interface ZeroShotImageClassificationTaskHelper {
168
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ZeroShotImageClassificationOutput>;
169
+ preparePayload(params: BodyParams<ZeroShotImageClassificationInput & BaseArgs>): Record<string, unknown> | BodyInit;
170
+ }
171
+
172
+ // NLP Tasks
173
+ export interface TextGenerationTaskHelper {
174
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<TextGenerationOutput>;
175
+ preparePayload(params: BodyParams<TextGenerationInput & BaseArgs>): Record<string, unknown>;
176
+ }
177
+
178
+ export interface ConversationalTaskHelper {
179
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ChatCompletionOutput>;
180
+ preparePayload(params: BodyParams<ChatCompletionInput & BaseArgs>): Record<string, unknown>;
181
+ }
182
+
183
+ export interface TextClassificationTaskHelper {
184
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<TextClassificationOutput>;
185
+ preparePayload(params: BodyParams<ZeroShotClassificationInput & BaseArgs>): Record<string, unknown>;
186
+ }
187
+
188
+ export interface QuestionAnsweringTaskHelper {
189
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<QuestionAnsweringOutput[number]>;
190
+ preparePayload(params: BodyParams<QuestionAnsweringInput & BaseArgs>): Record<string, unknown>;
191
+ }
192
+
193
+ export interface FillMaskTaskHelper {
194
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<FillMaskOutput>;
195
+ preparePayload(params: BodyParams<FillMaskInput & BaseArgs>): Record<string, unknown>;
196
+ }
197
+
198
+ export interface ZeroShotClassificationTaskHelper {
199
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ZeroShotClassificationOutput>;
200
+ preparePayload(params: BodyParams<ZeroShotClassificationInput & BaseArgs>): Record<string, unknown>;
201
+ }
202
+
203
+ export interface SentenceSimilarityTaskHelper {
204
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<SentenceSimilarityOutput>;
205
+ preparePayload(params: BodyParams<SentenceSimilarityInput & BaseArgs>): Record<string, unknown>;
206
+ }
207
+
208
+ export interface TableQuestionAnsweringTaskHelper {
209
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<TableQuestionAnsweringOutput[number]>;
210
+ preparePayload(params: BodyParams<TableQuestionAnsweringInput & BaseArgs>): Record<string, unknown>;
211
+ }
212
+
213
+ export interface TokenClassificationTaskHelper {
214
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<TokenClassificationOutput>;
215
+ preparePayload(params: BodyParams<TokenClassificationInput & BaseArgs>): Record<string, unknown>;
216
+ }
217
+
218
+ export interface TranslationTaskHelper {
219
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<TranslationOutput>;
220
+ preparePayload(params: BodyParams<TranslationInput & BaseArgs>): Record<string, unknown>;
221
+ }
222
+
223
+ export interface SummarizationTaskHelper {
224
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<SummarizationOutput>;
225
+ preparePayload(params: BodyParams<SummarizationInput & BaseArgs>): Record<string, unknown>;
226
+ }
227
+
228
+ // Audio Tasks
229
+ export interface TextToSpeechTaskHelper {
230
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<Blob>;
231
+ preparePayload(params: BodyParams<TextToSpeechInput & BaseArgs>): Record<string, unknown>;
232
+ }
233
+
234
+ export interface TextToAudioTaskHelper {
235
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<Blob>;
236
+ preparePayload(params: BodyParams<Record<string, unknown> & BaseArgs>): Record<string, unknown>;
237
+ }
238
+
239
+ export interface AudioToAudioTaskHelper {
240
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<AudioToAudioOutput[]>;
241
+ preparePayload(
242
+ params: BodyParams<BaseArgs & { inputs: Blob } & Record<string, unknown>>
243
+ ): Record<string, unknown> | BodyInit;
244
+ }
245
+ export interface AutomaticSpeechRecognitionTaskHelper {
246
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<AutomaticSpeechRecognitionOutput>;
247
+ preparePayload(params: BodyParams<AutomaticSpeechRecognitionInput & BaseArgs>): Record<string, unknown> | BodyInit;
248
+ }
249
+
250
+ export interface AudioClassificationTaskHelper {
251
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<AudioClassificationOutput>;
252
+ preparePayload(params: BodyParams<AudioClassificationInput & BaseArgs>): Record<string, unknown> | BodyInit;
253
+ }
254
+
255
+ // Multimodal Tasks
256
+ export interface DocumentQuestionAnsweringTaskHelper {
257
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<DocumentQuestionAnsweringOutput[number]>;
258
+ preparePayload(params: BodyParams<DocumentQuestionAnsweringInput & BaseArgs>): Record<string, unknown> | BodyInit;
259
+ }
260
+
261
+ export interface FeatureExtractionTaskHelper {
262
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<FeatureExtractionOutput>;
263
+ preparePayload(params: BodyParams<FeatureExtractionInput & BaseArgs>): Record<string, unknown>;
264
+ }
265
+
266
+ export interface VisualQuestionAnsweringTaskHelper {
267
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<VisualQuestionAnsweringOutput[number]>;
268
+ preparePayload(params: BodyParams<VisualQuestionAnsweringInput & BaseArgs>): Record<string, unknown> | BodyInit;
269
+ }
270
+
271
+ export interface TabularClassificationTaskHelper {
272
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<number[]>;
273
+ preparePayload(
274
+ params: BodyParams<BaseArgs & { inputs: { data: Record<string, string[]> } } & Record<string, unknown>>
275
+ ): Record<string, unknown> | BodyInit;
276
+ }
277
+
278
+ export interface TabularRegressionTaskHelper {
279
+ getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<number[]>;
280
+ preparePayload(
281
+ params: BodyParams<BaseArgs & { inputs: { data: Record<string, string[]> } } & Record<string, unknown>>
282
+ ): Record<string, unknown> | BodyInit;
283
+ }
284
+
285
+ // BASE IMPLEMENTATIONS FOR COMMON PATTERNS
286
+
287
+ export class BaseConversationalTask extends TaskProviderHelper implements ConversationalTaskHelper {
288
+ constructor(provider: InferenceProvider, baseUrl: string, clientSideRoutingOnly: boolean = false) {
289
+ super(provider, baseUrl, clientSideRoutingOnly);
290
+ }
291
+
292
+ makeRoute(): string {
293
+ return "v1/chat/completions";
294
+ }
295
+
296
+ preparePayload(params: BodyParams): Record<string, unknown> {
297
+ return {
298
+ ...params.args,
299
+ model: params.model,
300
+ };
301
+ }
302
+
303
+ async getResponse(response: ChatCompletionOutput): Promise<ChatCompletionOutput> {
304
+ if (
305
+ typeof response === "object" &&
306
+ Array.isArray(response?.choices) &&
307
+ typeof response?.created === "number" &&
308
+ typeof response?.id === "string" &&
309
+ typeof response?.model === "string" &&
310
+ /// Together.ai and Nebius do not output a system_fingerprint
311
+ (response.system_fingerprint === undefined ||
312
+ response.system_fingerprint === null ||
313
+ typeof response.system_fingerprint === "string") &&
314
+ typeof response?.usage === "object"
315
+ ) {
316
+ return response;
317
+ }
318
+
319
+ throw new InferenceOutputError("Expected ChatCompletionOutput");
320
+ }
321
+ }
322
+
323
+ export class BaseTextGenerationTask extends TaskProviderHelper implements TextGenerationTaskHelper {
324
+ constructor(provider: InferenceProvider, baseUrl: string, clientSideRoutingOnly: boolean = false) {
325
+ super(provider, baseUrl, clientSideRoutingOnly);
326
+ }
327
+
328
+ preparePayload(params: BodyParams): Record<string, unknown> {
329
+ return {
330
+ ...params.args,
331
+ model: params.model,
332
+ };
333
+ }
334
+
335
+ makeRoute(): string {
336
+ return "v1/completions";
337
+ }
338
+
339
+ async getResponse(response: unknown): Promise<TextGenerationOutput> {
340
+ const res = toArray(response);
341
+ if (
342
+ Array.isArray(res) &&
343
+ res.length > 0 &&
344
+ res.every(
345
+ (x): x is { generated_text: string } =>
346
+ typeof x === "object" && !!x && "generated_text" in x && typeof x.generated_text === "string"
347
+ )
348
+ ) {
349
+ return res[0];
350
+ }
351
+
352
+ throw new InferenceOutputError("Expected Array<{generated_text: string}>");
353
+ }
354
+ }
@@ -14,37 +14,127 @@
14
14
  *
15
15
  * Thanks!
16
16
  */
17
- import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
18
-
19
- export const REPLICATE_API_BASE_URL = "https://api.replicate.com";
20
-
21
- const makeBaseUrl = (): string => {
22
- return REPLICATE_API_BASE_URL;
23
- };
24
-
25
- const makeBody = (params: BodyParams): Record<string, unknown> => {
26
- return {
27
- input: params.args,
28
- version: params.model.includes(":") ? params.model.split(":")[1] : undefined,
29
- };
30
- };
31
-
32
- const makeHeaders = (params: HeaderParams): Record<string, string> => {
33
- return { Authorization: `Bearer ${params.accessToken}`, Prefer: "wait" };
34
- };
35
-
36
- const makeUrl = (params: UrlParams): string => {
37
- if (params.model.includes(":")) {
38
- /// Versioned model
39
- return `${params.baseUrl}/v1/predictions`;
40
- }
41
- /// Evergreen / Canonical model
42
- return `${params.baseUrl}/v1/models/${params.model}/predictions`;
43
- };
44
-
45
- export const REPLICATE_CONFIG: ProviderConfig = {
46
- makeBaseUrl,
47
- makeBody,
48
- makeHeaders,
49
- makeUrl,
50
- };
17
+ import { InferenceOutputError } from "../lib/InferenceOutputError";
18
+ import { isUrl } from "../lib/isUrl";
19
+ import type { BodyParams, HeaderParams, UrlParams } from "../types";
20
+ import { omit } from "../utils/omit";
21
+ import { TaskProviderHelper, type TextToImageTaskHelper, type TextToVideoTaskHelper } from "./providerHelper";
22
+
23
+ export interface ReplicateOutput {
24
+ output?: string | string[];
25
+ }
26
+
27
+ abstract class ReplicateTask extends TaskProviderHelper {
28
+ constructor(url?: string) {
29
+ super("replicate", url || "https://api.replicate.com");
30
+ }
31
+
32
+ makeRoute(params: UrlParams): string {
33
+ if (params.model.includes(":")) {
34
+ return "v1/predictions";
35
+ }
36
+ return `v1/models/${params.model}/predictions`;
37
+ }
38
+ preparePayload(params: BodyParams): Record<string, unknown> {
39
+ return {
40
+ input: {
41
+ ...omit(params.args, ["inputs", "parameters"]),
42
+ ...(params.args.parameters as Record<string, unknown>),
43
+ prompt: params.args.inputs,
44
+ },
45
+ version: params.model.includes(":") ? params.model.split(":")[1] : undefined,
46
+ };
47
+ }
48
+ override prepareHeaders(params: HeaderParams, binary: boolean): Record<string, string> {
49
+ const headers: Record<string, string> = { Authorization: `Bearer ${params.accessToken}`, Prefer: "wait" };
50
+ if (!binary) {
51
+ headers["Content-Type"] = "application/json";
52
+ }
53
+ return headers;
54
+ }
55
+
56
+ override makeUrl(params: UrlParams): string {
57
+ const baseUrl = this.makeBaseUrl(params);
58
+ if (params.model.includes(":")) {
59
+ return `${baseUrl}/v1/predictions`;
60
+ }
61
+ return `${baseUrl}/v1/models/${params.model}/predictions`;
62
+ }
63
+ }
64
+
65
+ export class ReplicateTextToImageTask extends ReplicateTask implements TextToImageTaskHelper {
66
+ override async getResponse(
67
+ res: ReplicateOutput | Blob,
68
+ url?: string,
69
+ headers?: Record<string, string>,
70
+ outputType?: "url" | "blob"
71
+ ): Promise<string | Blob> {
72
+ void url;
73
+ void headers;
74
+ if (
75
+ typeof res === "object" &&
76
+ "output" in res &&
77
+ Array.isArray(res.output) &&
78
+ res.output.length > 0 &&
79
+ typeof res.output[0] === "string"
80
+ ) {
81
+ if (outputType === "url") {
82
+ return res.output[0];
83
+ }
84
+ const urlResponse = await fetch(res.output[0]);
85
+ return await urlResponse.blob();
86
+ }
87
+
88
+ throw new InferenceOutputError("Expected Replicate text-to-image response format");
89
+ }
90
+ }
91
+
92
+ export class ReplicateTextToSpeechTask extends ReplicateTask {
93
+ override preparePayload(params: BodyParams): Record<string, unknown> {
94
+ const payload = super.preparePayload(params);
95
+
96
+ const input = payload["input"];
97
+ if (typeof input === "object" && input !== null && "prompt" in input) {
98
+ const inputObj = input as Record<string, unknown>;
99
+ inputObj["text"] = inputObj["prompt"];
100
+ delete inputObj["prompt"];
101
+ }
102
+
103
+ return payload;
104
+ }
105
+
106
+ override async getResponse(response: ReplicateOutput): Promise<Blob> {
107
+ if (response instanceof Blob) {
108
+ return response;
109
+ }
110
+ if (response && typeof response === "object") {
111
+ if ("output" in response) {
112
+ if (typeof response.output === "string") {
113
+ const urlResponse = await fetch(response.output);
114
+ return await urlResponse.blob();
115
+ } else if (Array.isArray(response.output)) {
116
+ const urlResponse = await fetch(response.output[0]);
117
+ return await urlResponse.blob();
118
+ }
119
+ }
120
+ }
121
+ throw new InferenceOutputError("Expected Blob or object with output");
122
+ }
123
+ }
124
+
125
+ export class ReplicateTextToVideoTask extends ReplicateTask implements TextToVideoTaskHelper {
126
+ override async getResponse(response: ReplicateOutput): Promise<Blob> {
127
+ if (
128
+ typeof response === "object" &&
129
+ !!response &&
130
+ "output" in response &&
131
+ typeof response.output === "string" &&
132
+ isUrl(response.output)
133
+ ) {
134
+ const urlResponse = await fetch(response.output);
135
+ return await urlResponse.blob();
136
+ }
137
+
138
+ throw new InferenceOutputError("Expected { output: string }");
139
+ }
140
+ }
@@ -14,35 +14,10 @@
14
14
  *
15
15
  * Thanks!
16
16
  */
17
- import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
17
+ import { BaseConversationalTask } from "./providerHelper";
18
18
 
19
- const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
20
-
21
- const makeBaseUrl = (): string => {
22
- return SAMBANOVA_API_BASE_URL;
23
- };
24
-
25
- const makeBody = (params: BodyParams): Record<string, unknown> => {
26
- return {
27
- ...params.args,
28
- ...(params.chatCompletion ? { model: params.model } : undefined),
29
- };
30
- };
31
-
32
- const makeHeaders = (params: HeaderParams): Record<string, string> => {
33
- return { Authorization: `Bearer ${params.accessToken}` };
34
- };
35
-
36
- const makeUrl = (params: UrlParams): string => {
37
- if (params.chatCompletion) {
38
- return `${params.baseUrl}/v1/chat/completions`;
19
+ export class SambanovaConversationalTask extends BaseConversationalTask {
20
+ constructor() {
21
+ super("sambanova", "https://api.sambanova.ai");
39
22
  }
40
- return params.baseUrl;
41
- };
42
-
43
- export const SAMBANOVA_CONFIG: ProviderConfig = {
44
- makeBaseUrl,
45
- makeBody,
46
- makeHeaders,
47
- makeUrl,
48
- };
23
+ }
@@ -14,41 +14,105 @@
14
14
  *
15
15
  * Thanks!
16
16
  */
17
- import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
17
+ import type { ChatCompletionOutput, TextGenerationOutput, TextGenerationOutputFinishReason } from "@huggingface/tasks";
18
+ import { InferenceOutputError } from "../lib/InferenceOutputError";
19
+ import type { BodyParams } from "../types";
20
+ import { omit } from "../utils/omit";
21
+ import {
22
+ BaseConversationalTask,
23
+ BaseTextGenerationTask,
24
+ TaskProviderHelper,
25
+ type TextToImageTaskHelper,
26
+ } from "./providerHelper";
18
27
 
19
28
  const TOGETHER_API_BASE_URL = "https://api.together.xyz";
20
29
 
21
- const makeBaseUrl = (): string => {
22
- return TOGETHER_API_BASE_URL;
23
- };
30
+ interface TogetherTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
31
+ choices: Array<{
32
+ text: string;
33
+ finish_reason: TextGenerationOutputFinishReason;
34
+ seed: number;
35
+ logprobs: unknown;
36
+ index: number;
37
+ }>;
38
+ }
24
39
 
25
- const makeBody = (params: BodyParams): Record<string, unknown> => {
26
- return {
27
- ...params.args,
28
- model: params.model,
29
- };
30
- };
40
+ interface TogetherBase64ImageGeneration {
41
+ data: Array<{
42
+ b64_json: string;
43
+ }>;
44
+ }
31
45
 
32
- const makeHeaders = (params: HeaderParams): Record<string, string> => {
33
- return { Authorization: `Bearer ${params.accessToken}` };
34
- };
46
+ export class TogetherConversationalTask extends BaseConversationalTask {
47
+ constructor() {
48
+ super("together", TOGETHER_API_BASE_URL);
49
+ }
50
+ }
51
+
52
+ export class TogetherTextGenerationTask extends BaseTextGenerationTask {
53
+ constructor() {
54
+ super("together", TOGETHER_API_BASE_URL);
55
+ }
35
56
 
36
- const makeUrl = (params: UrlParams): string => {
37
- if (params.task === "text-to-image") {
38
- return `${params.baseUrl}/v1/images/generations`;
57
+ override preparePayload(params: BodyParams): Record<string, unknown> {
58
+ return {
59
+ model: params.model,
60
+ ...params.args,
61
+ prompt: params.args.inputs,
62
+ };
39
63
  }
40
- if (params.chatCompletion) {
41
- return `${params.baseUrl}/v1/chat/completions`;
64
+
65
+ override async getResponse(response: TogetherTextCompletionOutput): Promise<TextGenerationOutput> {
66
+ if (
67
+ typeof response === "object" &&
68
+ "choices" in response &&
69
+ Array.isArray(response?.choices) &&
70
+ typeof response?.model === "string"
71
+ ) {
72
+ const completion = response.choices[0];
73
+ return {
74
+ generated_text: completion.text,
75
+ };
76
+ }
77
+ throw new InferenceOutputError("Expected Together text generation response format");
42
78
  }
43
- if (params.task === "text-generation") {
44
- return `${params.baseUrl}/v1/completions`;
79
+ }
80
+
81
+ export class TogetherTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper {
82
+ constructor() {
83
+ super("together", TOGETHER_API_BASE_URL);
45
84
  }
46
- return params.baseUrl;
47
- };
48
85
 
49
- export const TOGETHER_CONFIG: ProviderConfig = {
50
- makeBaseUrl,
51
- makeBody,
52
- makeHeaders,
53
- makeUrl,
54
- };
86
+ makeRoute(): string {
87
+ return "v1/images/generations";
88
+ }
89
+
90
+ preparePayload(params: BodyParams): Record<string, unknown> {
91
+ return {
92
+ ...omit(params.args, ["inputs", "parameters"]),
93
+ ...(params.args.parameters as Record<string, unknown>),
94
+ prompt: params.args.inputs,
95
+ response_format: "base64",
96
+ model: params.model,
97
+ };
98
+ }
99
+
100
+ async getResponse(response: TogetherBase64ImageGeneration, outputType?: "url" | "blob"): Promise<string | Blob> {
101
+ if (
102
+ typeof response === "object" &&
103
+ "data" in response &&
104
+ Array.isArray(response.data) &&
105
+ response.data.length > 0 &&
106
+ "b64_json" in response.data[0] &&
107
+ typeof response.data[0].b64_json === "string"
108
+ ) {
109
+ const base64Data = response.data[0].b64_json;
110
+ if (outputType === "url") {
111
+ return `data:image/jpeg;base64,${base64Data}`;
112
+ }
113
+ return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
114
+ }
115
+
116
+ throw new InferenceOutputError("Expected Together text-to-image response format");
117
+ }
118
+ }