@huggingface/inference 3.6.2 → 3.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. package/README.md +0 -25
  2. package/dist/index.cjs +1232 -898
  3. package/dist/index.js +1234 -900
  4. package/dist/src/config.d.ts +1 -0
  5. package/dist/src/config.d.ts.map +1 -1
  6. package/dist/src/lib/getProviderHelper.d.ts +37 -0
  7. package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
  8. package/dist/src/lib/makeRequestOptions.d.ts +0 -2
  9. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  10. package/dist/src/providers/black-forest-labs.d.ts +14 -18
  11. package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
  12. package/dist/src/providers/cerebras.d.ts +4 -2
  13. package/dist/src/providers/cerebras.d.ts.map +1 -1
  14. package/dist/src/providers/cohere.d.ts +5 -2
  15. package/dist/src/providers/cohere.d.ts.map +1 -1
  16. package/dist/src/providers/fal-ai.d.ts +50 -3
  17. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  18. package/dist/src/providers/fireworks-ai.d.ts +5 -2
  19. package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
  20. package/dist/src/providers/hf-inference.d.ts +125 -2
  21. package/dist/src/providers/hf-inference.d.ts.map +1 -1
  22. package/dist/src/providers/hyperbolic.d.ts +31 -2
  23. package/dist/src/providers/hyperbolic.d.ts.map +1 -1
  24. package/dist/src/providers/nebius.d.ts +20 -18
  25. package/dist/src/providers/nebius.d.ts.map +1 -1
  26. package/dist/src/providers/novita.d.ts +21 -18
  27. package/dist/src/providers/novita.d.ts.map +1 -1
  28. package/dist/src/providers/openai.d.ts +4 -2
  29. package/dist/src/providers/openai.d.ts.map +1 -1
  30. package/dist/src/providers/providerHelper.d.ts +182 -0
  31. package/dist/src/providers/providerHelper.d.ts.map +1 -0
  32. package/dist/src/providers/replicate.d.ts +23 -19
  33. package/dist/src/providers/replicate.d.ts.map +1 -1
  34. package/dist/src/providers/sambanova.d.ts +4 -2
  35. package/dist/src/providers/sambanova.d.ts.map +1 -1
  36. package/dist/src/providers/together.d.ts +32 -2
  37. package/dist/src/providers/together.d.ts.map +1 -1
  38. package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
  39. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  40. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  41. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  42. package/dist/src/tasks/audio/utils.d.ts +2 -1
  43. package/dist/src/tasks/audio/utils.d.ts.map +1 -1
  44. package/dist/src/tasks/custom/request.d.ts +1 -2
  45. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  46. package/dist/src/tasks/custom/streamingRequest.d.ts +1 -2
  47. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  48. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  49. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  50. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  51. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  52. package/dist/src/tasks/cv/objectDetection.d.ts +1 -1
  53. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  54. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  55. package/dist/src/tasks/cv/textToVideo.d.ts +1 -1
  56. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
  57. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts +1 -1
  58. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  59. package/dist/src/tasks/index.d.ts +6 -6
  60. package/dist/src/tasks/index.d.ts.map +1 -1
  61. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts +1 -1
  62. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  63. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  64. package/dist/src/tasks/nlp/chatCompletion.d.ts +1 -1
  65. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  66. package/dist/src/tasks/nlp/chatCompletionStream.d.ts +1 -1
  67. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
  68. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  69. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  70. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  71. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  72. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  73. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  74. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  75. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  76. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  77. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  78. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  79. package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
  80. package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
  81. package/dist/src/types.d.ts +10 -13
  82. package/dist/src/types.d.ts.map +1 -1
  83. package/dist/src/utils/request.d.ts +27 -0
  84. package/dist/src/utils/request.d.ts.map +1 -0
  85. package/package.json +3 -3
  86. package/src/config.ts +1 -0
  87. package/src/lib/getProviderHelper.ts +270 -0
  88. package/src/lib/makeRequestOptions.ts +36 -90
  89. package/src/providers/black-forest-labs.ts +73 -22
  90. package/src/providers/cerebras.ts +6 -27
  91. package/src/providers/cohere.ts +9 -28
  92. package/src/providers/fal-ai.ts +195 -77
  93. package/src/providers/fireworks-ai.ts +8 -29
  94. package/src/providers/hf-inference.ts +555 -34
  95. package/src/providers/hyperbolic.ts +107 -29
  96. package/src/providers/nebius.ts +65 -29
  97. package/src/providers/novita.ts +68 -32
  98. package/src/providers/openai.ts +6 -32
  99. package/src/providers/providerHelper.ts +354 -0
  100. package/src/providers/replicate.ts +124 -34
  101. package/src/providers/sambanova.ts +5 -30
  102. package/src/providers/together.ts +92 -28
  103. package/src/snippets/getInferenceSnippets.ts +16 -9
  104. package/src/snippets/templates.exported.ts +2 -2
  105. package/src/tasks/audio/audioClassification.ts +6 -9
  106. package/src/tasks/audio/audioToAudio.ts +5 -28
  107. package/src/tasks/audio/automaticSpeechRecognition.ts +7 -6
  108. package/src/tasks/audio/textToSpeech.ts +6 -30
  109. package/src/tasks/audio/utils.ts +2 -1
  110. package/src/tasks/custom/request.ts +7 -34
  111. package/src/tasks/custom/streamingRequest.ts +5 -87
  112. package/src/tasks/cv/imageClassification.ts +5 -9
  113. package/src/tasks/cv/imageSegmentation.ts +5 -10
  114. package/src/tasks/cv/imageToImage.ts +5 -8
  115. package/src/tasks/cv/imageToText.ts +8 -13
  116. package/src/tasks/cv/objectDetection.ts +6 -21
  117. package/src/tasks/cv/textToImage.ts +10 -138
  118. package/src/tasks/cv/textToVideo.ts +11 -59
  119. package/src/tasks/cv/zeroShotImageClassification.ts +7 -12
  120. package/src/tasks/index.ts +6 -6
  121. package/src/tasks/multimodal/documentQuestionAnswering.ts +10 -26
  122. package/src/tasks/multimodal/visualQuestionAnswering.ts +6 -12
  123. package/src/tasks/nlp/chatCompletion.ts +7 -23
  124. package/src/tasks/nlp/chatCompletionStream.ts +4 -5
  125. package/src/tasks/nlp/featureExtraction.ts +5 -20
  126. package/src/tasks/nlp/fillMask.ts +5 -18
  127. package/src/tasks/nlp/questionAnswering.ts +5 -23
  128. package/src/tasks/nlp/sentenceSimilarity.ts +5 -18
  129. package/src/tasks/nlp/summarization.ts +5 -8
  130. package/src/tasks/nlp/tableQuestionAnswering.ts +5 -29
  131. package/src/tasks/nlp/textClassification.ts +8 -14
  132. package/src/tasks/nlp/textGeneration.ts +13 -80
  133. package/src/tasks/nlp/textGenerationStream.ts +2 -2
  134. package/src/tasks/nlp/tokenClassification.ts +8 -24
  135. package/src/tasks/nlp/translation.ts +5 -8
  136. package/src/tasks/nlp/zeroShotClassification.ts +8 -22
  137. package/src/tasks/tabular/tabularClassification.ts +5 -8
  138. package/src/tasks/tabular/tabularRegression.ts +5 -8
  139. package/src/types.ts +11 -14
  140. package/src/utils/request.ts +161 -0
