@huggingface/inference 3.0.0 → 3.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. package/README.md +11 -6
  2. package/dist/index.cjs +193 -76
  3. package/dist/index.js +193 -76
  4. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  5. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  6. package/dist/src/providers/replicate.d.ts.map +1 -1
  7. package/dist/src/providers/together.d.ts.map +1 -1
  8. package/dist/src/tasks/audio/audioClassification.d.ts +4 -18
  9. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  10. package/dist/src/tasks/audio/audioToAudio.d.ts +10 -9
  11. package/dist/src/tasks/audio/audioToAudio.d.ts.map +1 -1
  12. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts +3 -12
  13. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  14. package/dist/src/tasks/audio/textToSpeech.d.ts +4 -8
  15. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  16. package/dist/src/tasks/audio/utils.d.ts +11 -0
  17. package/dist/src/tasks/audio/utils.d.ts.map +1 -0
  18. package/dist/src/tasks/cv/imageClassification.d.ts +3 -17
  19. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  20. package/dist/src/tasks/cv/imageSegmentation.d.ts +3 -21
  21. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  22. package/dist/src/tasks/cv/imageToImage.d.ts +3 -49
  23. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  24. package/dist/src/tasks/cv/imageToText.d.ts +3 -12
  25. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  26. package/dist/src/tasks/cv/objectDetection.d.ts +3 -26
  27. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  28. package/dist/src/tasks/cv/textToImage.d.ts +3 -38
  29. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  30. package/dist/src/tasks/cv/textToVideo.d.ts +6 -0
  31. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -0
  32. package/dist/src/tasks/cv/utils.d.ts +11 -0
  33. package/dist/src/tasks/cv/utils.d.ts.map +1 -0
  34. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts +7 -15
  35. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  36. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts +5 -28
  37. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  38. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts +5 -20
  39. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  40. package/dist/src/tasks/nlp/fillMask.d.ts +2 -21
  41. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  42. package/dist/src/tasks/nlp/questionAnswering.d.ts +3 -25
  43. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  44. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts +2 -13
  45. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  46. package/dist/src/tasks/nlp/summarization.d.ts +2 -42
  47. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  48. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts +3 -31
  49. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  50. package/dist/src/tasks/nlp/textClassification.d.ts +2 -16
  51. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  52. package/dist/src/tasks/nlp/tokenClassification.d.ts +2 -45
  53. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  54. package/dist/src/tasks/nlp/translation.d.ts +2 -13
  55. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  56. package/dist/src/tasks/nlp/zeroShotClassification.d.ts +2 -22
  57. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  58. package/dist/src/types.d.ts +4 -0
  59. package/dist/src/types.d.ts.map +1 -1
  60. package/package.json +2 -2
  61. package/src/lib/makeRequestOptions.ts +7 -5
  62. package/src/providers/fal-ai.ts +12 -0
  63. package/src/providers/replicate.ts +6 -3
  64. package/src/providers/together.ts +2 -0
  65. package/src/tasks/audio/audioClassification.ts +7 -22
  66. package/src/tasks/audio/audioToAudio.ts +43 -23
  67. package/src/tasks/audio/automaticSpeechRecognition.ts +35 -23
  68. package/src/tasks/audio/textToSpeech.ts +23 -14
  69. package/src/tasks/audio/utils.ts +18 -0
  70. package/src/tasks/cv/imageClassification.ts +5 -20
  71. package/src/tasks/cv/imageSegmentation.ts +5 -24
  72. package/src/tasks/cv/imageToImage.ts +4 -52
  73. package/src/tasks/cv/imageToText.ts +6 -15
  74. package/src/tasks/cv/objectDetection.ts +5 -30
  75. package/src/tasks/cv/textToImage.ts +14 -50
  76. package/src/tasks/cv/textToVideo.ts +67 -0
  77. package/src/tasks/cv/utils.ts +13 -0
  78. package/src/tasks/cv/zeroShotImageClassification.ts +32 -31
  79. package/src/tasks/multimodal/documentQuestionAnswering.ts +25 -43
  80. package/src/tasks/multimodal/visualQuestionAnswering.ts +20 -36
  81. package/src/tasks/nlp/fillMask.ts +2 -22
  82. package/src/tasks/nlp/questionAnswering.ts +22 -36
  83. package/src/tasks/nlp/sentenceSimilarity.ts +12 -15
  84. package/src/tasks/nlp/summarization.ts +2 -43
  85. package/src/tasks/nlp/tableQuestionAnswering.ts +25 -41
  86. package/src/tasks/nlp/textClassification.ts +3 -18
  87. package/src/tasks/nlp/tokenClassification.ts +2 -47
  88. package/src/tasks/nlp/translation.ts +3 -17
  89. package/src/tasks/nlp/zeroShotClassification.ts +2 -24
  90. package/src/types.ts +7 -1
