@huggingface/inference 3.6.2 → 3.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +0 -25
- package/dist/index.cjs +1232 -898
- package/dist/index.js +1234 -900
- package/dist/src/config.d.ts +1 -0
- package/dist/src/config.d.ts.map +1 -1
- package/dist/src/lib/getProviderHelper.d.ts +37 -0
- package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
- package/dist/src/lib/makeRequestOptions.d.ts +0 -2
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/black-forest-labs.d.ts +14 -18
- package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
- package/dist/src/providers/cerebras.d.ts +4 -2
- package/dist/src/providers/cerebras.d.ts.map +1 -1
- package/dist/src/providers/cohere.d.ts +5 -2
- package/dist/src/providers/cohere.d.ts.map +1 -1
- package/dist/src/providers/fal-ai.d.ts +50 -3
- package/dist/src/providers/fal-ai.d.ts.map +1 -1
- package/dist/src/providers/fireworks-ai.d.ts +5 -2
- package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
- package/dist/src/providers/hf-inference.d.ts +125 -2
- package/dist/src/providers/hf-inference.d.ts.map +1 -1
- package/dist/src/providers/hyperbolic.d.ts +31 -2
- package/dist/src/providers/hyperbolic.d.ts.map +1 -1
- package/dist/src/providers/nebius.d.ts +20 -18
- package/dist/src/providers/nebius.d.ts.map +1 -1
- package/dist/src/providers/novita.d.ts +21 -18
- package/dist/src/providers/novita.d.ts.map +1 -1
- package/dist/src/providers/openai.d.ts +4 -2
- package/dist/src/providers/openai.d.ts.map +1 -1
- package/dist/src/providers/providerHelper.d.ts +182 -0
- package/dist/src/providers/providerHelper.d.ts.map +1 -0
- package/dist/src/providers/replicate.d.ts +23 -19
- package/dist/src/providers/replicate.d.ts.map +1 -1
- package/dist/src/providers/sambanova.d.ts +4 -2
- package/dist/src/providers/sambanova.d.ts.map +1 -1
- package/dist/src/providers/together.d.ts +32 -2
- package/dist/src/providers/together.d.ts.map +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/audio/utils.d.ts +2 -1
- package/dist/src/tasks/audio/utils.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts +1 -2
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts +1 -2
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/index.d.ts +6 -6
- package/dist/src/tasks/index.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
- package/dist/src/types.d.ts +10 -13
- package/dist/src/types.d.ts.map +1 -1
- package/dist/src/utils/request.d.ts +27 -0
- package/dist/src/utils/request.d.ts.map +1 -0
- package/package.json +3 -3
- package/src/config.ts +1 -0
- package/src/lib/getProviderHelper.ts +270 -0
- package/src/lib/makeRequestOptions.ts +36 -90
- package/src/providers/black-forest-labs.ts +73 -22
- package/src/providers/cerebras.ts +6 -27
- package/src/providers/cohere.ts +9 -28
- package/src/providers/fal-ai.ts +195 -77
- package/src/providers/fireworks-ai.ts +8 -29
- package/src/providers/hf-inference.ts +555 -34
- package/src/providers/hyperbolic.ts +107 -29
- package/src/providers/nebius.ts +65 -29
- package/src/providers/novita.ts +68 -32
- package/src/providers/openai.ts +6 -32
- package/src/providers/providerHelper.ts +354 -0
- package/src/providers/replicate.ts +124 -34
- package/src/providers/sambanova.ts +5 -30
- package/src/providers/together.ts +92 -28
- package/src/snippets/getInferenceSnippets.ts +16 -9
- package/src/snippets/templates.exported.ts +2 -2
- package/src/tasks/audio/audioClassification.ts +6 -9
- package/src/tasks/audio/audioToAudio.ts +5 -28
- package/src/tasks/audio/automaticSpeechRecognition.ts +7 -6
- package/src/tasks/audio/textToSpeech.ts +6 -30
- package/src/tasks/audio/utils.ts +2 -1
- package/src/tasks/custom/request.ts +7 -34
- package/src/tasks/custom/streamingRequest.ts +5 -87
- package/src/tasks/cv/imageClassification.ts +5 -9
- package/src/tasks/cv/imageSegmentation.ts +5 -10
- package/src/tasks/cv/imageToImage.ts +5 -8
- package/src/tasks/cv/imageToText.ts +8 -13
- package/src/tasks/cv/objectDetection.ts +6 -21
- package/src/tasks/cv/textToImage.ts +10 -138
- package/src/tasks/cv/textToVideo.ts +11 -59
- package/src/tasks/cv/zeroShotImageClassification.ts +7 -12
- package/src/tasks/index.ts +6 -6
- package/src/tasks/multimodal/documentQuestionAnswering.ts +10 -26
- package/src/tasks/multimodal/visualQuestionAnswering.ts +6 -12
- package/src/tasks/nlp/chatCompletion.ts +7 -23
- package/src/tasks/nlp/chatCompletionStream.ts +4 -5
- package/src/tasks/nlp/featureExtraction.ts +5 -20
- package/src/tasks/nlp/fillMask.ts +5 -18
- package/src/tasks/nlp/questionAnswering.ts +5 -23
- package/src/tasks/nlp/sentenceSimilarity.ts +5 -18
- package/src/tasks/nlp/summarization.ts +5 -8
- package/src/tasks/nlp/tableQuestionAnswering.ts +5 -29
- package/src/tasks/nlp/textClassification.ts +8 -14
- package/src/tasks/nlp/textGeneration.ts +13 -80
- package/src/tasks/nlp/textGenerationStream.ts +2 -2
- package/src/tasks/nlp/tokenClassification.ts +8 -24
- package/src/tasks/nlp/translation.ts +5 -8
- package/src/tasks/nlp/zeroShotClassification.ts +8 -22
- package/src/tasks/tabular/tabularClassification.ts +5 -8
- package/src/tasks/tabular/tabularRegression.ts +5 -8
- package/src/types.ts +11 -14
- package/src/utils/request.ts +161 -0
|
@@ -14,33 +14,84 @@
|
|
|
14
14
|
*
|
|
15
15
|
* Thanks!
|
|
16
16
|
*/
|
|
17
|
-
import
|
|
17
|
+
import { InferenceOutputError } from "../lib/InferenceOutputError";
|
|
18
|
+
import type { BodyParams, HeaderParams, UrlParams } from "../types";
|
|
19
|
+
import { delay } from "../utils/delay";
|
|
20
|
+
import { omit } from "../utils/omit";
|
|
21
|
+
import { TaskProviderHelper, type TextToImageTaskHelper } from "./providerHelper";
|
|
18
22
|
|
|
19
23
|
const BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
|
|
24
|
+
interface BlackForestLabsResponse {
|
|
25
|
+
id: string;
|
|
26
|
+
polling_url: string;
|
|
27
|
+
}
|
|
20
28
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
29
|
+
export class BlackForestLabsTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper {
|
|
30
|
+
constructor() {
|
|
31
|
+
super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
|
|
32
|
+
}
|
|
24
33
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
34
|
+
preparePayload(params: BodyParams): Record<string, unknown> {
|
|
35
|
+
return {
|
|
36
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
37
|
+
...(params.args.parameters as Record<string, unknown>),
|
|
38
|
+
prompt: params.args.inputs,
|
|
39
|
+
};
|
|
40
|
+
}
|
|
28
41
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
42
|
+
override prepareHeaders(params: HeaderParams, binary: boolean): Record<string, string> {
|
|
43
|
+
const headers: Record<string, string> = {
|
|
44
|
+
Authorization:
|
|
45
|
+
params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`,
|
|
46
|
+
};
|
|
47
|
+
if (!binary) {
|
|
48
|
+
headers["Content-Type"] = "application/json";
|
|
49
|
+
}
|
|
50
|
+
return headers;
|
|
34
51
|
}
|
|
35
|
-
};
|
|
36
52
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
53
|
+
makeRoute(params: UrlParams): string {
|
|
54
|
+
if (!params) {
|
|
55
|
+
throw new Error("Params are required");
|
|
56
|
+
}
|
|
57
|
+
return `/v1/${params.model}`;
|
|
58
|
+
}
|
|
40
59
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
60
|
+
async getResponse(
|
|
61
|
+
response: BlackForestLabsResponse,
|
|
62
|
+
url?: string,
|
|
63
|
+
headers?: HeadersInit,
|
|
64
|
+
outputType?: "url" | "blob"
|
|
65
|
+
): Promise<string | Blob> {
|
|
66
|
+
const urlObj = new URL(response.polling_url);
|
|
67
|
+
for (let step = 0; step < 5; step++) {
|
|
68
|
+
await delay(1000);
|
|
69
|
+
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
|
|
70
|
+
urlObj.searchParams.set("attempt", step.toString(10));
|
|
71
|
+
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
|
|
72
|
+
if (!resp.ok) {
|
|
73
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
74
|
+
}
|
|
75
|
+
const payload = await resp.json();
|
|
76
|
+
if (
|
|
77
|
+
typeof payload === "object" &&
|
|
78
|
+
payload &&
|
|
79
|
+
"status" in payload &&
|
|
80
|
+
typeof payload.status === "string" &&
|
|
81
|
+
payload.status === "Ready" &&
|
|
82
|
+
"result" in payload &&
|
|
83
|
+
typeof payload.result === "object" &&
|
|
84
|
+
payload.result &&
|
|
85
|
+
"sample" in payload.result &&
|
|
86
|
+
typeof payload.result.sample === "string"
|
|
87
|
+
) {
|
|
88
|
+
if (outputType === "url") {
|
|
89
|
+
return payload.result.sample;
|
|
90
|
+
}
|
|
91
|
+
const image = await fetch(payload.result.sample);
|
|
92
|
+
return await image.blob();
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
96
|
+
}
|
|
97
|
+
}
|
|
@@ -14,32 +14,11 @@
|
|
|
14
14
|
*
|
|
15
15
|
* Thanks!
|
|
16
16
|
*/
|
|
17
|
-
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
|
|
18
17
|
|
|
19
|
-
|
|
18
|
+
import { BaseConversationalTask } from "./providerHelper";
|
|
20
19
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
return {
|
|
27
|
-
...params.args,
|
|
28
|
-
model: params.model,
|
|
29
|
-
};
|
|
30
|
-
};
|
|
31
|
-
|
|
32
|
-
const makeHeaders = (params: HeaderParams): Record<string, string> => {
|
|
33
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
34
|
-
};
|
|
35
|
-
|
|
36
|
-
const makeUrl = (params: UrlParams): string => {
|
|
37
|
-
return `${params.baseUrl}/v1/chat/completions`;
|
|
38
|
-
};
|
|
39
|
-
|
|
40
|
-
export const CEREBRAS_CONFIG: ProviderConfig = {
|
|
41
|
-
makeBaseUrl,
|
|
42
|
-
makeBody,
|
|
43
|
-
makeHeaders,
|
|
44
|
-
makeUrl,
|
|
45
|
-
};
|
|
20
|
+
export class CerebrasConversationalTask extends BaseConversationalTask {
|
|
21
|
+
constructor() {
|
|
22
|
+
super("cerebras", "https://api.cerebras.ai");
|
|
23
|
+
}
|
|
24
|
+
}
|
package/src/providers/cohere.ts
CHANGED
|
@@ -14,32 +14,13 @@
|
|
|
14
14
|
*
|
|
15
15
|
* Thanks!
|
|
16
16
|
*/
|
|
17
|
-
import
|
|
17
|
+
import { BaseConversationalTask } from "./providerHelper";
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
...params.args,
|
|
28
|
-
model: params.model,
|
|
29
|
-
};
|
|
30
|
-
};
|
|
31
|
-
|
|
32
|
-
const makeHeaders = (params: HeaderParams): Record<string, string> => {
|
|
33
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
34
|
-
};
|
|
35
|
-
|
|
36
|
-
const makeUrl = (params: UrlParams): string => {
|
|
37
|
-
return `${params.baseUrl}/compatibility/v1/chat/completions`;
|
|
38
|
-
};
|
|
39
|
-
|
|
40
|
-
export const COHERE_CONFIG: ProviderConfig = {
|
|
41
|
-
makeBaseUrl,
|
|
42
|
-
makeBody,
|
|
43
|
-
makeHeaders,
|
|
44
|
-
makeUrl,
|
|
45
|
-
};
|
|
19
|
+
export class CohereConversationalTask extends BaseConversationalTask {
|
|
20
|
+
constructor() {
|
|
21
|
+
super("cohere", "https://api.cohere.com");
|
|
22
|
+
}
|
|
23
|
+
override makeRoute(): string {
|
|
24
|
+
return "/compatibility/v1/chat/completions";
|
|
25
|
+
}
|
|
26
|
+
}
|
package/src/providers/fal-ai.ts
CHANGED
|
@@ -14,109 +14,227 @@
|
|
|
14
14
|
*
|
|
15
15
|
* Thanks!
|
|
16
16
|
*/
|
|
17
|
+
import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
|
|
17
18
|
import { InferenceOutputError } from "../lib/InferenceOutputError";
|
|
18
19
|
import { isUrl } from "../lib/isUrl";
|
|
19
|
-
import type { BodyParams, HeaderParams,
|
|
20
|
+
import type { BodyParams, HeaderParams, UrlParams } from "../types";
|
|
20
21
|
import { delay } from "../utils/delay";
|
|
22
|
+
import { omit } from "../utils/omit";
|
|
23
|
+
import {
|
|
24
|
+
type AutomaticSpeechRecognitionTaskHelper,
|
|
25
|
+
TaskProviderHelper,
|
|
26
|
+
type TextToImageTaskHelper,
|
|
27
|
+
type TextToVideoTaskHelper,
|
|
28
|
+
} from "./providerHelper";
|
|
21
29
|
|
|
22
|
-
|
|
23
|
-
|
|
30
|
+
export interface FalAiQueueOutput {
|
|
31
|
+
request_id: string;
|
|
32
|
+
status: string;
|
|
33
|
+
response_url: string;
|
|
34
|
+
}
|
|
24
35
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
36
|
+
interface FalAITextToImageOutput {
|
|
37
|
+
images: Array<{
|
|
38
|
+
url: string;
|
|
39
|
+
}>;
|
|
40
|
+
}
|
|
28
41
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
}
|
|
42
|
+
interface FalAIAutomaticSpeechRecognitionOutput {
|
|
43
|
+
text: string;
|
|
44
|
+
}
|
|
32
45
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
46
|
+
interface FalAITextToSpeechOutput {
|
|
47
|
+
audio: {
|
|
48
|
+
url: string;
|
|
49
|
+
content_type: string;
|
|
36
50
|
};
|
|
37
|
-
}
|
|
51
|
+
}
|
|
52
|
+
export const FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
|
|
38
53
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
return `${baseUrl}?_subdomain=queue`;
|
|
54
|
+
abstract class FalAITask extends TaskProviderHelper {
|
|
55
|
+
constructor(url?: string) {
|
|
56
|
+
super("fal-ai", url || "https://fal.run");
|
|
43
57
|
}
|
|
44
|
-
return baseUrl;
|
|
45
|
-
};
|
|
46
58
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
}
|
|
59
|
+
preparePayload(params: BodyParams): Record<string, unknown> {
|
|
60
|
+
return params.args;
|
|
61
|
+
}
|
|
62
|
+
makeRoute(params: UrlParams): string {
|
|
63
|
+
return `/${params.model}`;
|
|
64
|
+
}
|
|
65
|
+
override prepareHeaders(params: HeaderParams, binary: boolean): Record<string, string> {
|
|
66
|
+
const headers: Record<string, string> = {
|
|
67
|
+
Authorization:
|
|
68
|
+
params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`,
|
|
69
|
+
};
|
|
70
|
+
if (!binary) {
|
|
71
|
+
headers["Content-Type"] = "application/json";
|
|
72
|
+
}
|
|
73
|
+
return headers;
|
|
74
|
+
}
|
|
75
|
+
}
|
|
53
76
|
|
|
54
|
-
export
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
77
|
+
export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHelper {
|
|
78
|
+
override preparePayload(params: BodyParams): Record<string, unknown> {
|
|
79
|
+
return {
|
|
80
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
81
|
+
...(params.args.parameters as Record<string, unknown>),
|
|
82
|
+
sync_mode: true,
|
|
83
|
+
prompt: params.args.inputs,
|
|
84
|
+
};
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
override async getResponse(response: FalAITextToImageOutput, outputType?: "url" | "blob"): Promise<string | Blob> {
|
|
88
|
+
if (
|
|
89
|
+
typeof response === "object" &&
|
|
90
|
+
"images" in response &&
|
|
91
|
+
Array.isArray(response.images) &&
|
|
92
|
+
response.images.length > 0 &&
|
|
93
|
+
"url" in response.images[0] &&
|
|
94
|
+
typeof response.images[0].url === "string"
|
|
95
|
+
) {
|
|
96
|
+
if (outputType === "url") {
|
|
97
|
+
return response.images[0].url;
|
|
98
|
+
}
|
|
99
|
+
const urlResponse = await fetch(response.images[0].url);
|
|
100
|
+
return await urlResponse.blob();
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
|
|
104
|
+
}
|
|
58
105
|
}
|
|
59
106
|
|
|
60
|
-
export
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
):
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
107
|
+
export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHelper {
|
|
108
|
+
constructor() {
|
|
109
|
+
super("https://queue.fal.run");
|
|
110
|
+
}
|
|
111
|
+
override makeRoute(params: UrlParams): string {
|
|
112
|
+
if (params.authMethod !== "provider-key") {
|
|
113
|
+
return `/${params.model}?_subdomain=queue`;
|
|
114
|
+
}
|
|
115
|
+
return `/${params.model}`;
|
|
68
116
|
}
|
|
69
|
-
|
|
117
|
+
override preparePayload(params: BodyParams): Record<string, unknown> {
|
|
118
|
+
return {
|
|
119
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
120
|
+
...(params.args.parameters as Record<string, unknown>),
|
|
121
|
+
prompt: params.args.inputs,
|
|
122
|
+
};
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
override async getResponse(
|
|
126
|
+
response: FalAiQueueOutput,
|
|
127
|
+
url?: string,
|
|
128
|
+
headers?: Record<string, string>
|
|
129
|
+
): Promise<Blob> {
|
|
130
|
+
if (!url || !headers) {
|
|
131
|
+
throw new InferenceOutputError("URL and headers are required for text-to-video task");
|
|
132
|
+
}
|
|
133
|
+
const requestId = response.request_id;
|
|
134
|
+
if (!requestId) {
|
|
135
|
+
throw new InferenceOutputError("No request ID found in the response");
|
|
136
|
+
}
|
|
137
|
+
let status = response.status;
|
|
70
138
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
139
|
+
const parsedUrl = new URL(url);
|
|
140
|
+
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${
|
|
141
|
+
parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
|
|
142
|
+
}`;
|
|
75
143
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
144
|
+
// extracting the provider model id for status and result urls
|
|
145
|
+
// from the response as it might be different from the mapped model in `url`
|
|
146
|
+
const modelId = new URL(response.response_url).pathname;
|
|
147
|
+
const queryParams = parsedUrl.search;
|
|
80
148
|
|
|
81
|
-
|
|
82
|
-
|
|
149
|
+
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
|
|
150
|
+
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
|
|
83
151
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
152
|
+
while (status !== "COMPLETED") {
|
|
153
|
+
await delay(500);
|
|
154
|
+
const statusResponse = await fetch(statusUrl, { headers });
|
|
87
155
|
|
|
88
|
-
|
|
89
|
-
|
|
156
|
+
if (!statusResponse.ok) {
|
|
157
|
+
throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
|
|
158
|
+
}
|
|
159
|
+
try {
|
|
160
|
+
status = (await statusResponse.json()).status;
|
|
161
|
+
} catch (error) {
|
|
162
|
+
throw new InferenceOutputError("Failed to parse status response from fal-ai API");
|
|
163
|
+
}
|
|
90
164
|
}
|
|
165
|
+
|
|
166
|
+
const resultResponse = await fetch(resultUrl, { headers });
|
|
167
|
+
let result: unknown;
|
|
91
168
|
try {
|
|
92
|
-
|
|
169
|
+
result = await resultResponse.json();
|
|
93
170
|
} catch (error) {
|
|
94
|
-
throw new InferenceOutputError("Failed to parse
|
|
171
|
+
throw new InferenceOutputError("Failed to parse result response from fal-ai API");
|
|
172
|
+
}
|
|
173
|
+
if (
|
|
174
|
+
typeof result === "object" &&
|
|
175
|
+
!!result &&
|
|
176
|
+
"video" in result &&
|
|
177
|
+
typeof result.video === "object" &&
|
|
178
|
+
!!result.video &&
|
|
179
|
+
"url" in result.video &&
|
|
180
|
+
typeof result.video.url === "string" &&
|
|
181
|
+
isUrl(result.video.url)
|
|
182
|
+
) {
|
|
183
|
+
const urlResponse = await fetch(result.video.url);
|
|
184
|
+
return await urlResponse.blob();
|
|
185
|
+
} else {
|
|
186
|
+
throw new InferenceOutputError(
|
|
187
|
+
"Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
|
|
188
|
+
);
|
|
189
|
+
}
|
|
190
|
+
}
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
export class FalAIAutomaticSpeechRecognitionTask extends FalAITask implements AutomaticSpeechRecognitionTaskHelper {
|
|
194
|
+
override prepareHeaders(params: HeaderParams, binary: boolean): Record<string, string> {
|
|
195
|
+
const headers = super.prepareHeaders(params, binary);
|
|
196
|
+
headers["Content-Type"] = "application/json";
|
|
197
|
+
return headers;
|
|
198
|
+
}
|
|
199
|
+
override async getResponse(response: unknown): Promise<AutomaticSpeechRecognitionOutput> {
|
|
200
|
+
const res = response as FalAIAutomaticSpeechRecognitionOutput;
|
|
201
|
+
if (typeof res?.text !== "string") {
|
|
202
|
+
throw new InferenceOutputError(
|
|
203
|
+
`Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
|
|
204
|
+
);
|
|
95
205
|
}
|
|
206
|
+
return { text: res.text };
|
|
96
207
|
}
|
|
208
|
+
}
|
|
97
209
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
210
|
+
export class FalAITextToSpeechTask extends FalAITask {
|
|
211
|
+
override preparePayload(params: BodyParams): Record<string, unknown> {
|
|
212
|
+
return {
|
|
213
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
214
|
+
...(params.args.parameters as Record<string, unknown>),
|
|
215
|
+
lyrics: params.args.inputs,
|
|
216
|
+
};
|
|
104
217
|
}
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
"
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
218
|
+
|
|
219
|
+
override async getResponse(response: unknown): Promise<Blob> {
|
|
220
|
+
const res = response as FalAITextToSpeechOutput;
|
|
221
|
+
if (typeof res?.audio?.url !== "string") {
|
|
222
|
+
throw new InferenceOutputError(
|
|
223
|
+
`Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
|
|
224
|
+
);
|
|
225
|
+
}
|
|
226
|
+
try {
|
|
227
|
+
const urlResponse = await fetch(res.audio.url);
|
|
228
|
+
if (!urlResponse.ok) {
|
|
229
|
+
throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
|
|
230
|
+
}
|
|
231
|
+
return await urlResponse.blob();
|
|
232
|
+
} catch (error) {
|
|
233
|
+
throw new InferenceOutputError(
|
|
234
|
+
`Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${
|
|
235
|
+
error instanceof Error ? error.message : String(error)
|
|
236
|
+
}`
|
|
237
|
+
);
|
|
238
|
+
}
|
|
121
239
|
}
|
|
122
240
|
}
|
|
@@ -14,35 +14,14 @@
|
|
|
14
14
|
*
|
|
15
15
|
* Thanks!
|
|
16
16
|
*/
|
|
17
|
-
import
|
|
17
|
+
import { BaseConversationalTask } from "./providerHelper";
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
return FIREWORKS_AI_API_BASE_URL;
|
|
23
|
-
};
|
|
24
|
-
|
|
25
|
-
const makeBody = (params: BodyParams): Record<string, unknown> => {
|
|
26
|
-
return {
|
|
27
|
-
...params.args,
|
|
28
|
-
...(params.chatCompletion ? { model: params.model } : undefined),
|
|
29
|
-
};
|
|
30
|
-
};
|
|
31
|
-
|
|
32
|
-
const makeHeaders = (params: HeaderParams): Record<string, string> => {
|
|
33
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
34
|
-
};
|
|
35
|
-
|
|
36
|
-
const makeUrl = (params: UrlParams): string => {
|
|
37
|
-
if (params.chatCompletion) {
|
|
38
|
-
return `${params.baseUrl}/inference/v1/chat/completions`;
|
|
19
|
+
export class FireworksConversationalTask extends BaseConversationalTask {
|
|
20
|
+
constructor() {
|
|
21
|
+
super("fireworks-ai", "https://api.fireworks.ai");
|
|
39
22
|
}
|
|
40
|
-
return `${params.baseUrl}/inference`;
|
|
41
|
-
};
|
|
42
23
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
makeUrl,
|
|
48
|
-
};
|
|
24
|
+
override makeRoute(): string {
|
|
25
|
+
return "/inference/v1/chat/completions";
|
|
26
|
+
}
|
|
27
|
+
}
|