@huggingface/inference 3.6.2 → 3.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. package/README.md +0 -25
  2. package/dist/index.cjs +1232 -898
  3. package/dist/index.js +1234 -900
  4. package/dist/src/config.d.ts +1 -0
  5. package/dist/src/config.d.ts.map +1 -1
  6. package/dist/src/lib/getProviderHelper.d.ts +37 -0
  7. package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
  8. package/dist/src/lib/makeRequestOptions.d.ts +0 -2
  9. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  10. package/dist/src/providers/black-forest-labs.d.ts +14 -18
  11. package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
  12. package/dist/src/providers/cerebras.d.ts +4 -2
  13. package/dist/src/providers/cerebras.d.ts.map +1 -1
  14. package/dist/src/providers/cohere.d.ts +5 -2
  15. package/dist/src/providers/cohere.d.ts.map +1 -1
  16. package/dist/src/providers/fal-ai.d.ts +50 -3
  17. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  18. package/dist/src/providers/fireworks-ai.d.ts +5 -2
  19. package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
  20. package/dist/src/providers/hf-inference.d.ts +125 -2
  21. package/dist/src/providers/hf-inference.d.ts.map +1 -1
  22. package/dist/src/providers/hyperbolic.d.ts +31 -2
  23. package/dist/src/providers/hyperbolic.d.ts.map +1 -1
  24. package/dist/src/providers/nebius.d.ts +20 -18
  25. package/dist/src/providers/nebius.d.ts.map +1 -1
  26. package/dist/src/providers/novita.d.ts +21 -18
  27. package/dist/src/providers/novita.d.ts.map +1 -1
  28. package/dist/src/providers/openai.d.ts +4 -2
  29. package/dist/src/providers/openai.d.ts.map +1 -1
  30. package/dist/src/providers/providerHelper.d.ts +182 -0
  31. package/dist/src/providers/providerHelper.d.ts.map +1 -0
  32. package/dist/src/providers/replicate.d.ts +23 -19
  33. package/dist/src/providers/replicate.d.ts.map +1 -1
  34. package/dist/src/providers/sambanova.d.ts +4 -2
  35. package/dist/src/providers/sambanova.d.ts.map +1 -1
  36. package/dist/src/providers/together.d.ts +32 -2
  37. package/dist/src/providers/together.d.ts.map +1 -1
  38. package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
  39. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  40. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  41. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  42. package/dist/src/tasks/audio/utils.d.ts +2 -1
  43. package/dist/src/tasks/audio/utils.d.ts.map +1 -1
  44. package/dist/src/tasks/custom/request.d.ts +1 -2
  45. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  46. package/dist/src/tasks/custom/streamingRequest.d.ts +1 -2
  47. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  48. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  49. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  50. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  51. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  52. package/dist/src/tasks/cv/objectDetection.d.ts +1 -1
  53. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  54. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  55. package/dist/src/tasks/cv/textToVideo.d.ts +1 -1
  56. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
  57. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts +1 -1
  58. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  59. package/dist/src/tasks/index.d.ts +6 -6
  60. package/dist/src/tasks/index.d.ts.map +1 -1
  61. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts +1 -1
  62. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  63. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  64. package/dist/src/tasks/nlp/chatCompletion.d.ts +1 -1
  65. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  66. package/dist/src/tasks/nlp/chatCompletionStream.d.ts +1 -1
  67. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
  68. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  69. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  70. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  71. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  72. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  73. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  74. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  75. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  76. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  77. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  78. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  79. package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
  80. package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
  81. package/dist/src/types.d.ts +10 -13
  82. package/dist/src/types.d.ts.map +1 -1
  83. package/dist/src/utils/request.d.ts +27 -0
  84. package/dist/src/utils/request.d.ts.map +1 -0
  85. package/package.json +3 -3
  86. package/src/config.ts +1 -0
  87. package/src/lib/getProviderHelper.ts +270 -0
  88. package/src/lib/makeRequestOptions.ts +36 -90
  89. package/src/providers/black-forest-labs.ts +73 -22
  90. package/src/providers/cerebras.ts +6 -27
  91. package/src/providers/cohere.ts +9 -28
  92. package/src/providers/fal-ai.ts +195 -77
  93. package/src/providers/fireworks-ai.ts +8 -29
  94. package/src/providers/hf-inference.ts +555 -34
  95. package/src/providers/hyperbolic.ts +107 -29
  96. package/src/providers/nebius.ts +65 -29
  97. package/src/providers/novita.ts +68 -32
  98. package/src/providers/openai.ts +6 -32
  99. package/src/providers/providerHelper.ts +354 -0
  100. package/src/providers/replicate.ts +124 -34
  101. package/src/providers/sambanova.ts +5 -30
  102. package/src/providers/together.ts +92 -28
  103. package/src/snippets/getInferenceSnippets.ts +16 -9
  104. package/src/snippets/templates.exported.ts +2 -2
  105. package/src/tasks/audio/audioClassification.ts +6 -9
  106. package/src/tasks/audio/audioToAudio.ts +5 -28
  107. package/src/tasks/audio/automaticSpeechRecognition.ts +7 -6
  108. package/src/tasks/audio/textToSpeech.ts +6 -30
  109. package/src/tasks/audio/utils.ts +2 -1
  110. package/src/tasks/custom/request.ts +7 -34
  111. package/src/tasks/custom/streamingRequest.ts +5 -87
  112. package/src/tasks/cv/imageClassification.ts +5 -9
  113. package/src/tasks/cv/imageSegmentation.ts +5 -10
  114. package/src/tasks/cv/imageToImage.ts +5 -8
  115. package/src/tasks/cv/imageToText.ts +8 -13
  116. package/src/tasks/cv/objectDetection.ts +6 -21
  117. package/src/tasks/cv/textToImage.ts +10 -138
  118. package/src/tasks/cv/textToVideo.ts +11 -59
  119. package/src/tasks/cv/zeroShotImageClassification.ts +7 -12
  120. package/src/tasks/index.ts +6 -6
  121. package/src/tasks/multimodal/documentQuestionAnswering.ts +10 -26
  122. package/src/tasks/multimodal/visualQuestionAnswering.ts +6 -12
  123. package/src/tasks/nlp/chatCompletion.ts +7 -23
  124. package/src/tasks/nlp/chatCompletionStream.ts +4 -5
  125. package/src/tasks/nlp/featureExtraction.ts +5 -20
  126. package/src/tasks/nlp/fillMask.ts +5 -18
  127. package/src/tasks/nlp/questionAnswering.ts +5 -23
  128. package/src/tasks/nlp/sentenceSimilarity.ts +5 -18
  129. package/src/tasks/nlp/summarization.ts +5 -8
  130. package/src/tasks/nlp/tableQuestionAnswering.ts +5 -29
  131. package/src/tasks/nlp/textClassification.ts +8 -14
  132. package/src/tasks/nlp/textGeneration.ts +13 -80
  133. package/src/tasks/nlp/textGenerationStream.ts +2 -2
  134. package/src/tasks/nlp/tokenClassification.ts +8 -24
  135. package/src/tasks/nlp/translation.ts +5 -8
  136. package/src/tasks/nlp/zeroShotClassification.ts +8 -22
  137. package/src/tasks/tabular/tabularClassification.ts +5 -8
  138. package/src/tasks/tabular/tabularRegression.ts +5 -8
  139. package/src/types.ts +11 -14
  140. package/src/utils/request.ts +161 -0