@@ -1,64 +1,16 @@
1
+ import type { ImageToImageInput } from "@huggingface/tasks";
1
2
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
3
  import type { BaseArgs, Options, RequestArgs } from "../../types";
3
4
  import { base64FromBytes } from "../../utils/base64FromBytes";
4
5
  import { request } from "../custom/request";
5
6
 
6
- export type ImageToImageArgs = BaseArgs & {
7
- /**
8
- * The initial image condition
9
- *
10
- **/
11
- inputs: Blob | ArrayBuffer;
12
-
13
- parameters?: {
14
- /**
15
- * The text prompt to guide the image generation.
16
- */
17
- prompt?: string;
18
- /**
19
- * strengh param only works for SD img2img and alt diffusion img2img models
20
- * Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
21
- * will be used as a starting point, adding more noise to it the larger the `strength`. The number of
22
- * denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
23
- * be maximum and the denoising process will run for the full number of iterations specified in
24
- * `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
25
- **/
26
- strength?: number;
27
- /**
28
- * An optional negative prompt for the image generation
29
- */
30
- negative_prompt?: string;
31
- /**
32
- * The height in pixels of the generated image
33
- */
34
- height?: number;
35
- /**
36
- * The width in pixels of the generated image
37
- */
38
- width?: number;
39
- /**
40
- * The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
41
- */
42
- num_inference_steps?: number;
43
- /**
44
- * Guidance scale: Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
45
- */
46
- guidance_scale?: number;
47
- /**
48
- * guess_mode only works for ControlNet models, defaults to False In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
49
- * you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
50
- */
51
- guess_mode?: boolean;
52
- };
53
- };
54
-
55
- export type ImageToImageOutput = Blob;
7
+ export type ImageToImageArgs = BaseArgs & ImageToImageInput;
56
8
 
57
9
  /**
58
10
  * This task reads some text input and outputs an image.
59
11
  * Recommended model: lllyasviel/sd-controlnet-depth
60
12
  */
