@huggingface/inference 3.7.0 → 3.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. package/dist/index.cjs +1152 -839
  2. package/dist/index.js +1154 -841
  3. package/dist/src/lib/getProviderHelper.d.ts +37 -0
  4. package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
  5. package/dist/src/lib/makeRequestOptions.d.ts +0 -2
  6. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  7. package/dist/src/providers/black-forest-labs.d.ts +14 -18
  8. package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
  9. package/dist/src/providers/cerebras.d.ts +4 -2
  10. package/dist/src/providers/cerebras.d.ts.map +1 -1
  11. package/dist/src/providers/cohere.d.ts +5 -2
  12. package/dist/src/providers/cohere.d.ts.map +1 -1
  13. package/dist/src/providers/fal-ai.d.ts +50 -3
  14. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  15. package/dist/src/providers/fireworks-ai.d.ts +5 -2
  16. package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
  17. package/dist/src/providers/hf-inference.d.ts +125 -2
  18. package/dist/src/providers/hf-inference.d.ts.map +1 -1
  19. package/dist/src/providers/hyperbolic.d.ts +31 -2
  20. package/dist/src/providers/hyperbolic.d.ts.map +1 -1
  21. package/dist/src/providers/nebius.d.ts +20 -18
  22. package/dist/src/providers/nebius.d.ts.map +1 -1
  23. package/dist/src/providers/novita.d.ts +21 -18
  24. package/dist/src/providers/novita.d.ts.map +1 -1
  25. package/dist/src/providers/openai.d.ts +4 -2
  26. package/dist/src/providers/openai.d.ts.map +1 -1
  27. package/dist/src/providers/providerHelper.d.ts +182 -0
  28. package/dist/src/providers/providerHelper.d.ts.map +1 -0
  29. package/dist/src/providers/replicate.d.ts +23 -19
  30. package/dist/src/providers/replicate.d.ts.map +1 -1
  31. package/dist/src/providers/sambanova.d.ts +4 -2
  32. package/dist/src/providers/sambanova.d.ts.map +1 -1
  33. package/dist/src/providers/together.d.ts +32 -2
  34. package/dist/src/providers/together.d.ts.map +1 -1
  35. package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
  36. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  37. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  38. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  39. package/dist/src/tasks/audio/utils.d.ts +2 -1
  40. package/dist/src/tasks/audio/utils.d.ts.map +1 -1
  41. package/dist/src/tasks/custom/request.d.ts +0 -2
  42. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  43. package/dist/src/tasks/custom/streamingRequest.d.ts +0 -2
  44. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  45. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  46. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  47. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  48. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  49. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  50. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  51. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
  52. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  53. package/dist/src/tasks/index.d.ts +6 -6
  54. package/dist/src/tasks/index.d.ts.map +1 -1
  55. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  56. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  57. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  58. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
  59. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  60. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  61. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  62. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  63. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  64. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  65. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  66. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  67. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  68. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  69. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  70. package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
  71. package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
  72. package/dist/src/types.d.ts +3 -13
  73. package/dist/src/types.d.ts.map +1 -1
  74. package/package.json +3 -3
  75. package/src/lib/getProviderHelper.ts +270 -0
  76. package/src/lib/makeRequestOptions.ts +34 -91
  77. package/src/providers/black-forest-labs.ts +73 -22
  78. package/src/providers/cerebras.ts +6 -27
  79. package/src/providers/cohere.ts +9 -28
  80. package/src/providers/fal-ai.ts +195 -77
  81. package/src/providers/fireworks-ai.ts +8 -29
  82. package/src/providers/hf-inference.ts +555 -34
  83. package/src/providers/hyperbolic.ts +107 -29
  84. package/src/providers/nebius.ts +65 -29
  85. package/src/providers/novita.ts +68 -32
  86. package/src/providers/openai.ts +6 -32
  87. package/src/providers/providerHelper.ts +354 -0
  88. package/src/providers/replicate.ts +124 -34
  89. package/src/providers/sambanova.ts +5 -30
  90. package/src/providers/together.ts +92 -28
  91. package/src/snippets/getInferenceSnippets.ts +16 -9
  92. package/src/snippets/templates.exported.ts +1 -1
  93. package/src/tasks/audio/audioClassification.ts +4 -7
  94. package/src/tasks/audio/audioToAudio.ts +3 -26
  95. package/src/tasks/audio/automaticSpeechRecognition.ts +4 -3
  96. package/src/tasks/audio/textToSpeech.ts +5 -29
  97. package/src/tasks/audio/utils.ts +2 -1
  98. package/src/tasks/custom/request.ts +0 -2
  99. package/src/tasks/custom/streamingRequest.ts +0 -2
  100. package/src/tasks/cv/imageClassification.ts +3 -7
  101. package/src/tasks/cv/imageSegmentation.ts +3 -8
  102. package/src/tasks/cv/imageToImage.ts +3 -6
  103. package/src/tasks/cv/imageToText.ts +3 -6
  104. package/src/tasks/cv/objectDetection.ts +3 -18
  105. package/src/tasks/cv/textToImage.ts +9 -137
  106. package/src/tasks/cv/textToVideo.ts +11 -62
  107. package/src/tasks/cv/zeroShotImageClassification.ts +3 -7
  108. package/src/tasks/index.ts +6 -6
  109. package/src/tasks/multimodal/documentQuestionAnswering.ts +3 -19
  110. package/src/tasks/multimodal/visualQuestionAnswering.ts +3 -11
  111. package/src/tasks/nlp/chatCompletion.ts +5 -20
  112. package/src/tasks/nlp/chatCompletionStream.ts +1 -2
  113. package/src/tasks/nlp/featureExtraction.ts +3 -18
  114. package/src/tasks/nlp/fillMask.ts +3 -16
  115. package/src/tasks/nlp/questionAnswering.ts +3 -22
  116. package/src/tasks/nlp/sentenceSimilarity.ts +3 -7
  117. package/src/tasks/nlp/summarization.ts +3 -6
  118. package/src/tasks/nlp/tableQuestionAnswering.ts +3 -27
  119. package/src/tasks/nlp/textClassification.ts +3 -8
  120. package/src/tasks/nlp/textGeneration.ts +12 -79
  121. package/src/tasks/nlp/tokenClassification.ts +3 -18
  122. package/src/tasks/nlp/translation.ts +3 -6
  123. package/src/tasks/nlp/zeroShotClassification.ts +3 -16
  124. package/src/tasks/tabular/tabularClassification.ts +3 -6
  125. package/src/tasks/tabular/tabularRegression.ts +3 -6
  126. package/src/types.ts +3 -14