@@ -14,33 +14,84 @@
14
14
  *
15
15
  * Thanks!
16
16
  */
17
- import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
17
+ import { InferenceOutputError } from "../lib/InferenceOutputError";
18
+ import type { BodyParams, HeaderParams, UrlParams } from "../types";
19
+ import { delay } from "../utils/delay";
20
+ import { omit } from "../utils/omit";
21
+ import { TaskProviderHelper, type TextToImageTaskHelper } from "./providerHelper";
18
22
 
19
23
  const BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
24
+ interface BlackForestLabsResponse {
25
+ id: string;
26
+ polling_url: string;
27
+ }
20
28
 
21
- const makeBaseUrl = (): string => {
22
- return BLACK_FOREST_LABS_AI_API_BASE_URL;
23
- };
29
+ export class BlackForestLabsTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper {
30
+ constructor() {
31
+ super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
32
+ }
24
33
 
25
- const makeBody = (params: BodyParams): Record<string, unknown> => {
26
- return params.args;
27
- };
34
+ preparePayload(params: BodyParams): Record<string, unknown> {
35
+ return {
36
+ ...omit(params.args, ["inputs", "parameters"]),
37
+ ...(params.args.parameters as Record<string, unknown>),
38
+ prompt: params.args.inputs,
39
+ };
40
+ }
28
41
 
29
- const makeHeaders = (params: HeaderParams): Record<string, string> => {
30
- if (params.authMethod === "provider-key") {
31
- return { "X-Key": `${params.accessToken}` };
32
- } else {
33
- return { Authorization: `Bearer ${params.accessToken}` };
42
+ override prepareHeaders(params: HeaderParams, binary: boolean): Record<string, string> {
43
+ const headers: Record<string, string> = {
44
+ Authorization:
45
+ params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`,
46
+ };
47
+ if (!binary) {
48
+ headers["Content-Type"] = "application/json";
49
+ }
50
+ return headers;
34
51
  }
35
- };
36
52
 
37
- const makeUrl = (params: UrlParams): string => {
38
- return `${params.baseUrl}/v1/${params.model}`;
39
- };
53
+ makeRoute(params: UrlParams): string {
54
+ if (!params) {
55
+ throw new Error("Params are required");
56
+ }
57
+ return `/v1/${params.model}`;
58
+ }
40
59
 
41
- export const BLACK_FOREST_LABS_CONFIG: ProviderConfig = {
42
- makeBaseUrl,
43
- makeBody,
44
- makeHeaders,
45
- makeUrl,
46
- };
60
+ async getResponse(
61
+ response: BlackForestLabsResponse,
62
+ url?: string,
63
+ headers?: HeadersInit,
64
+ outputType?: "url" | "blob"
65
+ ): Promise<string | Blob> {
66
+ const urlObj = new URL(response.polling_url);
67
+ for (let step = 0; step < 5; step++) {
68
+ await delay(1000);
69
+ console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
70
+ urlObj.searchParams.set("attempt", step.toString(10));
71
+ const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
72
+ if (!resp.ok) {
73
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
74
+ }
75
+ const payload = await resp.json();
76
+ if (
77
+ typeof payload === "object" &&
78
+ payload &&
79
+ "status" in payload &&
80
+ typeof payload.status === "string" &&
81
+ payload.status === "Ready" &&
82
+ "result" in payload &&
83
+ typeof payload.result === "object" &&
84
+ payload.result &&
85
+ "sample" in payload.result &&
86
+ typeof payload.result.sample === "string"
87
+ ) {
88
+ if (outputType === "url") {
89
+ return payload.result.sample;
90
+ }
91
+ const image = await fetch(payload.result.sample);
92
+ return await image.blob();
93
+ }
94
+ }
95
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
96
+ }
97
+ }
@@ -14,32 +14,11 @@
14
14
  *
15
15
  * Thanks!
16
16
  */
17
- import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
18
17
 
19
- const CEREBRAS_API_BASE_URL = "https://api.cerebras.ai";
18
+ import { BaseConversationalTask } from "./providerHelper";
20
19
 
21
- const makeBaseUrl = (): string => {
22
- return CEREBRAS_API_BASE_URL;
23
- };
24
-
25
- const makeBody = (params: BodyParams): Record<string, unknown> => {
26
- return {
27
- ...params.args,
28
- model: params.model,
29
- };
30
- };
31
-
32
- const makeHeaders = (params: HeaderParams): Record<string, string> => {
33
- return { Authorization: `Bearer ${params.accessToken}` };
34
- };
35
-
36
- const makeUrl = (params: UrlParams): string => {
37
- return `${params.baseUrl}/v1/chat/completions`;
38
- };
39
-
40
- export const CEREBRAS_CONFIG: ProviderConfig = {
41
- makeBaseUrl,
42
- makeBody,
43
- makeHeaders,
44
- makeUrl,
45
- };
20
+ export class CerebrasConversationalTask extends BaseConversationalTask {
21
+ constructor() {
22
+ super("cerebras", "https://api.cerebras.ai");
23
+ }
24
+ }
@@ -14,32 +14,13 @@
14
14
  *
15
15
  * Thanks!
16
16
  */
17
- import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
17
+ import { BaseConversationalTask } from "./providerHelper";
18
18
 
19
- const COHERE_API_BASE_URL = "https://api.cohere.com";
20
-
21
- const makeBaseUrl = (): string => {
22
- return COHERE_API_BASE_URL;
23
- };
24
-
25
- const makeBody = (params: BodyParams): Record<string, unknown> => {
26
- return {
27
- ...params.args,
28
- model: params.model,
29
- };
30
- };
31
-
32
- const makeHeaders = (params: HeaderParams): Record<string, string> => {
33
- return { Authorization: `Bearer ${params.accessToken}` };
34
- };
35
-
36
- const makeUrl = (params: UrlParams): string => {
37
- return `${params.baseUrl}/compatibility/v1/chat/completions`;
38
- };
39
-
40
- export const COHERE_CONFIG: ProviderConfig = {
41
- makeBaseUrl,
42
- makeBody,
43
- makeHeaders,
44
- makeUrl,
45
- };
19
+ export class CohereConversationalTask extends BaseConversationalTask {
20
+ constructor() {
21
+ super("cohere", "https://api.cohere.com");
22
+ }
23
+ override makeRoute(): string {
24
+ return "/compatibility/v1/chat/completions";
25
+ }
26
+ }
@@ -14,109 +14,227 @@
14
14
  *
15
15
  * Thanks!
16
16
  */
17
+ import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
17
18
  import { InferenceOutputError } from "../lib/InferenceOutputError";
18
19
  import { isUrl } from "../lib/isUrl";
19
- import type { BodyParams, HeaderParams, InferenceTask, ProviderConfig, UrlParams } from "../types";
20
+ import type { BodyParams, HeaderParams, UrlParams } from "../types";
20
21
  import { delay } from "../utils/delay";
22
+ import { omit } from "../utils/omit";
23
+ import {
24
+ type AutomaticSpeechRecognitionTaskHelper,
25
+ TaskProviderHelper,
26
+ type TextToImageTaskHelper,
27
+ type TextToVideoTaskHelper,
28
+ } from "./providerHelper";
21
29
 
22
- const FAL_AI_API_BASE_URL = "https://fal.run";
23
- const FAL_AI_API_BASE_URL_QUEUE = "https://queue.fal.run";
30
+ export interface FalAiQueueOutput {
31
+ request_id: string;
32
+ status: string;
33
+ response_url: string;
34
+ }
24
35
 
25
- const makeBaseUrl = (task?: InferenceTask): string => {
26
- return task === "text-to-video" ? FAL_AI_API_BASE_URL_QUEUE : FAL_AI_API_BASE_URL;
27
- };
36
+ interface FalAITextToImageOutput {
37
+ images: Array<{
38
+ url: string;
39
+ }>;
40
+ }
28
41
 
29
- const makeBody = (params: BodyParams): Record<string, unknown> => {
30
- return params.args;
31
- };
42
+ interface FalAIAutomaticSpeechRecognitionOutput {
43
+ text: string;
44
+ }
32
45
 
33
- const makeHeaders = (params: HeaderParams): Record<string, string> => {
34
- return {
35
- Authorization: params.authMethod === "provider-key" ? `Key ${params.accessToken}` : `Bearer ${params.accessToken}`,
46
+ interface FalAITextToSpeechOutput {
47
+ audio: {
48
+ url: string;
49
+ content_type: string;
36
50
  };
37
- };
51
+ }
52
+ export const FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
38
53
 
39
- const makeUrl = (params: UrlParams): string => {
40
- const baseUrl = `${params.baseUrl}/${params.model}`;
41
- if (params.authMethod !== "provider-key" && params.task === "text-to-video") {
42
- return `${baseUrl}?_subdomain=queue`;
54
+ abstract class FalAITask extends TaskProviderHelper {
55
+ constructor(url?: string) {
56
+ super("fal-ai", url || "https://fal.run");
43
57
  }
44
- return baseUrl;
45
- };
46
58
 
47
- export const FAL_AI_CONFIG: ProviderConfig = {
48
- makeBaseUrl,
49
- makeBody,
50
- makeHeaders,
51
- makeUrl,
52
- };
59
+ preparePayload(params: BodyParams): Record<string, unknown> {
60
+ return params.args;
61
+ }
62
+ makeRoute(params: UrlParams): string {
63
+ return `/${params.model}`;
64
+ }
65
+ override prepareHeaders(params: HeaderParams, binary: boolean): Record<string, string> {
66
+ const headers: Record<string, string> = {
67
+ Authorization:
68
+ params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`,
69
+ };
70
+ if (!binary) {
71
+ headers["Content-Type"] = "application/json";
72
+ }
73
+ return headers;
74
+ }
75
+ }
53
76
 
54
- export interface FalAiQueueOutput {
55
- request_id: string;
56
- status: string;
57
- response_url: string;
77
+ export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHelper {
78
+ override preparePayload(params: BodyParams): Record<string, unknown> {
79
+ return {
80
+ ...omit(params.args, ["inputs", "parameters"]),
81
+ ...(params.args.parameters as Record<string, unknown>),
82
+ sync_mode: true,
83
+ prompt: params.args.inputs,
84
+ };
85
+ }
86
+
87
+ override async getResponse(response: FalAITextToImageOutput, outputType?: "url" | "blob"): Promise<string | Blob> {
88
+ if (
89
+ typeof response === "object" &&
90
+ "images" in response &&
91
+ Array.isArray(response.images) &&
92
+ response.images.length > 0 &&
93
+ "url" in response.images[0] &&
94
+ typeof response.images[0].url === "string"
95
+ ) {
96
+ if (outputType === "url") {
97
+ return response.images[0].url;
98
+ }
99
+ const urlResponse = await fetch(response.images[0].url);
100
+ return await urlResponse.blob();
101
+ }
102
+
103
+ throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
104
+ }
58
105
  }
59
106
 
60
- export async function pollFalResponse(
61
- res: FalAiQueueOutput,
62
- url: string,
63
- headers: Record<string, string>
64
- ): Promise<Blob> {
65
- const requestId = res.request_id;
66
- if (!requestId) {
67
- throw new InferenceOutputError("No request ID found in the response");
107
+ export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHelper {
108
+ constructor() {
109
+ super("https://queue.fal.run");
110
+ }
111
+ override makeRoute(params: UrlParams): string {
112
+ if (params.authMethod !== "provider-key") {
113
+ return `/${params.model}?_subdomain=queue`;
114
+ }
115
+ return `/${params.model}`;
68
116
  }
69
- let status = res.status;
117
+ override preparePayload(params: BodyParams): Record<string, unknown> {
118
+ return {
119
+ ...omit(params.args, ["inputs", "parameters"]),
120
+ ...(params.args.parameters as Record<string, unknown>),
121
+ prompt: params.args.inputs,
122
+ };
123
+ }
124
+
125
+ override async getResponse(
126
+ response: FalAiQueueOutput,
127
+ url?: string,
128
+ headers?: Record<string, string>
129
+ ): Promise<Blob> {
130
+ if (!url || !headers) {
131
+ throw new InferenceOutputError("URL and headers are required for text-to-video task");
132
+ }
133
+ const requestId = response.request_id;
134
+ if (!requestId) {
135
+ throw new InferenceOutputError("No request ID found in the response");
136
+ }
137
+ let status = response.status;
70
138
 
71
- const parsedUrl = new URL(url);
72
- const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${
73
- parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
74
- }`;
139
+ const parsedUrl = new URL(url);
140
+ const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${
141
+ parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
142
+ }`;
75
143
 
76
- // extracting the provider model id for status and result urls
77
- // from the response as it might be different from the mapped model in `url`
78
- const modelId = new URL(res.response_url).pathname;
79
- const queryParams = parsedUrl.search;
144
+ // extracting the provider model id for status and result urls
145
+ // from the response as it might be different from the mapped model in `url`
146
+ const modelId = new URL(response.response_url).pathname;
147
+ const queryParams = parsedUrl.search;
80
148
 
81
- const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
82
- const resultUrl = `${baseUrl}${modelId}${queryParams}`;
149
+ const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
150
+ const resultUrl = `${baseUrl}${modelId}${queryParams}`;
83
151
 
84
- while (status !== "COMPLETED") {
85
- await delay(500);
86
- const statusResponse = await fetch(statusUrl, { headers });
152
+ while (status !== "COMPLETED") {
153
+ await delay(500);
154
+ const statusResponse = await fetch(statusUrl, { headers });
87
155
 
88
- if (!statusResponse.ok) {
89
- throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
156
+ if (!statusResponse.ok) {
157
+ throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
158
+ }
159
+ try {
160
+ status = (await statusResponse.json()).status;
161
+ } catch (error) {
162
+ throw new InferenceOutputError("Failed to parse status response from fal-ai API");
163
+ }
90
164
  }
165
+
166
+ const resultResponse = await fetch(resultUrl, { headers });
167
+ let result: unknown;
91
168
  try {
92
- status = (await statusResponse.json()).status;
169
+ result = await resultResponse.json();
93
170
  } catch (error) {
94
- throw new InferenceOutputError("Failed to parse status response from fal-ai API");
171
+ throw new InferenceOutputError("Failed to parse result response from fal-ai API");
172
+ }
173
+ if (
174
+ typeof result === "object" &&
175
+ !!result &&
176
+ "video" in result &&
177
+ typeof result.video === "object" &&
178
+ !!result.video &&
179
+ "url" in result.video &&
180
+ typeof result.video.url === "string" &&
181
+ isUrl(result.video.url)
182
+ ) {
183
+ const urlResponse = await fetch(result.video.url);
184
+ return await urlResponse.blob();
185
+ } else {
186
+ throw new InferenceOutputError(
187
+ "Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
188
+ );
189
+ }
190
+ }
191
+ }
192
+
193
+ export class FalAIAutomaticSpeechRecognitionTask extends FalAITask implements AutomaticSpeechRecognitionTaskHelper {
194
+ override prepareHeaders(params: HeaderParams, binary: boolean): Record<string, string> {
195
+ const headers = super.prepareHeaders(params, binary);
196
+ headers["Content-Type"] = "application/json";
197
+ return headers;
198
+ }
199
+ override async getResponse(response: unknown): Promise<AutomaticSpeechRecognitionOutput> {
200
+ const res = response as FalAIAutomaticSpeechRecognitionOutput;
201
+ if (typeof res?.text !== "string") {
202
+ throw new InferenceOutputError(
203
+ `Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
204
+ );
95
205
  }
206
+ return { text: res.text };
96
207
  }
208
+ }
97
209
 
98
- const resultResponse = await fetch(resultUrl, { headers });
99
- let result: unknown;
100
- try {
101
- result = await resultResponse.json();
102
- } catch (error) {
103
- throw new InferenceOutputError("Failed to parse result response from fal-ai API");
210
+ export class FalAITextToSpeechTask extends FalAITask {
211
+ override preparePayload(params: BodyParams): Record<string, unknown> {
212
+ return {
213
+ ...omit(params.args, ["inputs", "parameters"]),
214
+ ...(params.args.parameters as Record<string, unknown>),
215
+ lyrics: params.args.inputs,
216
+ };
104
217
  }
105
- if (
106
- typeof result === "object" &&
107
- !!result &&
108
- "video" in result &&
109
- typeof result.video === "object" &&
110
- !!result.video &&
111
- "url" in result.video &&
112
- typeof result.video.url === "string" &&
113
- isUrl(result.video.url)
114
- ) {
115
- const urlResponse = await fetch(result.video.url);
116
- return await urlResponse.blob();
117
- } else {
118
- throw new InferenceOutputError(
119
- "Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
120
- );
218
+
219
+ override async getResponse(response: unknown): Promise<Blob> {
220
+ const res = response as FalAITextToSpeechOutput;
221
+ if (typeof res?.audio?.url !== "string") {
222
+ throw new InferenceOutputError(
223
+ `Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
224
+ );
225
+ }
226
+ try {
227
+ const urlResponse = await fetch(res.audio.url);
228
+ if (!urlResponse.ok) {
229
+ throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
230
+ }
231
+ return await urlResponse.blob();
232
+ } catch (error) {
233
+ throw new InferenceOutputError(
234
+ `Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${
235
+ error instanceof Error ? error.message : String(error)
236
+ }`
237
+ );
238
+ }
121
239
  }
122
240
  }
@@ -14,35 +14,14 @@
14
14
  *
15
15
  * Thanks!
16
16
  */
17
- import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
17
+ import { BaseConversationalTask } from "./providerHelper";
18
18
 
19
- const FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai";
20
-
21
- const makeBaseUrl = (): string => {
22
- return FIREWORKS_AI_API_BASE_URL;
23
- };
24
-
25
- const makeBody = (params: BodyParams): Record<string, unknown> => {
26
- return {
27
- ...params.args,
28
- ...(params.chatCompletion ? { model: params.model } : undefined),
29
- };
30
- };
31
-
32
- const makeHeaders = (params: HeaderParams): Record<string, string> => {
33
- return { Authorization: `Bearer ${params.accessToken}` };
34
- };
35
-
36
- const makeUrl = (params: UrlParams): string => {
37
- if (params.chatCompletion) {
38
- return `${params.baseUrl}/inference/v1/chat/completions`;
19
+ export class FireworksConversationalTask extends BaseConversationalTask {
20
+ constructor() {
21
+ super("fireworks-ai", "https://api.fireworks.ai");
39
22
  }
40
- return `${params.baseUrl}/inference`;
41
- };
42
23
 
43
- export const FIREWORKS_AI_CONFIG: ProviderConfig = {
44
- makeBaseUrl,
45
- makeBody,
46
- makeHeaders,
47
- makeUrl,
48
- };
24
+ override makeRoute(): string {
25
+ return "/inference/v1/chat/completions";
26
+ }
27
+ }