61
- export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<ImageToImageOutput> {
13
+ export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> {
62
14
  let reqArgs: RequestArgs;
63
15
  if (!args.parameters) {
64
16
  reqArgs = {
@@ -74,7 +26,7 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P
74
26
  ),
75
27
  };
76
28
  }
77
- const res = await request<ImageToImageOutput>(reqArgs, {
29
+ const res = await request<Blob>(reqArgs, {
78
30
  ...options,
79
31
  taskHint: "image-to-image",
80
32
  });
@@ -1,27 +1,18 @@
1
+ import type { ImageToTextInput, ImageToTextOutput } from "@huggingface/tasks";
1
2
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
3
  import type { BaseArgs, Options } from "../../types";
3
4
  import { request } from "../custom/request";
5
+ import type { LegacyImageInput } from "./utils";
6
+ import { preparePayload } from "./utils";
4
7
 
5
- export type ImageToTextArgs = BaseArgs & {
6
- /**
7
- * Binary image data
8
- */
9
- data: Blob | ArrayBuffer;
10
- };
11
-
12
- export interface ImageToTextOutput {
13
- /**
14
- * The generated caption
15
- */
16
- generated_text: string;
17
- }
18
-
8
+ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput);
19
9
  /**
20
10
  * This task reads some image input and outputs the text caption.
21
11
  */
22
12
  export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
13
+ const payload = preparePayload(args);
23
14
  const res = (
24
- await request<[ImageToTextOutput]>(args, {
15
+ await request<[ImageToTextOutput]>(payload, {
25
16
  ...options,
26
17
  taskHint: "image-to-text",
27
18
  })
@@ -1,43 +1,18 @@
1
1
  import { request } from "../custom/request";
2
2
  import type { BaseArgs, Options } from "../../types";
3
3
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
4
+ import type { ObjectDetectionInput, ObjectDetectionOutput } from "@huggingface/tasks";
5
+ import { preparePayload, type LegacyImageInput } from "./utils";
4
6
 
5
- export type ObjectDetectionArgs = BaseArgs & {
6
- /**
7
- * Binary image data
8
- */
9
- data: Blob | ArrayBuffer;
10
- };
11
-
12
- export interface ObjectDetectionOutputValue {
13
- /**
14
- * A dict (with keys [xmin,ymin,xmax,ymax]) representing the bounding box of a detected object.
15
- */
16
- box: {
17
- xmax: number;
18
- xmin: number;
19
- ymax: number;
20
- ymin: number;
21
- };
22
- /**
23
- * The label for the class (model specific) of a detected object.
24
- */
25
- label: string;
26
-
27
- /**
28
- * A float that represents how likely it is that the detected object belongs to the given class.
29
- */
30
- score: number;
31
- }
32
-
33
- export type ObjectDetectionOutput = ObjectDetectionOutputValue[];
7
+ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImageInput);
34
8
 
35
9
  /**
36
10
  * This task reads some image input and outputs the likelihood of classes & bounding boxes of detected objects.
37
11
  * Recommended model: facebook/detr-resnet-50
38
12
  */
39
13
  export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> {
40
- const res = await request<ObjectDetectionOutput>(args, {
14
+ const payload = preparePayload(args);
15
+ const res = await request<ObjectDetectionOutput>(payload, {
41
16
  ...options,
42
17
  taskHint: "object-detection",
43
18
  });
@@ -1,47 +1,10 @@
1
+ import type { TextToImageInput, TextToImageOutput } from "@huggingface/tasks";
1
2
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
3
  import type { BaseArgs, Options } from "../../types";
4
+ import { omit } from "../../utils/omit";
3
5
  import { request } from "../custom/request";
4
6
 
5
- export type TextToImageArgs = BaseArgs & {
6
- /**
7
- * The text to generate an image from
8
- */
9
- inputs: string;
10
-
11
- /**
12
- * Same param but for external providers like Together, Replicate
13
- */
14
- prompt?: string;
15
- response_format?: "base64";
16
- input?: {
17
- prompt: string;
18
- };
19
-
20
- parameters?: {
21
- /**
22
- * An optional negative prompt for the image generation
23
- */
24
- negative_prompt?: string;
25
- /**
26
- * The height in pixels of the generated image
27
- */
28
- height?: number;
29
- /**
30
- * The width in pixels of the generated image
31
- */
32
- width?: number;
33
- /**
34
- * The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
35
- */
36
- num_inference_steps?: number;
37
- /**
38
- * Guidance scale: Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
39
- */
40
- guidance_scale?: number;
41
- };
42
- };
43
-
44
- export type TextToImageOutput = Blob;
7
+ export type TextToImageArgs = BaseArgs & TextToImageInput;
45
8
 
46
9
  interface Base64ImageGeneration {
47
10
  data: Array<{
@@ -56,16 +19,17 @@ interface OutputUrlImageGeneration {
56
19
  * This task reads some text input and outputs an image.
57
20
  * Recommended model: stabilityai/stable-diffusion-2
58
21
  */
59
- export async function textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageOutput> {
60
- if (args.provider === "together" || args.provider === "fal-ai") {
61
- args.prompt = args.inputs;
62
- args.inputs = "";
63
- args.response_format = "base64";
64
- } else if (args.provider === "replicate") {
65
- args.input = { prompt: args.inputs };
66
- delete (args as unknown as { inputs: unknown }).inputs;
67
- }
68
- const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(args, {
22
+ export async function textToImage(args: TextToImageArgs, options?: Options): Promise<Blob> {
23
+ const payload =
24
+ args.provider === "together" || args.provider === "fal-ai" || args.provider === "replicate"
25
+ ? {
26
+ ...omit(args, ["inputs", "parameters"]),
27
+ ...args.parameters,
28
+ ...(args.provider !== "replicate" ? { response_format: "base64" } : undefined),
29
+ prompt: args.inputs,
30
+ }
31
+ : args;
32
+ const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(payload, {
69
33
  ...options,
70
34
  taskHint: "text-to-image",
71
35
  });
@@ -0,0 +1,67 @@
1
+ import type { BaseArgs, InferenceProvider, Options } from "../../types";
2
+ import type { TextToVideoInput } from "@huggingface/tasks";
3
+ import { request } from "../custom/request";
4
+ import { omit } from "../../utils/omit";
5
+ import { isUrl } from "../../lib/isUrl";
6
+ import { InferenceOutputError } from "../../lib/InferenceOutputError";
7
+ import { typedInclude } from "../../utils/typedInclude";
8
+
9
+ export type TextToVideoArgs = BaseArgs & TextToVideoInput;
10
+
11
+ export type TextToVideoOutput = Blob;
12
+
13
+ interface FalAiOutput {
14
+ video: {
15
+ url: string;
16
+ };
17
+ }
18
+
19
+ interface ReplicateOutput {
20
+ output: string;
21
+ }
22
+
23
+ const SUPPORTED_PROVIDERS = ["fal-ai", "replicate"] as const satisfies readonly InferenceProvider[];
24
+
25
+ export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise<TextToVideoOutput> {
26
+ if (!args.provider || !typedInclude(SUPPORTED_PROVIDERS, args.provider)) {
27
+ throw new Error(
28
+ `textToVideo inference is only supported for the following providers: ${SUPPORTED_PROVIDERS.join(", ")}`
29
+ );
30
+ }
31
+
32
+ const payload =
33
+ args.provider === "fal-ai" || args.provider === "replicate"
34
+ ? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs }
35
+ : args;
36
+ const res = await request<FalAiOutput | ReplicateOutput>(payload, {
37
+ ...options,
38
+ taskHint: "text-to-video",
39
+ });
40
+
41
+ if (args.provider === "fal-ai") {
42
+ const isValidOutput =
43
+ typeof res === "object" &&
44
+ !!res &&
45
+ "video" in res &&
46
+ typeof res.video === "object" &&
47
+ !!res.video &&
48
+ "url" in res.video &&
49
+ typeof res.video.url === "string" &&
50
+ isUrl(res.video.url);
51
+ if (!isValidOutput) {
52
+ throw new InferenceOutputError("Expected { video: { url: string } }");
53
+ }
54
+ const urlResponse = await fetch(res.video.url);
55
+ return await urlResponse.blob();
56
+ } else {
57
+ /// TODO: Replicate: handle the case where the generation request "times out" / is async (ie output is null)
58
+ /// https://replicate.com/docs/topics/predictions/create-a-prediction
59
+ const isValidOutput =
60
+ typeof res === "object" && !!res && "output" in res && typeof res.output === "string" && isUrl(res.output);
61
+ if (!isValidOutput) {
62
+ throw new InferenceOutputError("Expected { output: string }");
63
+ }
64
+ const urlResponse = await fetch(res.output);
65
+ return await urlResponse.blob();
66
+ }
67
+ }
@@ -0,0 +1,13 @@
1
+ import type { BaseArgs, RequestArgs } from "../../types";
2
+ import { omit } from "../../utils/omit";
3
+
4
+ /**
5
+ * @deprecated
6
+ */
7
+ export interface LegacyImageInput {
8
+ data: Blob | ArrayBuffer;
9
+ }
10
+
11
+ export function preparePayload(args: BaseArgs & ({ inputs: Blob } | LegacyImageInput)): RequestArgs {
12
+ return "data" in args ? args : { ...omit(args, "inputs"), data: args.inputs };
13
+ }
@@ -3,28 +3,39 @@ import type { BaseArgs, Options } from "../../types";
3
3
  import { request } from "../custom/request";
4
4
  import type { RequestArgs } from "../../types";
5
5
  import { base64FromBytes } from "../../utils/base64FromBytes";
6
+ import type { ZeroShotImageClassificationInput, ZeroShotImageClassificationOutput } from "@huggingface/tasks";
6
7
 
7
- export type ZeroShotImageClassificationArgs = BaseArgs & {
8
- inputs: {
9
- /**
10
- * Binary image data
11
- */
12
- image: Blob | ArrayBuffer;
13
- };
14
- parameters: {
15
- /**
16
- * A list of strings that are potential classes for inputs. (max 10)
17
- */
18
- candidate_labels: string[];
19
- };
20
- };
21
-
22
- export interface ZeroShotImageClassificationOutputValue {
23
- label: string;
24
- score: number;
8
+ /**
9
+ * @deprecated
10
+ */
11
+ interface LegacyZeroShotImageClassificationInput {
12
+ inputs: { image: Blob | ArrayBuffer };
25
13
  }
26
14
 
27
- export type ZeroShotImageClassificationOutput = ZeroShotImageClassificationOutputValue[];
15
+ export type ZeroShotImageClassificationArgs = BaseArgs &
16
+ (ZeroShotImageClassificationInput | LegacyZeroShotImageClassificationInput);
17
+
18
+ async function preparePayload(args: ZeroShotImageClassificationArgs): Promise<RequestArgs> {
19
+ if (args.inputs instanceof Blob) {
20
+ return {
21
+ ...args,
22
+ inputs: {
23
+ image: base64FromBytes(new Uint8Array(await args.inputs.arrayBuffer())),
24
+ },
25
+ };
26
+ } else {
27
+ return {
28
+ ...args,
29
+ inputs: {
30
+ image: base64FromBytes(
31
+ new Uint8Array(
32
+ args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
33
+ )
34
+ ),
35
+ },
36
+ };
37
+ }
38
+ }
28
39
 
29
40
  /**
30
41
  * Classify an image to specified classes.
@@ -34,18 +45,8 @@ export async function zeroShotImageClassification(
34
45
  args: ZeroShotImageClassificationArgs,
35
46
  options?: Options
36
47
  ): Promise<ZeroShotImageClassificationOutput> {
37
- const reqArgs: RequestArgs = {
38
- ...args,
39
- inputs: {
40
- image: base64FromBytes(
41
- new Uint8Array(
42
- args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
43
- )
44
- ),
45
- },
46
- } as RequestArgs;
47
-
48
- const res = await request<ZeroShotImageClassificationOutput>(reqArgs, {
48
+ const payload = await preparePayload(args);
49
+ const res = await request<ZeroShotImageClassificationOutput>(payload, {
49
50
  ...options,
50
51
  taskHint: "zero-shot-image-classification",
51
52
  });
@@ -4,37 +4,15 @@ import { request } from "../custom/request";
4
4
  import type { RequestArgs } from "../../types";
5
5
  import { toArray } from "../../utils/toArray";
6
6
  import { base64FromBytes } from "../../utils/base64FromBytes";
7
+ import type {
8
+ DocumentQuestionAnsweringInput,
9
+ DocumentQuestionAnsweringInputData,
10
+ DocumentQuestionAnsweringOutput,
11
+ } from "@huggingface/tasks";
7
12
 
8
- export type DocumentQuestionAnsweringArgs = BaseArgs & {
9
- inputs: {
10
- /**
11
- * Raw image
12
- *
13
- * You can use native `File` in browsers, or `new Blob([buffer])` in node, or for a base64 image `new Blob([btoa(base64String)])`, or even `await (await fetch('...)).blob()`
14
- **/
15
- image: Blob | ArrayBuffer;
16
- question: string;
17
- };
18
- };
19
-
20
- export interface DocumentQuestionAnsweringOutput {
21
- /**
22
- * A string that’s the answer within the document.
23
- */
24
- answer: string;
25
- /**
26
- * ?
27
- */
28
- end?: number;
29
- /**
30
- * A float that represents how likely that the answer is correct
31
- */
32
- score?: number;
33
- /**
34
- * ?
35
- */
36
- start?: number;
37
- }
13
+ /// Override the type to properly set inputs.image as Blob
14
+ export type DocumentQuestionAnsweringArgs = BaseArgs &
15
+ DocumentQuestionAnsweringInput & { inputs: DocumentQuestionAnsweringInputData & { image: Blob } };
38
16
 
39
17
  /**
40
18
  * Answers a question on a document image. Recommended model: impira/layoutlm-document-qa.
@@ -42,32 +20,36 @@ export interface DocumentQuestionAnsweringOutput {
42
20
  export async function documentQuestionAnswering(
43
21
  args: DocumentQuestionAnsweringArgs,
44
22
  options?: Options
45
- ): Promise<DocumentQuestionAnsweringOutput> {
23
+ ): Promise<DocumentQuestionAnsweringOutput[number]> {
46
24
  const reqArgs: RequestArgs = {
47
25
  ...args,
48
26
  inputs: {
49
27
  question: args.inputs.question,
50
28
  // convert Blob or ArrayBuffer to base64
51
- image: base64FromBytes(
52
- new Uint8Array(
53
- args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
54
- )
55
- ),
29
+ image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer())),
56
30
  },
57
31
  } as RequestArgs;
58
32
  const res = toArray(
59
- await request<[DocumentQuestionAnsweringOutput] | DocumentQuestionAnsweringOutput>(reqArgs, {
33
+ await request<DocumentQuestionAnsweringOutput | DocumentQuestionAnsweringOutput[number]>(reqArgs, {
60
34
  ...options,
61
35
  taskHint: "document-question-answering",
62
36
  })
63
- )?.[0];
37
+ );
38
+
64
39
  const isValidOutput =
65
- typeof res?.answer === "string" &&
66
- (typeof res.end === "number" || typeof res.end === "undefined") &&
67
- (typeof res.score === "number" || typeof res.score === "undefined") &&
68
- (typeof res.start === "number" || typeof res.start === "undefined");
40
+ Array.isArray(res) &&
41
+ res.every(
42
+ (elem) =>
43
+ typeof elem === "object" &&
44
+ !!elem &&
45
+ typeof elem?.answer === "string" &&
46
+ (typeof elem.end === "number" || typeof elem.end === "undefined") &&
47
+ (typeof elem.score === "number" || typeof elem.score === "undefined") &&
48
+ (typeof elem.start === "number" || typeof elem.start === "undefined")
49
+ );
69
50
  if (!isValidOutput) {
70
51
  throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>");
71
52
  }
72
- return res;
53
+
54
+ return res[0];
73
55
  }
@@ -1,30 +1,16 @@
1
+ import type {
2
+ VisualQuestionAnsweringInput,
3
+ VisualQuestionAnsweringInputData,
4
+ VisualQuestionAnsweringOutput,
5
+ } from "@huggingface/tasks";
1
6
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
7
  import type { BaseArgs, Options, RequestArgs } from "../../types";
3
8
  import { base64FromBytes } from "../../utils/base64FromBytes";
4
9
  import { request } from "../custom/request";
5
10
 
6
- export type VisualQuestionAnsweringArgs = BaseArgs & {
7
- inputs: {
8
- /**
9
- * Raw image
10
- *
11
- * You can use native `File` in browsers, or `new Blob([buffer])` in node, or for a base64 image `new Blob([btoa(base64String)])`, or even `await (await fetch('...)).blob()`
12
- **/
13
- image: Blob | ArrayBuffer;
14
- question: string;
15
- };
16
- };
17
-
18
- export interface VisualQuestionAnsweringOutput {
19
- /**
20
- * A string that’s the answer to a visual question.
21
- */
22
- answer: string;
23
- /**
24
- * Answer correctness score.
25
- */
26
- score: number;
27
- }
11
+ /// Override the type to properly set inputs.image as Blob
12
+ export type VisualQuestionAnsweringArgs = BaseArgs &
13
+ VisualQuestionAnsweringInput & { inputs: VisualQuestionAnsweringInputData & { image: Blob } };
28
14
 
29
15
  /**
30
16
  * Answers a question on an image. Recommended model: dandelin/vilt-b32-finetuned-vqa.
@@ -32,28 +18,26 @@ export interface VisualQuestionAnsweringOutput {
32
18
  export async function visualQuestionAnswering(
33
19
  args: VisualQuestionAnsweringArgs,
34
20
  options?: Options
35
- ): Promise<VisualQuestionAnsweringOutput> {
21
+ ): Promise<VisualQuestionAnsweringOutput[number]> {
36
22
  const reqArgs: RequestArgs = {
37
23
  ...args,
38
24
  inputs: {
39
25
  question: args.inputs.question,
40
26
  // convert Blob or ArrayBuffer to base64
41
- image: base64FromBytes(
42
- new Uint8Array(
43
- args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
44
- )
45
- ),
27
+ image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer())),
46
28
  },
47
29
  } as RequestArgs;
48
- const res = (
49
- await request<[VisualQuestionAnsweringOutput]>(reqArgs, {
50
- ...options,
51
- taskHint: "visual-question-answering",
52
- })
53
- )?.[0];
54
- const isValidOutput = typeof res?.answer === "string" && typeof res.score === "number";
30
+ const res = await request<VisualQuestionAnsweringOutput>(reqArgs, {
31
+ ...options,
32
+ taskHint: "visual-question-answering",
33
+ });
34
+ const isValidOutput =
35
+ Array.isArray(res) &&
36
+ res.every(
37
+ (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
38
+ );
55
39
  if (!isValidOutput) {
56
40
  throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
57
41
  }
58
- return res;
42
+ return res[0];
59
43
  }
@@ -1,29 +1,9 @@
1
+ import type { FillMaskInput, FillMaskOutput } from "@huggingface/tasks";
1
2
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
3
  import type { BaseArgs, Options } from "../../types";
3
4
  import { request } from "../custom/request";
4
5
 
5
- export type FillMaskArgs = BaseArgs & {
6
- inputs: string;
7
- };
8
-
9
- export type FillMaskOutput = {
10
- /**
11
- * The probability for this token.
12
- */
13
- score: number;
14
- /**
15
- * The actual sequence of tokens that ran against the model (may contain special tokens)
16
- */
17
- sequence: string;
18
- /**
19
- * The id of the token
20
- */
21
- token: number;
22
- /**
23
- * The string representation of the token
24
- */
25
- token_str: string;
26
- }[];
6
+ export type FillMaskArgs = BaseArgs & FillMaskInput;
27
7
 
28
8
  /**
29
9
  * Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.