@@ -1,15 +1,15 @@
1
- import type { PipelineType, WidgetType } from "@huggingface/tasks/src/pipelines.js";
2
- import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingface/tasks/src/tasks/index.js";
1
+ import { Template } from "@huggingface/jinja";
3
2
  import {
4
3
  type InferenceSnippet,
5
4
  type InferenceSnippetLanguage,
6
5
  type ModelDataMinimal,
7
- inferenceSnippetLanguages,
8
6
  getModelInputSnippet,
7
+ inferenceSnippetLanguages,
9
8
  } from "@huggingface/tasks";
10
- import type { InferenceProvider, InferenceTask, RequestArgs } from "../types";
11
- import { Template } from "@huggingface/jinja";
9
+ import type { PipelineType, WidgetType } from "@huggingface/tasks/src/pipelines.js";
10
+ import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingface/tasks/src/tasks/index.js";
12
11
  import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions";
12
+ import type { InferenceProvider, InferenceTask, RequestArgs } from "../types";
13
13
  import { templates } from "./templates.exported";
14
14
 
15
15
  const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const;
@@ -120,6 +120,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
120
120
  opts?: Record<string, unknown>
121
121
  ): InferenceSnippet[] => {
122
122
  /// Hacky: hard-code conversational templates here
123
+ let task = model.pipeline_tag as InferenceTask;
123
124
  if (
124
125
  model.pipeline_tag &&
125
126
  ["text-generation", "image-text-to-text"].includes(model.pipeline_tag) &&
@@ -127,14 +128,20 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
127
128
  ) {
128
129
  templateName = opts?.streaming ? "conversationalStream" : "conversational";
129
130
  inputPreparationFn = prepareConversationalInput;
131
+ task = "conversational";
130
132
  }
131
-
132
133
  /// Prepare inputs + make request
133
134
  const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
134
135
  const request = makeRequestOptionsFromResolvedModel(
135
136
  providerModelId ?? model.id,
136
- { accessToken: accessToken, provider: provider, ...inputs } as RequestArgs,
137
- { chatCompletion: templateName.includes("conversational"), task: model.pipeline_tag as InferenceTask }
137
+ {
138
+ accessToken: accessToken,
139
+ provider: provider,
140
+ ...inputs,
141
+ } as RequestArgs,
142
+ {
143
+ task: task,
144
+ }
138
145
  );
139
146
 
140
147
  /// Parse request.info.body if not a binary.
@@ -247,7 +254,7 @@ const prepareConversationalInput = (
247
254
  return {
248
255
  messages: opts?.messages ?? getModelInputSnippet(model),
249
256
  ...(opts?.temperature ? { temperature: opts?.temperature } : undefined),
250
- max_tokens: opts?.max_tokens ?? 500,
257
+ max_tokens: opts?.max_tokens ?? 512,
251
258
  ...(opts?.top_p ? { top_p: opts?.top_p } : undefined),
252
259
  };
253
260
  };
@@ -6,7 +6,7 @@ export const templates: Record<string, Record<string, Record<string, string>>> =
6
6
  "basicAudio": "async function query(data) {\n\tconst response = await fetch(\n\t\t\"{{ fullUrl }}\",\n\t\t{\n\t\t\theaders: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n\t\t\t\t\"Content-Type\": \"audio/flac\"\n\t\t\t},\n\t\t\tmethod: \"POST\",\n\t\t\tbody: JSON.stringify(data),\n\t\t}\n\t);\n\tconst result = await response.json();\n\treturn result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});",
7
7
  "basicImage": "async function query(data) {\n\tconst response = await fetch(\n\t\t\"{{ fullUrl }}\",\n\t\t{\n\t\t\theaders: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n\t\t\t\t\"Content-Type\": \"image/jpeg\"\n\t\t\t},\n\t\t\tmethod: \"POST\",\n\t\t\tbody: JSON.stringify(data),\n\t\t}\n\t);\n\tconst result = await response.json();\n\treturn result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});",
8
8
  "textToAudio": "{% if model.library_name == \"transformers\" %}\nasync function query(data) {\n\tconst response = await fetch(\n\t\t\"{{ fullUrl }}\",\n\t\t{\n\t\t\theaders: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n\t\t\t\t\"Content-Type\": \"application/json\",\n\t\t\t},\n\t\t\tmethod: \"POST\",\n\t\t\tbody: JSON.stringify(data),\n\t\t}\n\t);\n\tconst result = await response.blob();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n\tconst response = await fetch(\n\t\t\"{{ fullUrl }}\",\n\t\t{\n\t\t\theaders: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n\t\t\t\t\"Content-Type\": \"application/json\",\n\t\t\t},\n\t\t\tmethod: \"POST\",\n\t\t\tbody: JSON.stringify(data),\n\t\t}\n\t);\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ",
9
- "textToImage": "async function query(data) {\n\tconst response = await fetch(\n\t\t\"{{ fullUrl }}\",\n\t\t{\n\t\t\theaders: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n\t\t\t\t\"Content-Type\": \"application/json\",\n\t\t\t},\n\t\t\tmethod: \"POST\",\n\t\t\tbody: JSON.stringify(data),\n\t\t}\n\t);\n\tconst result = await response.blob();\n\treturn result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n // Use image\n});",
9
+ "textToImage": "async function query(data) {\n\tconst response = await fetch(\n\t\t\"{{ fullUrl }}\",\n\t\t{\n\t\t\theaders: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n\t\t\t\t\"Content-Type\": \"application/json\",\n\t\t\t},\n\t\t\tmethod: \"POST\",\n\t\t\tbody: JSON.stringify(data),\n\t\t}\n\t);\n\tconst result = await response.blob();\n\treturn result;\n}\n\n\nquery({ {{ providerInputs.asTsString }} }).then((response) => {\n // Use image\n});",
10
10
  "zeroShotClassification": "async function query(data) {\n const response = await fetch(\n\t\t\"{{ fullUrl }}\",\n {\n headers: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n \"Content-Type\": \"application/json\",\n },\n method: \"POST\",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({\n inputs: {{ providerInputs.asObj.inputs }},\n parameters: { candidate_labels: [\"refund\", \"legal\", \"faq\"] }\n}).then((response) => {\n console.log(JSON.stringify(response));\n});"
11
11
  },
