@huggingface/inference 3.7.0 → 3.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. package/dist/index.cjs +1369 -941
  2. package/dist/index.js +1371 -943
  3. package/dist/src/lib/getInferenceProviderMapping.d.ts +21 -0
  4. package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -0
  5. package/dist/src/lib/getProviderHelper.d.ts +37 -0
  6. package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
  7. package/dist/src/lib/makeRequestOptions.d.ts +5 -5
  8. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  9. package/dist/src/providers/black-forest-labs.d.ts +14 -18
  10. package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
  11. package/dist/src/providers/cerebras.d.ts +4 -2
  12. package/dist/src/providers/cerebras.d.ts.map +1 -1
  13. package/dist/src/providers/cohere.d.ts +5 -2
  14. package/dist/src/providers/cohere.d.ts.map +1 -1
  15. package/dist/src/providers/consts.d.ts +2 -3
  16. package/dist/src/providers/consts.d.ts.map +1 -1
  17. package/dist/src/providers/fal-ai.d.ts +50 -3
  18. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  19. package/dist/src/providers/fireworks-ai.d.ts +5 -2
  20. package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
  21. package/dist/src/providers/hf-inference.d.ts +126 -2
  22. package/dist/src/providers/hf-inference.d.ts.map +1 -1
  23. package/dist/src/providers/hyperbolic.d.ts +31 -2
  24. package/dist/src/providers/hyperbolic.d.ts.map +1 -1
  25. package/dist/src/providers/nebius.d.ts +20 -18
  26. package/dist/src/providers/nebius.d.ts.map +1 -1
  27. package/dist/src/providers/novita.d.ts +21 -18
  28. package/dist/src/providers/novita.d.ts.map +1 -1
  29. package/dist/src/providers/openai.d.ts +4 -2
  30. package/dist/src/providers/openai.d.ts.map +1 -1
  31. package/dist/src/providers/providerHelper.d.ts +182 -0
  32. package/dist/src/providers/providerHelper.d.ts.map +1 -0
  33. package/dist/src/providers/replicate.d.ts +23 -19
  34. package/dist/src/providers/replicate.d.ts.map +1 -1
  35. package/dist/src/providers/sambanova.d.ts +4 -2
  36. package/dist/src/providers/sambanova.d.ts.map +1 -1
  37. package/dist/src/providers/together.d.ts +32 -2
  38. package/dist/src/providers/together.d.ts.map +1 -1
  39. package/dist/src/snippets/getInferenceSnippets.d.ts +2 -1
  40. package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
  41. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  42. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  43. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  44. package/dist/src/tasks/audio/utils.d.ts +2 -1
  45. package/dist/src/tasks/audio/utils.d.ts.map +1 -1
  46. package/dist/src/tasks/custom/request.d.ts +0 -2
  47. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  48. package/dist/src/tasks/custom/streamingRequest.d.ts +0 -2
  49. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  50. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  51. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  52. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  53. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  54. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  55. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  56. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
  57. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  58. package/dist/src/tasks/index.d.ts +6 -6
  59. package/dist/src/tasks/index.d.ts.map +1 -1
  60. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  61. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  62. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  63. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
  64. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  65. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  66. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  67. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  68. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  69. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  70. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  71. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  72. package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
  73. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  74. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  75. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  76. package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
  77. package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
  78. package/dist/src/types.d.ts +5 -13
  79. package/dist/src/types.d.ts.map +1 -1
  80. package/dist/src/utils/request.d.ts +3 -2
  81. package/dist/src/utils/request.d.ts.map +1 -1
  82. package/package.json +3 -3
  83. package/src/lib/getInferenceProviderMapping.ts +96 -0
  84. package/src/lib/getProviderHelper.ts +270 -0
  85. package/src/lib/makeRequestOptions.ts +78 -97
  86. package/src/providers/black-forest-labs.ts +73 -22
  87. package/src/providers/cerebras.ts +6 -27
  88. package/src/providers/cohere.ts +9 -28
  89. package/src/providers/consts.ts +5 -2
  90. package/src/providers/fal-ai.ts +224 -77
  91. package/src/providers/fireworks-ai.ts +8 -29
  92. package/src/providers/hf-inference.ts +557 -34
  93. package/src/providers/hyperbolic.ts +107 -29
  94. package/src/providers/nebius.ts +65 -29
  95. package/src/providers/novita.ts +68 -32
  96. package/src/providers/openai.ts +6 -32
  97. package/src/providers/providerHelper.ts +354 -0
  98. package/src/providers/replicate.ts +124 -34
  99. package/src/providers/sambanova.ts +5 -30
  100. package/src/providers/together.ts +92 -28
  101. package/src/snippets/getInferenceSnippets.ts +39 -14
  102. package/src/snippets/templates.exported.ts +25 -25
  103. package/src/tasks/audio/audioClassification.ts +5 -8
  104. package/src/tasks/audio/audioToAudio.ts +4 -27
  105. package/src/tasks/audio/automaticSpeechRecognition.ts +5 -4
  106. package/src/tasks/audio/textToSpeech.ts +5 -29
  107. package/src/tasks/audio/utils.ts +2 -1
  108. package/src/tasks/custom/request.ts +3 -3
  109. package/src/tasks/custom/streamingRequest.ts +4 -3
  110. package/src/tasks/cv/imageClassification.ts +4 -8
  111. package/src/tasks/cv/imageSegmentation.ts +4 -9
  112. package/src/tasks/cv/imageToImage.ts +4 -7
  113. package/src/tasks/cv/imageToText.ts +4 -7
  114. package/src/tasks/cv/objectDetection.ts +4 -19
  115. package/src/tasks/cv/textToImage.ts +9 -137
  116. package/src/tasks/cv/textToVideo.ts +17 -64
  117. package/src/tasks/cv/zeroShotImageClassification.ts +4 -8
  118. package/src/tasks/index.ts +6 -6
  119. package/src/tasks/multimodal/documentQuestionAnswering.ts +4 -19
  120. package/src/tasks/multimodal/visualQuestionAnswering.ts +4 -12
  121. package/src/tasks/nlp/chatCompletion.ts +5 -20
  122. package/src/tasks/nlp/chatCompletionStream.ts +4 -3
  123. package/src/tasks/nlp/featureExtraction.ts +4 -19
  124. package/src/tasks/nlp/fillMask.ts +4 -17
  125. package/src/tasks/nlp/questionAnswering.ts +11 -26
  126. package/src/tasks/nlp/sentenceSimilarity.ts +4 -8
  127. package/src/tasks/nlp/summarization.ts +4 -7
  128. package/src/tasks/nlp/tableQuestionAnswering.ts +10 -30
  129. package/src/tasks/nlp/textClassification.ts +4 -9
  130. package/src/tasks/nlp/textGeneration.ts +11 -79
  131. package/src/tasks/nlp/textGenerationStream.ts +3 -1
  132. package/src/tasks/nlp/tokenClassification.ts +11 -23
  133. package/src/tasks/nlp/translation.ts +4 -7
  134. package/src/tasks/nlp/zeroShotClassification.ts +11 -21
  135. package/src/tasks/tabular/tabularClassification.ts +4 -7
  136. package/src/tasks/tabular/tabularRegression.ts +4 -7
  137. package/src/types.ts +5 -14
  138. package/src/utils/request.ts +7 -4
  139. package/dist/src/lib/getProviderModelId.d.ts +0 -10
  140. package/dist/src/lib/getProviderModelId.d.ts.map +0 -1
  141. package/src/lib/getProviderModelId.ts +0 -74