@@ -1,15 +1,15 @@
1
- import type { PipelineType, WidgetType } from "@huggingface/tasks/src/pipelines.js";
2
- import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingface/tasks/src/tasks/index.js";
1
+ import { Template } from "@huggingface/jinja";
3
2
  import {
4
3
  type InferenceSnippet,
5
4
  type InferenceSnippetLanguage,
6
5
  type ModelDataMinimal,
7
- inferenceSnippetLanguages,
8
6
  getModelInputSnippet,
7
+ inferenceSnippetLanguages,
9
8
  } from "@huggingface/tasks";
10
- import type { InferenceProvider, InferenceTask, RequestArgs } from "../types";
11
- import { Template } from "@huggingface/jinja";
9
+ import type { PipelineType, WidgetType } from "@huggingface/tasks/src/pipelines.js";
10
+ import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingface/tasks/src/tasks/index.js";
12
11
  import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions";
12
+ import type { InferenceProvider, InferenceTask, RequestArgs } from "../types";
13
13
  import { templates } from "./templates.exported";
14
14
 
15
15
  const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const;
@@ -120,6 +120,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
120
120
  opts?: Record<string, unknown>
121
121
  ): InferenceSnippet[] => {
122
122
  /// Hacky: hard-code conversational templates here
123
+ let task = model.pipeline_tag as InferenceTask;
123
124
  if (
124
125
  model.pipeline_tag &&
125
126
  ["text-generation", "image-text-to-text"].includes(model.pipeline_tag) &&
@@ -127,14 +128,20 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
127
128
  ) {
128
129
  templateName = opts?.streaming ? "conversationalStream" : "conversational";
129
130
  inputPreparationFn = prepareConversationalInput;
131
+ task = "conversational";
130
132
  }
131
-
132
133
  /// Prepare inputs + make request
133
134
  const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
134
135
  const request = makeRequestOptionsFromResolvedModel(
135
136
  providerModelId ?? model.id,
136
- { accessToken: accessToken, provider: provider, ...inputs } as RequestArgs,
137
- { chatCompletion: templateName.includes("conversational"), task: model.pipeline_tag as InferenceTask }
137
+ {
138
+ accessToken: accessToken,
139
+ provider: provider,
140
+ ...inputs,
141
+ } as RequestArgs,
142
+ {
143
+ task: task,
144
+ }
138
145
  );
139
146
 
140
147
  /// Parse request.info.body if not a binary.
@@ -247,7 +254,7 @@ const prepareConversationalInput = (
247
254
  return {
248
255
  messages: opts?.messages ?? getModelInputSnippet(model),
249
256
  ...(opts?.temperature ? { temperature: opts?.temperature } : undefined),
250
- max_tokens: opts?.max_tokens ?? 500,
257
+ max_tokens: opts?.max_tokens ?? 512,
251
258
  ...(opts?.top_p ? { top_p: opts?.top_p } : undefined),
252
259
  };
253
260
  };
@@ -20,7 +20,7 @@ export const templates: Record<string, Record<string, Record<string, string>>> =
20
20
  },