12
12
  "huggingface.js": {
@@ -20,7 +20,7 @@ export const templates: Record<string, Record<string, Record<string, string>>> =
20
20
  },
21
21
  "openai": {
22
22
  "conversational": "import { OpenAI } from \"openai\";\n\nconst client = new OpenAI({\n\tbaseURL: \"{{ baseUrl }}\",\n\tapiKey: \"{{ accessToken }}\",\n});\n\nconst chatCompletion = await client.chat.completions.create({\n\tmodel: \"{{ providerModelId }}\",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);",
23
- "conversationalStream": "import { OpenAI } from \"openai\";\n\nconst client = new OpenAI({\n\tbaseURL: \"{{ baseUrl }}\",\n\tapiKey: \"{{ accessToken }}\",\n});\n\nlet out = \"\";\n\nconst stream = await client.chat.completions.create({\n provider: \"{{ provider }}\",\n model: \"{{ model.id }}\",\n{{ inputs.asTsString }}\n});\n\nfor await (const chunk of stream) {\n\tif (chunk.choices && chunk.choices.length > 0) {\n\t\tconst newContent = chunk.choices[0].delta.content;\n\t\tout += newContent;\n\t\tconsole.log(newContent);\n\t} \n}"
23
+ "conversationalStream": "import { OpenAI } from \"openai\";\n\nconst client = new OpenAI({\n\tbaseURL: \"{{ baseUrl }}\",\n\tapiKey: \"{{ accessToken }}\",\n});\n\nconst stream = await client.chat.completions.create({\n model: \"{{ providerModelId }}\",\n{{ inputs.asTsString }}\n stream: true,\n});\n\nfor await (const chunk of stream) {\n process.stdout.write(chunk.choices[0]?.delta?.content || \"\");\n}"
24
24
  }