@@ -14,109 +14,256 @@
14
14
  *
15
15
  * Thanks!
16
16
  */
17
+ import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
17
18
  import { InferenceOutputError } from "../lib/InferenceOutputError";
18
19
  import { isUrl } from "../lib/isUrl";
19
- import type { BodyParams, HeaderParams, InferenceTask, ProviderConfig, UrlParams } from "../types";
20
+ import type { BodyParams, HeaderParams, ModelId, UrlParams } from "../types";
20
21
  import { delay } from "../utils/delay";
22
+ import { omit } from "../utils/omit";
23
+ import {
24
+ type AutomaticSpeechRecognitionTaskHelper,
25
+ TaskProviderHelper,
26
+ type TextToImageTaskHelper,
27
+ type TextToVideoTaskHelper,
28
+ } from "./providerHelper";
29
+ import { HF_HUB_URL } from "../config";
21
30
 
22
- const FAL_AI_API_BASE_URL = "https://fal.run";
23
- const FAL_AI_API_BASE_URL_QUEUE = "https://queue.fal.run";
31
+ export interface FalAiQueueOutput {
32
+ request_id: string;
33
+ status: string;
34
+ response_url: string;
35
+ }
24
36
 
25
- const makeBaseUrl = (task?: InferenceTask): string => {
26
- return task === "text-to-video" ? FAL_AI_API_BASE_URL_QUEUE : FAL_AI_API_BASE_URL;
27
- };
37
+ interface FalAITextToImageOutput {
38
+ images: Array<{
39
+ url: string;
40
+ }>;
41
+ }
28
42
 