21
21
  "openai": {
22
22
  "conversational": "import { OpenAI } from \"openai\";\n\nconst client = new OpenAI({\n\tbaseURL: \"{{ baseUrl }}\",\n\tapiKey: \"{{ accessToken }}\",\n});\n\nconst chatCompletion = await client.chat.completions.create({\n\tmodel: \"{{ providerModelId }}\",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);",
23
- "conversationalStream": "import { OpenAI } from \"openai\";\n\nconst client = new OpenAI({\n\tbaseURL: \"{{ baseUrl }}\",\n\tapiKey: \"{{ accessToken }}\",\n});\n\nlet out = \"\";\n\nconst stream = await client.chat.completions.create({\n provider: \"{{ provider }}\",\n model: \"{{ model.id }}\",\n{{ inputs.asTsString }}\n});\n\nfor await (const chunk of stream) {\n\tif (chunk.choices && chunk.choices.length > 0) {\n\t\tconst newContent = chunk.choices[0].delta.content;\n\t\tout += newContent;\n\t\tconsole.log(newContent);\n\t} \n}"
23
+ "conversationalStream": "import { OpenAI } from \"openai\";\n\nconst client = new OpenAI({\n\tbaseURL: \"{{ baseUrl }}\",\n\tapiKey: \"{{ accessToken }}\",\n});\n\nconst stream = await client.chat.completions.create({\n model: \"{{ providerModelId }}\",\n{{ inputs.asTsString }}\n stream: true,\n});\n\nfor await (const chunk of stream) {\n process.stdout.write(chunk.choices[0]?.delta?.content || \"\");\n}"
24
24
  }
25
25
  },