25
25
  },
26
26
  "python": {
@@ -1,7 +1,7 @@
1
1
  import type { AudioClassificationInput, AudioClassificationOutput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
3
  import type { BaseArgs, Options } from "../../types";
4
- import { request } from "../custom/request";
4
+ import { innerRequest } from "../../utils/request";
5
5
  import type { LegacyAudioInput } from "./utils";
6
6
  import { preparePayload } from "./utils";
7
7
 
@@ -15,15 +15,12 @@ export async function audioClassification(
15
15
  args: AudioClassificationArgs,
16
16
  options?: Options
17
17
  ): Promise<AudioClassificationOutput> {
18
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
18
19
  const payload = preparePayload(args);
19
- const res = await request<AudioClassificationOutput>(payload, {
20
+ const { data: res } = await innerRequest<AudioClassificationOutput>(payload, {
20
21
  ...options,
21
22
  task: "audio-classification",
22
23
  });
23
- const isValidOutput =
24
- Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
25
- if (!isValidOutput) {
26
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
27
- }
28
- return res;
24
+
25
+ return providerHelper.getResponse(res);
29
26
  }
@@ -1,6 +1,6 @@
1
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
1
+ import { getProviderHelper } from "../../lib/getProviderHelper";
2
2
  import type { BaseArgs, Options } from "../../types";
3
- import { request } from "../custom/request";
3
+ import { innerRequest } from "../../utils/request";
4
4
  import type { LegacyAudioInput } from "./utils";
5
5
  import { preparePayload } from "./utils";
6
6
 
@@ -36,34 +36,11 @@ export interface AudioToAudioOutput {
36
36
  * Example model: speechbrain/sepformer-wham does audio source separation.
37
37
  */
38
38
  export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
39
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
39
40
  const payload = preparePayload(args);
40
- const res = await request<AudioToAudioOutput>(payload, {
41
+ const { data: res } = await innerRequest<AudioToAudioOutput>(payload, {
41
42
  ...options,
42
43
  task: "audio-to-audio",
43
44
  });
44
-
45
- return validateOutput(res);
46
- }
47
-
48
- function validateOutput(output: unknown): AudioToAudioOutput[] {
49
- if (!Array.isArray(output)) {
50
- throw new InferenceOutputError("Expected Array");
51
- }
52
- if (
53
- !output.every((elem): elem is AudioToAudioOutput => {
54
- return (
55
- typeof elem === "object" &&
56
- elem &&
57
- "label" in elem &&
58
- typeof elem.label === "string" &&
59
- "content-type" in elem &&
60
- typeof elem["content-type"] === "string" &&
61
- "blob" in elem &&
62
- typeof elem.blob === "string"
63
- );
64
- })
65
- ) {
66
- throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
67
- }
68
- return output;
45
+ return providerHelper.getResponse(res);
69
46
  }
@@ -1,11 +1,13 @@
1
1
  import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
2
3
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
4
+ import { FAL_AI_SUPPORTED_BLOB_TYPES } from "../../providers/fal-ai";
3
5
  import type { BaseArgs, Options, RequestArgs } from "../../types";
4
6
  import { base64FromBytes } from "../../utils/base64FromBytes";
5
- import { request } from "../custom/request";
7
+ import { omit } from "../../utils/omit";
8
+ import { innerRequest } from "../../utils/request";
6
9
  import type { LegacyAudioInput } from "./utils";
7
10
  import { preparePayload } from "./utils";
8
- import { omit } from "../../utils/omit";
9
11
 
10
12
  export type AutomaticSpeechRecognitionArgs = BaseArgs & (AutomaticSpeechRecognitionInput | LegacyAudioInput);
11
13
  /**
@@ -16,8 +18,9 @@ export async function automaticSpeechRecognition(
16
18
  args: AutomaticSpeechRecognitionArgs,
17
19
  options?: Options
18
20
  ): Promise<AutomaticSpeechRecognitionOutput> {
21
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
19
22
  const payload = await buildPayload(args);
20
- const res = await request<AutomaticSpeechRecognitionOutput>(payload, {
23
+ const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, {
21
24
  ...options,
22
25
  task: "automatic-speech-recognition",
23
26
  });
@@ -25,11 +28,9 @@ export async function automaticSpeechRecognition(
25
28
  if (!isValidOutput) {
26
29
  throw new InferenceOutputError("Expected {text: string}");
27
30
  }
28
- return res;
31
+ return providerHelper.getResponse(res);
29
32
  }
30
33
 
31
- const FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
32
-
33
34
  async function buildPayload(args: AutomaticSpeechRecognitionArgs): Promise<RequestArgs> {
34
35
  if (args.provider === "fal-ai") {
35
36
  const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : undefined;
@@ -1,8 +1,7 @@
1
1
  import type { TextToSpeechInput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
3
  import type { BaseArgs, Options } from "../../types";
4
- import { omit } from "../../utils/omit";
5
- import { request } from "../custom/request";
4
+ import { innerRequest } from "../../utils/request";
6
5
  type TextToSpeechArgs = BaseArgs & TextToSpeechInput;
7
6
 
8
7
  interface OutputUrlTextToSpeechGeneration {
@@ -13,34 +12,11 @@ interface OutputUrlTextToSpeechGeneration {
13
12
  * Recommended model: espnet/kan-bayashi_ljspeech_vits
14
13
  */
15
14
  export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
16
- // Replicate models expects "text" instead of "inputs"
17
- const payload =
18
- args.provider === "replicate"
19
- ? {
20
- ...omit(args, ["inputs", "parameters"]),
21
- ...args.parameters,
22
- text: args.inputs,
23
- }
24
- : args;
25
- const res = await request<Blob | OutputUrlTextToSpeechGeneration>(payload, {
15
+ const provider = args.provider ?? "hf-inference";
16
+ const providerHelper = getProviderHelper(provider, "text-to-speech");
17
+ const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, {
26
18
  ...options,
27
19
  task: "text-to-speech",
28
20
  });
29
- if (res instanceof Blob) {
30
- return res;
31
- }
32
- if (res && typeof res === "object") {
33
- if ("output" in res) {
34
- if (typeof res.output === "string") {
35
- const urlResponse = await fetch(res.output);
36
- const blob = await urlResponse.blob();
37
- return blob;
38
- } else if (Array.isArray(res.output)) {
39
- const urlResponse = await fetch(res.output[0]);
40
- const blob = await urlResponse.blob();
41
- return blob;
42
- }
43
- }
44
- }
45
- throw new InferenceOutputError("Expected Blob or object with output");
21
+ return providerHelper.getResponse(res);
46
22
  }
@@ -1,4 +1,4 @@
1
- import type { BaseArgs, RequestArgs } from "../../types";
1
+ import type { BaseArgs, InferenceProvider, RequestArgs } from "../../types";
2
2
  import { omit } from "../../utils/omit";
3
3
 
4
4
  /**
@@ -6,6 +6,7 @@ import { omit } from "../../utils/omit";
6
6
  */
7
7
  export interface LegacyAudioInput {
8
8
  data: Blob | ArrayBuffer;
9
+ provider?: InferenceProvider;
9
10
  }
10
11
 
11
12
  export function preparePayload(args: BaseArgs & ({ inputs: Blob } | LegacyAudioInput)): RequestArgs {
@@ -1,47 +1,20 @@
1
1
  import type { InferenceTask, Options, RequestArgs } from "../../types";
2
- import { makeRequestOptions } from "../../lib/makeRequestOptions";
2
+ import { innerRequest } from "../../utils/request";
3
3
 
4
4
  /**
5
5
  * Primitive to make custom calls to the inference provider
6
+ * @deprecated Use specific task functions instead. This function will be removed in a future version.
6
7
  */
7
8
  export async function request<T>(
8
9
  args: RequestArgs,
9
10
  options?: Options & {
10
11
  /** In most cases (unless we pass a endpointUrl) we know the task */
11
12
  task?: InferenceTask;
12
- /** Is chat completion compatible */
13
- chatCompletion?: boolean;
14
13
  }
15
14
  ): Promise<T> {
16
- const { url, info } = await makeRequestOptions(args, options);
17
- const response = await (options?.fetch ?? fetch)(url, info);
18
-
19
- if (options?.retry_on_error !== false && response.status === 503) {
20
- return request(args, options);
21
- }
22
-
23
- if (!response.ok) {
24
- const contentType = response.headers.get("Content-Type");
25
- if (["application/json", "application/problem+json"].some((ct) => contentType?.startsWith(ct))) {
26
- const output = await response.json();
27
- if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
28
- throw new Error(
29
- `Server ${args.model} does not seem to support chat completion. Error: ${JSON.stringify(output.error)}`
30
- );
31
- }
32
- if (output.error || output.detail) {
33
- throw new Error(JSON.stringify(output.error ?? output.detail));
34
- } else {
35
- throw new Error(output);
36
- }
37
- }
38
- const message = contentType?.startsWith("text/plain;") ? await response.text() : undefined;
39
- throw new Error(message ?? "An error occurred while fetching the blob");
40
- }
41
-
42
- if (response.headers.get("Content-Type")?.startsWith("application/json")) {
43
- return await response.json();
44
- }
45
-
46
- return (await response.blob()) as T;
15
+ console.warn(
16
+ "The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
17
+ );
18
+ const result = await innerRequest<T>(args, options);
19
+ return result.data;
47
20
  }
@@ -1,100 +1,18 @@
1
1
  import type { InferenceTask, Options, RequestArgs } from "../../types";
2
- import { makeRequestOptions } from "../../lib/makeRequestOptions";
3
- import type { EventSourceMessage } from "../../vendor/fetch-event-source/parse";
4
- import { getLines, getMessages } from "../../vendor/fetch-event-source/parse";
5
-
2
+ import { innerStreamingRequest } from "../../utils/request";
6
3
  /**
7
4
  * Primitive to make custom inference calls that expect server-sent events, and returns the response through a generator
5
+ * @deprecated Use specific task functions instead. This function will be removed in a future version.
8
6
  */
9
7
  export async function* streamingRequest<T>(
10
8
  args: RequestArgs,
11
9
  options?: Options & {
12
10
  /** In most cases (unless we pass a endpointUrl) we know the task */
13
11
  task?: InferenceTask;
14
- /** Is chat completion compatible */
15
- chatCompletion?: boolean;
16
12
  }
17
13
  ): AsyncGenerator<T> {
18
- const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
19
- const response = await (options?.fetch ?? fetch)(url, info);
20
-
21
- if (options?.retry_on_error !== false && response.status === 503) {
22
- return yield* streamingRequest(args, options);
23
- }
24
- if (!response.ok) {
25
- if (response.headers.get("Content-Type")?.startsWith("application/json")) {
26
- const output = await response.json();
27
- if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
28
- throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
29
- }
30
- if (typeof output.error === "string") {
31
- throw new Error(output.error);
32
- }
33
- if (output.error && "message" in output.error && typeof output.error.message === "string") {
34
- /// OpenAI errors
35
- throw new Error(output.error.message);
36
- }
37
- }
38
-
39
- throw new Error(`Server response contains error: ${response.status}`);
40
- }
41
- if (!response.headers.get("content-type")?.startsWith("text/event-stream")) {
42
- throw new Error(
43
- `Server does not support event stream content type, it returned ` + response.headers.get("content-type")
44
- );
45
- }
46
-
47
- if (!response.body) {
48
- return;
49
- }
50
-
51
- const reader = response.body.getReader();
52
- let events: EventSourceMessage[] = [];
53
-
54
- const onEvent = (event: EventSourceMessage) => {
55
- // accumulate events in array
56
- events.push(event);
57
- };
58
-
59
- const onChunk = getLines(
60
- getMessages(
61
- () => {},
62
- () => {},
63
- onEvent
64
- )
14
+ console.warn(
15
+ "The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
65
16
  );
66
-
67
- try {
68
- while (true) {
69
- const { done, value } = await reader.read();
70
- if (done) {
71
- return;
72
- }
73
- onChunk(value);
74
- for (const event of events) {
75
- if (event.data.length > 0) {
76
- if (event.data === "[DONE]") {
77
- return;
78
- }
79
- const data = JSON.parse(event.data);
80
- if (typeof data === "object" && data !== null && "error" in data) {
81
- const errorStr =
82
- typeof data.error === "string"
83
- ? data.error
84
- : typeof data.error === "object" &&
85
- data.error &&
86
- "message" in data.error &&
87
- typeof data.error.message === "string"
88
- ? data.error.message
89
- : JSON.stringify(data.error);
90
- throw new Error(`Error forwarded from backend: ` + errorStr);
91
- }
92
- yield data as T;
93
- }
94
- }
95
- events = [];
96
- }
97
- } finally {
98
- reader.releaseLock();
99
- }
17
+ yield* innerStreamingRequest(args, options);
100
18
  }
@@ -1,7 +1,7 @@
1
1
  import type { ImageClassificationInput, ImageClassificationOutput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
3
  import type { BaseArgs, Options } from "../../types";
4
- import { request } from "../custom/request";
4
+ import { innerRequest } from "../../utils/request";
5
5
  import { preparePayload, type LegacyImageInput } from "./utils";
6
6
 
7
7
  export type ImageClassificationArgs = BaseArgs & (ImageClassificationInput | LegacyImageInput);
@@ -14,15 +14,11 @@ export async function imageClassification(
14
14
  args: ImageClassificationArgs,
15
15
  options?: Options
16
16
  ): Promise<ImageClassificationOutput> {
17
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
17
18
  const payload = preparePayload(args);
18
- const res = await request<ImageClassificationOutput>(payload, {
19
+ const { data: res } = await innerRequest<ImageClassificationOutput>(payload, {
19
20
  ...options,
20
21
  task: "image-classification",
21
22
  });
22
- const isValidOutput =
23
- Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
24
- if (!isValidOutput) {
25
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
26
- }
27
- return res;
23
+ return providerHelper.getResponse(res);
28
24
  }
@@ -1,7 +1,7 @@
1
1
  import type { ImageSegmentationInput, ImageSegmentationOutput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
3
  import type { BaseArgs, Options } from "../../types";
4
- import { request } from "../custom/request";
4
+ import { innerRequest } from "../../utils/request";
5
5
  import { preparePayload, type LegacyImageInput } from "./utils";
6
6
 
7
7
  export type ImageSegmentationArgs = BaseArgs & (ImageSegmentationInput | LegacyImageInput);
@@ -14,16 +14,11 @@ export async function imageSegmentation(
14
14
  args: ImageSegmentationArgs,
15
15
  options?: Options
16
16
  ): Promise<ImageSegmentationOutput> {
17
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
17
18
  const payload = preparePayload(args);
18
- const res = await request<ImageSegmentationOutput>(payload, {
19
+ const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, {
19
20
  ...options,
20
21
  task: "image-segmentation",
21
22
  });
22
- const isValidOutput =
23
- Array.isArray(res) &&
24
- res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
25
- if (!isValidOutput) {
26
- throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
27
- }
28
- return res;
23
+ return providerHelper.getResponse(res);
29
24
  }
@@ -1,8 +1,8 @@
1
1
  import type { ImageToImageInput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
3
  import type { BaseArgs, Options, RequestArgs } from "../../types";
4
4
  import { base64FromBytes } from "../../utils/base64FromBytes";
5
- import { request } from "../custom/request";
5
+ import { innerRequest } from "../../utils/request";
6
6
 
7
7
  export type ImageToImageArgs = BaseArgs & ImageToImageInput;
8
8
 
@@ -11,6 +11,7 @@ export type ImageToImageArgs = BaseArgs & ImageToImageInput;
11
11
  * Recommended model: lllyasviel/sd-controlnet-depth
12
12
  */
13
13
  export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> {
14
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
14
15
  let reqArgs: RequestArgs;
15
16
  if (!args.parameters) {
16
17
  reqArgs = {
@@ -26,13 +27,9 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P
26
27
  ),
27
28
  };
28
29
  }
29
- const res = await request<Blob>(reqArgs, {
30
+ const { data: res } = await innerRequest<Blob>(reqArgs, {
30
31
  ...options,
31
32
  task: "image-to-image",
32
33
  });
33
- const isValidOutput = res && res instanceof Blob;
34
- if (!isValidOutput) {
35
- throw new InferenceOutputError("Expected Blob");
36
- }
37
- return res;
34
+ return providerHelper.getResponse(res);
38
35
  }
@@ -1,7 +1,7 @@
1
1
  import type { ImageToTextInput, ImageToTextOutput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
3
  import type { BaseArgs, Options } from "../../types";
4
- import { request } from "../custom/request";
4
+ import { innerRequest } from "../../utils/request";
5
5
  import type { LegacyImageInput } from "./utils";
6
6
  import { preparePayload } from "./utils";
7
7
 
@@ -10,17 +10,12 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput);
10
10
  * This task reads some image input and outputs the text caption.
11
11
  */
12
12
  export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
13
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
13
14
  const payload = preparePayload(args);
14
- const res = (
15
- await request<[ImageToTextOutput]>(payload, {
16
- ...options,
17
- task: "image-to-text",
18
- })
19
- )?.[0];
15
+ const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, {
16
+ ...options,
17
+ task: "image-to-text",
18
+ });
20
19
 
21
- if (typeof res?.generated_text !== "string") {
22
- throw new InferenceOutputError("Expected {generated_text: string}");
23
- }
24
-
25
- return res;
20
+ return providerHelper.getResponse(res[0]);
26
21
  }
@@ -1,7 +1,7 @@
1
- import { request } from "../custom/request";
2
- import type { BaseArgs, Options } from "../../types";
3
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
4
1
  import type { ObjectDetectionInput, ObjectDetectionOutput } from "@huggingface/tasks";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
+ import type { BaseArgs, Options } from "../../types";
4
+ import { innerRequest } from "../../utils/request";
5
5
  import { preparePayload, type LegacyImageInput } from "./utils";
6
6
 
7
7
  export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImageInput);
@@ -11,26 +11,11 @@ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImage
11
11
  * Recommended model: facebook/detr-resnet-50
12
12
  */
13
13
  export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> {
14
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
14
15
  const payload = preparePayload(args);
15
- const res = await request<ObjectDetectionOutput>(payload, {
16
+ const { data: res } = await innerRequest<ObjectDetectionOutput>(payload, {
16
17
  ...options,
17
18
  task: "object-detection",
18
19
  });
19
- const isValidOutput =
20
- Array.isArray(res) &&
21
- res.every(
22
- (x) =>
23
- typeof x.label === "string" &&
24
- typeof x.score === "number" &&
25
- typeof x.box.xmin === "number" &&
26
- typeof x.box.ymin === "number" &&
27
- typeof x.box.xmax === "number" &&
28
- typeof x.box.ymax === "number"
29
- );
30
- if (!isValidOutput) {
31
- throw new InferenceOutputError(
32
- "Expected Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
33
- );
34
- }
35
- return res;
20
+ return providerHelper.getResponse(res);
36
21
  }