29
- const makeBody = (params: BodyParams): Record<string, unknown> => {
30
- return params.args;
31
- };
43
+ interface FalAIAutomaticSpeechRecognitionOutput {
44
+ text: string;
45
+ }
32
46
 
33
- const makeHeaders = (params: HeaderParams): Record<string, string> => {
34
- return {
35
- Authorization: params.authMethod === "provider-key" ? `Key ${params.accessToken}` : `Bearer ${params.accessToken}`,
47
+ interface FalAITextToSpeechOutput {
48
+ audio: {
49
+ url: string;
50
+ content_type: string;
36
51
  };
37
- };
52
+ }
53
+ export const FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
38
54
 
39
- const makeUrl = (params: UrlParams): string => {
40
- const baseUrl = `${params.baseUrl}/${params.model}`;
41
- if (params.authMethod !== "provider-key" && params.task === "text-to-video") {
42
- return `${baseUrl}?_subdomain=queue`;
55
+ abstract class FalAITask extends TaskProviderHelper {
56
+ constructor(url?: string) {
57
+ super("fal-ai", url || "https://fal.run");
43
58
  }
44
- return baseUrl;
45
- };
46
59
 
47
- export const FAL_AI_CONFIG: ProviderConfig = {
48
- makeBaseUrl,
49
- makeBody,
50
- makeHeaders,
51
- makeUrl,
52
- };
60
+ preparePayload(params: BodyParams): Record<string, unknown> {
61
+ return params.args;
62
+ }
63
+ makeRoute(params: UrlParams): string {
64
+ return `/${params.model}`;
65
+ }
66
+ override prepareHeaders(params: HeaderParams, binary: boolean): Record<string, string> {
67
+ const headers: Record<string, string> = {
68
+ Authorization:
69
+ params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`,
70
+ };
71
+ if (!binary) {
72
+ headers["Content-Type"] = "application/json";
73
+ }
74
+ return headers;
75
+ }
76
+ }
53
77
 
54
- export interface FalAiQueueOutput {
55
- request_id: string;
56
- status: string;
57
- response_url: string;
78
+ function buildLoraPath(modelId: ModelId, adapterWeightsPath: string): string {
79
+ return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
80
+ }
81
+
82
+ export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHelper {
83
+ override preparePayload(params: BodyParams): Record<string, unknown> {
84
+ const payload: Record<string, unknown> = {
85
+ ...omit(params.args, ["inputs", "parameters"]),
86
+ ...(params.args.parameters as Record<string, unknown>),
87
+ sync_mode: true,
88
+ prompt: params.args.inputs,
89
+ ...(params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath
90
+ ? {
91
+ loras: [
92
+ {
93
+ path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
94
+ scale: 1,
95
+ },
96
+ ],
97
+ }
98
+ : undefined),
99
+ };
100
+
101
+ if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
102
+ payload.loras = [
103
+ {
104
+ path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
105
+ scale: 1,
106
+ },
107
+ ];
108
+ if (params.mapping.providerId === "fal-ai/lora") {
109
+ payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
110
+ }
111
+ }
112
+
113
+ return payload;
114
+ }
115
+
116
+ override async getResponse(response: FalAITextToImageOutput, outputType?: "url" | "blob"): Promise<string | Blob> {
117
+ if (
118
+ typeof response === "object" &&
119
+ "images" in response &&
120
+ Array.isArray(response.images) &&
121
+ response.images.length > 0 &&
122
+ "url" in response.images[0] &&
123
+ typeof response.images[0].url === "string"
124
+ ) {
125
+ if (outputType === "url") {
126
+ return response.images[0].url;
127
+ }
128
+ const urlResponse = await fetch(response.images[0].url);
129
+ return await urlResponse.blob();
130
+ }
131
+
132
+ throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
133
+ }
58
134
  }
59
135
 
60
- export async function pollFalResponse(
61
- res: FalAiQueueOutput,
62
- url: string,
63
- headers: Record<string, string>
64
- ): Promise<Blob> {
65
- const requestId = res.request_id;
66
- if (!requestId) {
67
- throw new InferenceOutputError("No request ID found in the response");
136
+ export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHelper {
137
+ constructor() {
138
+ super("https://queue.fal.run");
139
+ }
140
+ override makeRoute(params: UrlParams): string {
141
+ if (params.authMethod !== "provider-key") {
142
+ return `/${params.model}?_subdomain=queue`;
143
+ }
144
+ return `/${params.model}`;
68
145
  }
69
- let status = res.status;
146
+ override preparePayload(params: BodyParams): Record<string, unknown> {
147
+ return {
148
+ ...omit(params.args, ["inputs", "parameters"]),
149
+ ...(params.args.parameters as Record<string, unknown>),
150
+ prompt: params.args.inputs,
151
+ };
152
+ }
153
+
154
+ override async getResponse(
155
+ response: FalAiQueueOutput,
156
+ url?: string,
157
+ headers?: Record<string, string>
158
+ ): Promise<Blob> {
159
+ if (!url || !headers) {
160
+ throw new InferenceOutputError("URL and headers are required for text-to-video task");
161
+ }
162
+ const requestId = response.request_id;
163
+ if (!requestId) {
164
+ throw new InferenceOutputError("No request ID found in the response");
165
+ }
166
+ let status = response.status;
70
167
 
71
- const parsedUrl = new URL(url);
72
- const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${
73
- parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
74
- }`;
168
+ const parsedUrl = new URL(url);
169
+ const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${
170
+ parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
171
+ }`;
75
172
 
76
- // extracting the provider model id for status and result urls
77
- // from the response as it might be different from the mapped model in `url`
78
- const modelId = new URL(res.response_url).pathname;
79
- const queryParams = parsedUrl.search;
173
+ // extracting the provider model id for status and result urls
174
+ // from the response as it might be different from the mapped model in `url`
175
+ const modelId = new URL(response.response_url).pathname;
176
+ const queryParams = parsedUrl.search;
80
177
 
81
- const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
82
- const resultUrl = `${baseUrl}${modelId}${queryParams}`;
178
+ const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
179
+ const resultUrl = `${baseUrl}${modelId}${queryParams}`;
83
180
 
84
- while (status !== "COMPLETED") {
85
- await delay(500);
86
- const statusResponse = await fetch(statusUrl, { headers });
181
+ while (status !== "COMPLETED") {
182
+ await delay(500);
183
+ const statusResponse = await fetch(statusUrl, { headers });
87
184
 
88
- if (!statusResponse.ok) {
89
- throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
185
+ if (!statusResponse.ok) {
186
+ throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
187
+ }
188
+ try {
189
+ status = (await statusResponse.json()).status;
190
+ } catch (error) {
191
+ throw new InferenceOutputError("Failed to parse status response from fal-ai API");
192
+ }
90
193
  }
194
+
195
+ const resultResponse = await fetch(resultUrl, { headers });
196
+ let result: unknown;
91
197
  try {
92
- status = (await statusResponse.json()).status;
198
+ result = await resultResponse.json();
93
199
  } catch (error) {
94
- throw new InferenceOutputError("Failed to parse status response from fal-ai API");
200
+ throw new InferenceOutputError("Failed to parse result response from fal-ai API");
201
+ }
202
+ if (
203
+ typeof result === "object" &&
204
+ !!result &&
205
+ "video" in result &&
206
+ typeof result.video === "object" &&
207
+ !!result.video &&
208
+ "url" in result.video &&
209
+ typeof result.video.url === "string" &&
210
+ isUrl(result.video.url)
211
+ ) {
212
+ const urlResponse = await fetch(result.video.url);
213
+ return await urlResponse.blob();
214
+ } else {
215
+ throw new InferenceOutputError(
216
+ "Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
217
+ );
95
218
  }
96
219
  }
220
+ }
221
+
222
+ export class FalAIAutomaticSpeechRecognitionTask extends FalAITask implements AutomaticSpeechRecognitionTaskHelper {
223
+ override prepareHeaders(params: HeaderParams, binary: boolean): Record<string, string> {
224
+ const headers = super.prepareHeaders(params, binary);
225
+ headers["Content-Type"] = "application/json";
226
+ return headers;
227
+ }
228
+ override async getResponse(response: unknown): Promise<AutomaticSpeechRecognitionOutput> {
229
+ const res = response as FalAIAutomaticSpeechRecognitionOutput;
230
+ if (typeof res?.text !== "string") {
231
+ throw new InferenceOutputError(
232
+ `Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
233
+ );
234
+ }
235
+ return { text: res.text };
236
+ }
237
+ }
97
238
 
98
- const resultResponse = await fetch(resultUrl, { headers });
99
- let result: unknown;
100
- try {
101
- result = await resultResponse.json();
102
- } catch (error) {
103
- throw new InferenceOutputError("Failed to parse result response from fal-ai API");
239
+ export class FalAITextToSpeechTask extends FalAITask {
240
+ override preparePayload(params: BodyParams): Record<string, unknown> {
241
+ return {
242
+ ...omit(params.args, ["inputs", "parameters"]),
243
+ ...(params.args.parameters as Record<string, unknown>),
244
+ lyrics: params.args.inputs,
245
+ };
104
246
  }
105
- if (
106
- typeof result === "object" &&
107
- !!result &&
108
- "video" in result &&
109
- typeof result.video === "object" &&
110
- !!result.video &&
111
- "url" in result.video &&
112
- typeof result.video.url === "string" &&
113
- isUrl(result.video.url)
114
- ) {
115
- const urlResponse = await fetch(result.video.url);
116
- return await urlResponse.blob();
117
- } else {
118
- throw new InferenceOutputError(
119
- "Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
120
- );
247
+
248
+ override async getResponse(response: unknown): Promise<Blob> {
249
+ const res = response as FalAITextToSpeechOutput;
250
+ if (typeof res?.audio?.url !== "string") {
251
+ throw new InferenceOutputError(
252
+ `Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
253
+ );
254
+ }
255
+ try {
256
+ const urlResponse = await fetch(res.audio.url);
257
+ if (!urlResponse.ok) {
258
+ throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
259
+ }
260
+ return await urlResponse.blob();
261
+ } catch (error) {
262
+ throw new InferenceOutputError(
263
+ `Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${
264
+ error instanceof Error ? error.message : String(error)
265
+ }`
266
+ );
267
+ }
121
268
  }
122
269
  }
@@ -14,35 +14,14 @@
14
14
  *
15
15
  * Thanks!
16
16
  */
17
- import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
17
+ import { BaseConversationalTask } from "./providerHelper";
18
18
 
19
- const FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai";
20
-
21
- const makeBaseUrl = (): string => {
22
- return FIREWORKS_AI_API_BASE_URL;
23
- };
24
-
25
- const makeBody = (params: BodyParams): Record<string, unknown> => {
26
- return {
27
- ...params.args,
28
- ...(params.chatCompletion ? { model: params.model } : undefined),
29
- };
30
- };
31
-
32
- const makeHeaders = (params: HeaderParams): Record<string, string> => {
33
- return { Authorization: `Bearer ${params.accessToken}` };
34
- };
35
-
36
- const makeUrl = (params: UrlParams): string => {
37
- if (params.chatCompletion) {
38
- return `${params.baseUrl}/inference/v1/chat/completions`;
19
+ export class FireworksConversationalTask extends BaseConversationalTask {
20
+ constructor() {
21
+ super("fireworks-ai", "https://api.fireworks.ai");
39
22
  }
40
- return `${params.baseUrl}/inference`;
41
- };
42
23
 
43
- export const FIREWORKS_AI_CONFIG: ProviderConfig = {
44
- makeBaseUrl,
45
- makeBody,
46
- makeHeaders,
47
- makeUrl,
48
- };
24
+ override makeRoute(): string {
25
+ return "/inference/v1/chat/completions";
26
+ }
27
+ }