26
26
  "python": {
@@ -1,5 +1,5 @@
1
1
  import type { AudioClassificationInput, AudioClassificationOutput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
3
  import type { BaseArgs, Options } from "../../types";
4
4
  import { innerRequest } from "../../utils/request";
5
5
  import type { LegacyAudioInput } from "./utils";
@@ -15,15 +15,12 @@ export async function audioClassification(
15
15
  args: AudioClassificationArgs,
16
16
  options?: Options
17
17
  ): Promise<AudioClassificationOutput> {
18
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
18
19
  const payload = preparePayload(args);
19
20
  const { data: res } = await innerRequest<AudioClassificationOutput>(payload, {
20
21
  ...options,
21
22
  task: "audio-classification",
22
23
  });
23
- const isValidOutput =
24
- Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
25
- if (!isValidOutput) {
26
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
27
- }
28
- return res;
24
+
25
+ return providerHelper.getResponse(res);
29
26
  }
@@ -1,4 +1,4 @@
1
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
1
+ import { getProviderHelper } from "../../lib/getProviderHelper";
2
2
  import type { BaseArgs, Options } from "../../types";
3
3
  import { innerRequest } from "../../utils/request";
4
4
  import type { LegacyAudioInput } from "./utils";
@@ -36,34 +36,11 @@ export interface AudioToAudioOutput {
36
36
  * Example model: speechbrain/sepformer-wham does audio source separation.
37
37
  */
38
38
  export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
39
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
39
40
  const payload = preparePayload(args);
40
41
  const { data: res } = await innerRequest<AudioToAudioOutput>(payload, {
41
42
  ...options,
42
43
  task: "audio-to-audio",
43
44
  });
44
-
45
- return validateOutput(res);
46
- }
47
-
48
- function validateOutput(output: unknown): AudioToAudioOutput[] {
49
- if (!Array.isArray(output)) {
50
- throw new InferenceOutputError("Expected Array");
51
- }
52
- if (
53
- !output.every((elem): elem is AudioToAudioOutput => {
54
- return (
55
- typeof elem === "object" &&
56
- elem &&
57
- "label" in elem &&
58
- typeof elem.label === "string" &&
59
- "content-type" in elem &&
60
- typeof elem["content-type"] === "string" &&
61
- "blob" in elem &&
62
- typeof elem.blob === "string"
63
- );
64
- })
65
- ) {
66
- throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
67
- }
68
- return output;
45
+ return providerHelper.getResponse(res);
69
46
  }
@@ -1,5 +1,7 @@
1
1
  import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
2
3
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
4
+ import { FAL_AI_SUPPORTED_BLOB_TYPES } from "../../providers/fal-ai";
3
5
  import type { BaseArgs, Options, RequestArgs } from "../../types";
4
6
  import { base64FromBytes } from "../../utils/base64FromBytes";
5
7
  import { omit } from "../../utils/omit";
@@ -16,6 +18,7 @@ export async function automaticSpeechRecognition(
16
18
  args: AutomaticSpeechRecognitionArgs,
17
19
  options?: Options
18
20
  ): Promise<AutomaticSpeechRecognitionOutput> {
21
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
19
22
  const payload = await buildPayload(args);
20
23
  const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, {
21
24
  ...options,
@@ -25,11 +28,9 @@ export async function automaticSpeechRecognition(
25
28
  if (!isValidOutput) {
26
29
  throw new InferenceOutputError("Expected {text: string}");
27
30
  }
28
- return res;
31
+ return providerHelper.getResponse(res);
29
32
  }
30
33
 
31
- const FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
32
-
33
34
  async function buildPayload(args: AutomaticSpeechRecognitionArgs): Promise<RequestArgs> {
34
35
  if (args.provider === "fal-ai") {
35
36
  const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : undefined;
@@ -1,7 +1,6 @@
1
1
  import type { TextToSpeechInput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
3
  import type { BaseArgs, Options } from "../../types";
4
- import { omit } from "../../utils/omit";
5
4
  import { innerRequest } from "../../utils/request";
6
5
  type TextToSpeechArgs = BaseArgs & TextToSpeechInput;
7
6
 
@@ -13,34 +12,11 @@ interface OutputUrlTextToSpeechGeneration {
13
12
  * Recommended model: espnet/kan-bayashi_ljspeech_vits
14
13
  */
15
14
  export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
16
- // Replicate models expects "text" instead of "inputs"
17
- const payload =
18
- args.provider === "replicate"
19
- ? {
20
- ...omit(args, ["inputs", "parameters"]),
21
- ...args.parameters,
22
- text: args.inputs,
23
- }
24
- : args;
25
- const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(payload, {
15
+ const provider = args.provider ?? "hf-inference";
16
+ const providerHelper = getProviderHelper(provider, "text-to-speech");
17
+ const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, {
26
18
  ...options,
27
19
  task: "text-to-speech",
28
20
  });
29
- if (res instanceof Blob) {
30
- return res;
31
- }
32
- if (res && typeof res === "object") {
33
- if ("output" in res) {
34
- if (typeof res.output === "string") {
35
- const urlResponse = await fetch(res.output);
36
- const blob = await urlResponse.blob();
37
- return blob;
38
- } else if (Array.isArray(res.output)) {
39
- const urlResponse = await fetch(res.output[0]);
40
- const blob = await urlResponse.blob();
41
- return blob;
42
- }
43
- }
44
- }
45
- throw new InferenceOutputError("Expected Blob or object with output");
21
+ return providerHelper.getResponse(res);
46
22
  }
@@ -1,4 +1,4 @@
1
- import type { BaseArgs, RequestArgs } from "../../types";
1
+ import type { BaseArgs, InferenceProvider, RequestArgs } from "../../types";
2
2
  import { omit } from "../../utils/omit";
3
3
 
4
4
  /**
@@ -6,6 +6,7 @@ import { omit } from "../../utils/omit";
6
6
  */
7
7
  export interface LegacyAudioInput {
8
8
  data: Blob | ArrayBuffer;
9
+ provider?: InferenceProvider;
9
10
  }
10
11
 
11
12
  export function preparePayload(args: BaseArgs & ({ inputs: Blob } | LegacyAudioInput)): RequestArgs {
@@ -10,8 +10,6 @@ export async function request<T>(
10
10
  options?: Options & {
11
11
  /** In most cases (unless we pass a endpointUrl) we know the task */
12
12
  task?: InferenceTask;
13
- /** Is chat completion compatible */
14
- chatCompletion?: boolean;
15
13
  }
16
14
  ): Promise<T> {
17
15
  console.warn(
@@ -9,8 +9,6 @@ export async function* streamingRequest<T>(
9
9
  options?: Options & {
10
10
  /** In most cases (unless we pass a endpointUrl) we know the task */
11
11
  task?: InferenceTask;
12
- /** Is chat completion compatible */
13
- chatCompletion?: boolean;
14
12
  }
15
13
  ): AsyncGenerator<T> {
16
14
  console.warn(
@@ -1,5 +1,5 @@
1
1
  import type { ImageClassificationInput, ImageClassificationOutput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
3
  import type { BaseArgs, Options } from "../../types";
4
4
  import { innerRequest } from "../../utils/request";
5
5
  import { preparePayload, type LegacyImageInput } from "./utils";
@@ -14,15 +14,11 @@ export async function imageClassification(
14
14
  args: ImageClassificationArgs,
15
15
  options?: Options
16
16
  ): Promise<ImageClassificationOutput> {
17
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
17
18
  const payload = preparePayload(args);
18
19
  const { data: res } = await innerRequest<ImageClassificationOutput>(payload, {
19
20
  ...options,
20
21
  task: "image-classification",
21
22
  });
22
- const isValidOutput =
23
- Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
24
- if (!isValidOutput) {
25
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
26
- }
27
- return res;
23
+ return providerHelper.getResponse(res);
28
24
  }
@@ -1,5 +1,5 @@
1
1
  import type { ImageSegmentationInput, ImageSegmentationOutput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
3
  import type { BaseArgs, Options } from "../../types";
4
4
  import { innerRequest } from "../../utils/request";
5
5
  import { preparePayload, type LegacyImageInput } from "./utils";
@@ -14,16 +14,11 @@ export async function imageSegmentation(
14
14
  args: ImageSegmentationArgs,
15
15
  options?: Options
16
16
  ): Promise<ImageSegmentationOutput> {
17
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
17
18
  const payload = preparePayload(args);
18
19
  const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, {
19
20
  ...options,
20
21
  task: "image-segmentation",
21
22
  });
22
- const isValidOutput =
23
- Array.isArray(res) &&
24
- res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
25
- if (!isValidOutput) {
26
- throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
27
- }
28
- return res;
23
+ return providerHelper.getResponse(res);
29
24
  }
@@ -1,5 +1,5 @@
1
1
  import type { ImageToImageInput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
3
  import type { BaseArgs, Options, RequestArgs } from "../../types";
4
4
  import { base64FromBytes } from "../../utils/base64FromBytes";
5
5
  import { innerRequest } from "../../utils/request";
@@ -11,6 +11,7 @@ export type ImageToImageArgs = BaseArgs & ImageToImageInput;
11
11
  * Recommended model: lllyasviel/sd-controlnet-depth
12
12
  */
13
13
  export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> {
14
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
14
15
  let reqArgs: RequestArgs;
15
16
  if (!args.parameters) {
16
17
  reqArgs = {
@@ -30,9 +31,5 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P
30
31
  ...options,
31
32
  task: "image-to-image",
32
33
  });
33
- const isValidOutput = res && res instanceof Blob;
34
- if (!isValidOutput) {
35
- throw new InferenceOutputError("Expected Blob");
36
- }
37
- return res;
34
+ return providerHelper.getResponse(res);
38
35
  }
@@ -1,5 +1,5 @@
1
1
  import type { ImageToTextInput, ImageToTextOutput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
3
  import type { BaseArgs, Options } from "../../types";
4
4
  import { innerRequest } from "../../utils/request";
5
5
  import type { LegacyImageInput } from "./utils";
@@ -10,15 +10,12 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput);
10
10
  * This task reads some image input and outputs the text caption.
11
11
  */
12
12
  export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
13
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
13
14
  const payload = preparePayload(args);
14
15
  const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, {
15
16
  ...options,
16
17
  task: "image-to-text",
17
18
  });
18
19
 
19
- if (typeof res?.[0]?.generated_text !== "string") {
20
- throw new InferenceOutputError("Expected {generated_text: string}");
21
- }
22
-
23
- return res?.[0];
20
+ return providerHelper.getResponse(res[0]);
24
21
  }
@@ -1,5 +1,5 @@
1
1
  import type { ObjectDetectionInput, ObjectDetectionOutput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
3
  import type { BaseArgs, Options } from "../../types";
4
4
  import { innerRequest } from "../../utils/request";
5
5
  import { preparePayload, type LegacyImageInput } from "./utils";
@@ -11,26 +11,11 @@ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImage
11
11
  * Recommended model: facebook/detr-resnet-50
12
12
  */
13
13
  export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> {
14
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
14
15
  const payload = preparePayload(args);
15
16
  const { data: res } = await innerRequest<ObjectDetectionOutput>(payload, {
16
17
  ...options,
17
18
  task: "object-detection",
18
19
  });
19
- const isValidOutput =
20
- Array.isArray(res) &&
21
- res.every(
22
- (x) =>
23
- typeof x.label === "string" &&
24
- typeof x.score === "number" &&
25
- typeof x.box.xmin === "number" &&
26
- typeof x.box.ymin === "number" &&
27
- typeof x.box.xmax === "number" &&
28
- typeof x.box.ymax === "number"
29
- );
30
- if (!isValidOutput) {
31
- throw new InferenceOutputError(
32
- "Expected Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
33
- );
34
- }
35
- return res;
20
+ return providerHelper.getResponse(res);
36
21
  }
@@ -1,48 +1,15 @@
1
- import type { TextToImageInput, TextToImageOutput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
3
- import type { BaseArgs, InferenceProvider, Options } from "../../types";
4
- import { delay } from "../../utils/delay";
5
- import { omit } from "../../utils/omit";
1
+ import type { TextToImageInput } from "@huggingface/tasks";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
+ import { makeRequestOptions } from "../../lib/makeRequestOptions";
4
+ import type { BaseArgs, Options } from "../../types";
6
5
  import { innerRequest } from "../../utils/request";
7
6
 
8
7
  export type TextToImageArgs = BaseArgs & TextToImageInput;
9
8
 
10
- interface Base64ImageGeneration {
11
- data: Array<{
12
- b64_json: string;
13
- }>;
14
- }
15
- interface OutputUrlImageGeneration {
16
- output: string[];
17
- }
18
- interface HyperbolicTextToImageOutput {
19
- images: Array<{ image: string }>;
20
- }
21
-
22
- interface BlackForestLabsResponse {
23
- id: string;
24
- polling_url: string;
25
- }
26
-
27
9
  interface TextToImageOptions extends Options {
28
10
  outputType?: "url" | "blob";
29
11
  }
30
12
 
31
- function getResponseFormatArg(provider: InferenceProvider) {
32
- switch (provider) {
33
- case "fal-ai":
34
- return { sync_mode: true };
35
- case "nebius":
36
- return { response_format: "b64_json" };
37
- case "replicate":
38
- return undefined;
39
- case "together":
40
- return { response_format: "base64" };
41
- default:
42
- return undefined;
43
- }
44
- }
45
-
46
13
  /**
47
14
  * This task reads some text input and outputs an image.
48
15
  * Recommended model: stabilityai/stable-diffusion-2
@@ -56,108 +23,13 @@ export async function textToImage(
56
23
  options?: TextToImageOptions & { outputType?: undefined | "blob" }
57
24
  ): Promise<Blob>;
58
25
  export async function textToImage(args: TextToImageArgs, options?: TextToImageOptions): Promise<Blob | string> {
59
- const payload =
60
- !args.provider || args.provider === "hf-inference" || args.provider === "sambanova"
61
- ? args
62
- : {
63
- ...omit(args, ["inputs", "parameters"]),
64
- ...args.parameters,
65
- ...getResponseFormatArg(args.provider),
66
- prompt: args.inputs,
67
- };
68
- const { data: res } = await innerRequest<
69
- | TextToImageOutput
70
- | Base64ImageGeneration
71
- | OutputUrlImageGeneration
72
- | BlackForestLabsResponse
73
- | HyperbolicTextToImageOutput
74
- >(payload, {
26
+ const provider = args.provider ?? "hf-inference";
27
+ const providerHelper = getProviderHelper(provider, "text-to-image");
28
+ const { data: res } = await innerRequest<Record<string, unknown>>(args, {
75
29
  ...options,
76
30
  task: "text-to-image",
77
31
  });
78
32
 
79
- if (res && typeof res === "object") {
80
- if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
81
- return await pollBflResponse(res.polling_url, options?.outputType);
82
- }
83
- if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
84
- if (options?.outputType === "url") {
85
- return res.images[0].url;
86
- } else {
87
- const image = await fetch(res.images[0].url);
88
- return await image.blob();
89
- }
90
- }
91
- if (
92
- args.provider === "hyperbolic" &&
93
- "images" in res &&
94
- Array.isArray(res.images) &&
95
- res.images[0] &&
96
- typeof res.images[0].image === "string"
97
- ) {
98
- if (options?.outputType === "url") {
99
- return `data:image/jpeg;base64,${res.images[0].image}`;
100
- }
101
- const base64Response = await fetch(`data:image/jpeg;base64,${res.images[0].image}`);
102
- return await base64Response.blob();
103
- }
104
- if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
105
- const base64Data = res.data[0].b64_json;
106
- if (options?.outputType === "url") {
107
- return `data:image/jpeg;base64,${base64Data}`;
108
- }
109
- const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
110
- return await base64Response.blob();
111
- }
112
- if ("output" in res && Array.isArray(res.output)) {
113
- if (options?.outputType === "url") {
114
- return res.output[0];
115
- }
116
- const urlResponse = await fetch(res.output[0]);
117
- const blob = await urlResponse.blob();
118
- return blob;
119
- }
120
- }
121
- const isValidOutput = res && res instanceof Blob;
122
- if (!isValidOutput) {
123
- throw new InferenceOutputError("Expected Blob");
124
- }
125
- if (options?.outputType === "url") {
126
- const b64 = await res.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
127
- return `data:image/jpeg;base64,${b64}`;
128
- }
129
- return res;
130
- }
131
-
132
- async function pollBflResponse(url: string, outputType?: "url" | "blob"): Promise<Blob> {
133
- const urlObj = new URL(url);
134
- for (let step = 0; step < 5; step++) {
135
- await delay(1000);
136
- console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
137
- urlObj.searchParams.set("attempt", step.toString(10));
138
- const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
139
- if (!resp.ok) {
140
- throw new InferenceOutputError("Failed to fetch result from black forest labs API");
141
- }
142
- const payload = await resp.json();
143
- if (
144
- typeof payload === "object" &&
145
- payload &&
146
- "status" in payload &&
147
- typeof payload.status === "string" &&
148
- payload.status === "Ready" &&
149
- "result" in payload &&
150
- typeof payload.result === "object" &&
151
- payload.result &&
152
- "sample" in payload.result &&
153
- typeof payload.result.sample === "string"
154
- ) {
155
- if (outputType === "url") {
156
- return payload.result.sample;
157
- }
158
- const image = await fetch(payload.result.sample);
159
- return await image.blob();
160
- }
161
- }
162
- throw new InferenceOutputError("Failed to fetch result from black forest labs API");
33
+ const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-image" });
34
+ return providerHelper.getResponse(res, url, info.headers as Record<string, string>, options?.outputType);
163
35
  }
@@ -1,74 +1,23 @@
1
1
  import type { TextToVideoInput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
3
- import { isUrl } from "../../lib/isUrl";
4
- import { pollFalResponse, type FalAiQueueOutput } from "../../providers/fal-ai";
5
- import type { BaseArgs, InferenceProvider, Options } from "../../types";
6
- import { omit } from "../../utils/omit";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
+ import { makeRequestOptions } from "../../lib/makeRequestOptions";
4
+ import type { FalAiQueueOutput } from "../../providers/fal-ai";
5
+ import type { NovitaOutput } from "../../providers/novita";
6
+ import type { ReplicateOutput } from "../../providers/replicate";
7
+ import type { BaseArgs, Options } from "../../types";
7
8
  import { innerRequest } from "../../utils/request";
8
- import { typedInclude } from "../../utils/typedInclude";
9
9
 
10
10
  export type TextToVideoArgs = BaseArgs & TextToVideoInput;
11
11
 
12
12
  export type TextToVideoOutput = Blob;
13
13
 
14
- interface ReplicateOutput {
15
- output: string;
16
- }
17
-
18
- interface NovitaOutput {
19
- video: {
20
- video_url: string;
21
- };
22
- }
23
-
24
- const SUPPORTED_PROVIDERS = ["fal-ai", "novita", "replicate"] as const satisfies readonly InferenceProvider[];
25
-
26
14
  export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise<TextToVideoOutput> {
27
- if (!args.provider || !typedInclude(SUPPORTED_PROVIDERS, args.provider)) {
28
- throw new Error(
29
- `textToVideo inference is only supported for the following providers: ${SUPPORTED_PROVIDERS.join(", ")}`
30
- );
31
- }
32
-
33
- const payload =
34
- args.provider === "fal-ai" || args.provider === "replicate" || args.provider === "novita"
35
- ? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs }
36
- : args;
37
- const { data, requestContext } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(payload, {
15
+ const provider = args.provider ?? "hf-inference";
16
+ const providerHelper = getProviderHelper(provider, "text-to-video");
17
+ const { data: response } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(args, {
38
18
  ...options,
39
19
  task: "text-to-video",
40
20
  });
41
-
42
- if (args.provider === "fal-ai") {
43
- return await pollFalResponse(
44
- data as FalAiQueueOutput,
45
- requestContext.url,
46
- requestContext.info.headers as Record<string, string>
47
- );
48
- } else if (args.provider === "novita") {
49
- const isValidOutput =
50
- typeof data === "object" &&
51
- !!data &&
52
- "video" in data &&
53
- typeof data.video === "object" &&
54
- !!data.video &&
55
- "video_url" in data.video &&
56
- typeof data.video.video_url === "string" &&
57
- isUrl(data.video.video_url);
58
- if (!isValidOutput) {
59
- throw new InferenceOutputError("Expected { video: { video_url: string } }");
60
- }
61
- const urlResponse = await fetch((data as NovitaOutput).video.video_url);
62
- return await urlResponse.blob();
63
- } else {
64
- /// TODO: Replicate: handle the case where the generation request "times out" / is async (ie output is null)
65
- /// https://replicate.com/docs/topics/predictions/create-a-prediction
66
- const isValidOutput =
67
- typeof data === "object" && !!data && "output" in data && typeof data.output === "string" && isUrl(data.output);
68
- if (!isValidOutput) {
69
- throw new InferenceOutputError("Expected { output: string }");
70
- }
71
- const urlResponse = await fetch(data.output);
72
- return await urlResponse.blob();
73
- }
21
+ const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-video" });
22
+ return providerHelper.getResponse(response, url, info.headers as Record<string, string>);
74
23
  }