@huggingface/inference 3.0.1 → 3.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. package/dist/index.cjs +162 -69
  2. package/dist/index.js +162 -69
  3. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  4. package/dist/src/providers/replicate.d.ts.map +1 -1
  5. package/dist/src/tasks/audio/audioClassification.d.ts +4 -18
  6. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  7. package/dist/src/tasks/audio/audioToAudio.d.ts +10 -9
  8. package/dist/src/tasks/audio/audioToAudio.d.ts.map +1 -1
  9. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts +3 -12
  10. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  11. package/dist/src/tasks/audio/textToSpeech.d.ts +4 -8
  12. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  13. package/dist/src/tasks/audio/utils.d.ts +11 -0
  14. package/dist/src/tasks/audio/utils.d.ts.map +1 -0
  15. package/dist/src/tasks/cv/imageClassification.d.ts +3 -17
  16. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  17. package/dist/src/tasks/cv/imageSegmentation.d.ts +3 -21
  18. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  19. package/dist/src/tasks/cv/imageToImage.d.ts +3 -49
  20. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  21. package/dist/src/tasks/cv/imageToText.d.ts +3 -12
  22. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  23. package/dist/src/tasks/cv/objectDetection.d.ts +3 -26
  24. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  25. package/dist/src/tasks/cv/textToImage.d.ts +3 -38
  26. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  27. package/dist/src/tasks/cv/textToVideo.d.ts +6 -0
  28. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -0
  29. package/dist/src/tasks/cv/utils.d.ts +11 -0
  30. package/dist/src/tasks/cv/utils.d.ts.map +1 -0
  31. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts +7 -15
  32. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  33. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts +5 -28
  34. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  35. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts +5 -20
  36. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  37. package/dist/src/tasks/nlp/fillMask.d.ts +2 -21
  38. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  39. package/dist/src/tasks/nlp/questionAnswering.d.ts +3 -25
  40. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  41. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts +2 -13
  42. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  43. package/dist/src/tasks/nlp/summarization.d.ts +2 -42
  44. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  45. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts +3 -31
  46. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  47. package/dist/src/tasks/nlp/textClassification.d.ts +2 -16
  48. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  49. package/dist/src/tasks/nlp/tokenClassification.d.ts +2 -45
  50. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  51. package/dist/src/tasks/nlp/translation.d.ts +2 -13
  52. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  53. package/dist/src/tasks/nlp/zeroShotClassification.d.ts +2 -22
  54. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  55. package/dist/src/types.d.ts +4 -0
  56. package/dist/src/types.d.ts.map +1 -1
  57. package/package.json +2 -2
  58. package/src/providers/fal-ai.ts +4 -0
  59. package/src/providers/replicate.ts +3 -0
  60. package/src/tasks/audio/audioClassification.ts +7 -22
  61. package/src/tasks/audio/audioToAudio.ts +43 -23
  62. package/src/tasks/audio/automaticSpeechRecognition.ts +35 -23
  63. package/src/tasks/audio/textToSpeech.ts +8 -14
  64. package/src/tasks/audio/utils.ts +18 -0
  65. package/src/tasks/cv/imageClassification.ts +5 -20
  66. package/src/tasks/cv/imageSegmentation.ts +5 -24
  67. package/src/tasks/cv/imageToImage.ts +4 -52
  68. package/src/tasks/cv/imageToText.ts +6 -15
  69. package/src/tasks/cv/objectDetection.ts +5 -30
  70. package/src/tasks/cv/textToImage.ts +14 -50
  71. package/src/tasks/cv/textToVideo.ts +67 -0
  72. package/src/tasks/cv/utils.ts +13 -0
  73. package/src/tasks/cv/zeroShotImageClassification.ts +32 -31
  74. package/src/tasks/multimodal/documentQuestionAnswering.ts +25 -43
  75. package/src/tasks/multimodal/visualQuestionAnswering.ts +20 -36
  76. package/src/tasks/nlp/fillMask.ts +2 -22
  77. package/src/tasks/nlp/questionAnswering.ts +22 -36
  78. package/src/tasks/nlp/sentenceSimilarity.ts +12 -15
  79. package/src/tasks/nlp/summarization.ts +2 -43
  80. package/src/tasks/nlp/tableQuestionAnswering.ts +25 -41
  81. package/src/tasks/nlp/textClassification.ts +3 -18
  82. package/src/tasks/nlp/tokenClassification.ts +2 -47
  83. package/src/tasks/nlp/translation.ts +3 -17
  84. package/src/tasks/nlp/zeroShotClassification.ts +2 -24
  85. package/src/types.ts +7 -1
@@ -1,43 +1,18 @@
1
1
  import { request } from "../custom/request";
2
2
  import type { BaseArgs, Options } from "../../types";
3
3
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
4
+ import type { ObjectDetectionInput, ObjectDetectionOutput } from "@huggingface/tasks";
5
+ import { preparePayload, type LegacyImageInput } from "./utils";
4
6
 
5
- export type ObjectDetectionArgs = BaseArgs & {
6
- /**
7
- * Binary image data
8
- */
9
- data: Blob | ArrayBuffer;
10
- };
11
-
12
- export interface ObjectDetectionOutputValue {
13
- /**
14
- * A dict (with keys [xmin,ymin,xmax,ymax]) representing the bounding box of a detected object.
15
- */
16
- box: {
17
- xmax: number;
18
- xmin: number;
19
- ymax: number;
20
- ymin: number;
21
- };
22
- /**
23
- * The label for the class (model specific) of a detected object.
24
- */
25
- label: string;
26
-
27
- /**
28
- * A float that represents how likely it is that the detected object belongs to the given class.
29
- */
30
- score: number;
31
- }
32
-
33
- export type ObjectDetectionOutput = ObjectDetectionOutputValue[];
7
+ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImageInput);
34
8
 
35
9
  /**
36
10
  * This task reads some image input and outputs the likelihood of classes & bounding boxes of detected objects.
37
11
  * Recommended model: facebook/detr-resnet-50
38
12
  */
39
13
  export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> {
40
- const res = await request<ObjectDetectionOutput>(args, {
14
+ const payload = preparePayload(args);
15
+ const res = await request<ObjectDetectionOutput>(payload, {
41
16
  ...options,
42
17
  taskHint: "object-detection",
43
18
  });
@@ -1,47 +1,10 @@
1
+ import type { TextToImageInput, TextToImageOutput } from "@huggingface/tasks";
1
2
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
3
  import type { BaseArgs, Options } from "../../types";
4
+ import { omit } from "../../utils/omit";
3
5
  import { request } from "../custom/request";
4
6
 
5
- export type TextToImageArgs = BaseArgs & {
6
- /**
7
- * The text to generate an image from
8
- */
9
- inputs: string;
10
-
11
- /**
12
- * Same param but for external providers like Together, Replicate
13
- */
14
- prompt?: string;
15
- response_format?: "base64";
16
- input?: {
17
- prompt: string;
18
- };
19
-
20
- parameters?: {
21
- /**
22
- * An optional negative prompt for the image generation
23
- */
24
- negative_prompt?: string;
25
- /**
26
- * The height in pixels of the generated image
27
- */
28
- height?: number;
29
- /**
30
- * The width in pixels of the generated image
31
- */
32
- width?: number;
33
- /**
34
- * The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
35
- */
36
- num_inference_steps?: number;
37
- /**
38
- * Guidance scale: Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
39
- */
40
- guidance_scale?: number;
41
- };
42
- };
43
-
44
- export type TextToImageOutput = Blob;
7
+ export type TextToImageArgs = BaseArgs & TextToImageInput;
45
8
 
46
9
  interface Base64ImageGeneration {
47
10
  data: Array<{
@@ -56,16 +19,17 @@ interface OutputUrlImageGeneration {
56
19
  * This task reads some text input and outputs an image.
57
20
  * Recommended model: stabilityai/stable-diffusion-2
58
21
  */
59
- export async function textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageOutput> {
60
- if (args.provider === "together" || args.provider === "fal-ai") {
61
- args.prompt = args.inputs;
62
- delete (args as unknown as { inputs: unknown }).inputs;
63
- args.response_format = "base64";
64
- } else if (args.provider === "replicate") {
65
- args.prompt = args.inputs;
66
- delete (args as unknown as { inputs: unknown }).inputs;
67
- }
68
- const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(args, {
22
+ export async function textToImage(args: TextToImageArgs, options?: Options): Promise<Blob> {
23
+ const payload =
24
+ args.provider === "together" || args.provider === "fal-ai" || args.provider === "replicate"
25
+ ? {
26
+ ...omit(args, ["inputs", "parameters"]),
27
+ ...args.parameters,
28
+ ...(args.provider !== "replicate" ? { response_format: "base64" } : undefined),
29
+ prompt: args.inputs,
30
+ }
31
+ : args;
32
+ const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(payload, {
69
33
  ...options,
70
34
  taskHint: "text-to-image",
71
35
  });
@@ -0,0 +1,67 @@
1
+ import type { BaseArgs, InferenceProvider, Options } from "../../types";
2
+ import type { TextToVideoInput } from "@huggingface/tasks";
3
+ import { request } from "../custom/request";
4
+ import { omit } from "../../utils/omit";
5
+ import { isUrl } from "../../lib/isUrl";
6
+ import { InferenceOutputError } from "../../lib/InferenceOutputError";
7
+ import { typedInclude } from "../../utils/typedInclude";
8
+
9
+ export type TextToVideoArgs = BaseArgs & TextToVideoInput;
10
+
11
+ export type TextToVideoOutput = Blob;
12
+
13
+ interface FalAiOutput {
14
+ video: {
15
+ url: string;
16
+ };
17
+ }
18
+
19
+ interface ReplicateOutput {
20
+ output: string;
21
+ }
22
+
23
+ const SUPPORTED_PROVIDERS = ["fal-ai", "replicate"] as const satisfies readonly InferenceProvider[];
24
+
25
+ export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise<TextToVideoOutput> {
26
+ if (!args.provider || !typedInclude(SUPPORTED_PROVIDERS, args.provider)) {
27
+ throw new Error(
28
+ `textToVideo inference is only supported for the following providers: ${SUPPORTED_PROVIDERS.join(", ")}`
29
+ );
30
+ }
31
+
32
+ const payload =
33
+ args.provider === "fal-ai" || args.provider === "replicate"
34
+ ? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs }
35
+ : args;
36
+ const res = await request<FalAiOutput | ReplicateOutput>(payload, {
37
+ ...options,
38
+ taskHint: "text-to-video",
39
+ });
40
+
41
+ if (args.provider === "fal-ai") {
42
+ const isValidOutput =
43
+ typeof res === "object" &&
44
+ !!res &&
45
+ "video" in res &&
46
+ typeof res.video === "object" &&
47
+ !!res.video &&
48
+ "url" in res.video &&
49
+ typeof res.video.url === "string" &&
50
+ isUrl(res.video.url);
51
+ if (!isValidOutput) {
52
+ throw new InferenceOutputError("Expected { video: { url: string } }");
53
+ }
54
+ const urlResponse = await fetch(res.video.url);
55
+ return await urlResponse.blob();
56
+ } else {
57
+ /// TODO: Replicate: handle the case where the generation request "times out" / is async (ie output is null)
58
+ /// https://replicate.com/docs/topics/predictions/create-a-prediction
59
+ const isValidOutput =
60
+ typeof res === "object" && !!res && "output" in res && typeof res.output === "string" && isUrl(res.output);
61
+ if (!isValidOutput) {
62
+ throw new InferenceOutputError("Expected { output: string }");
63
+ }
64
+ const urlResponse = await fetch(res.output);
65
+ return await urlResponse.blob();
66
+ }
67
+ }
@@ -0,0 +1,13 @@
1
+ import type { BaseArgs, RequestArgs } from "../../types";
2
+ import { omit } from "../../utils/omit";
3
+
4
+ /**
5
+ * @deprecated
6
+ */
7
+ export interface LegacyImageInput {
8
+ data: Blob | ArrayBuffer;
9
+ }
10
+
11
+ export function preparePayload(args: BaseArgs & ({ inputs: Blob } | LegacyImageInput)): RequestArgs {
12
+ return "data" in args ? args : { ...omit(args, "inputs"), data: args.inputs };
13
+ }
@@ -3,28 +3,39 @@ import type { BaseArgs, Options } from "../../types";
3
3
  import { request } from "../custom/request";
4
4
  import type { RequestArgs } from "../../types";
5
5
  import { base64FromBytes } from "../../utils/base64FromBytes";
6
+ import type { ZeroShotImageClassificationInput, ZeroShotImageClassificationOutput } from "@huggingface/tasks";
6
7
 
7
- export type ZeroShotImageClassificationArgs = BaseArgs & {
8
- inputs: {
9
- /**
10
- * Binary image data
11
- */
12
- image: Blob | ArrayBuffer;
13
- };
14
- parameters: {
15
- /**
16
- * A list of strings that are potential classes for inputs. (max 10)
17
- */
18
- candidate_labels: string[];
19
- };
20
- };
21
-
22
- export interface ZeroShotImageClassificationOutputValue {
23
- label: string;
24
- score: number;
8
+ /**
9
+ * @deprecated
10
+ */
11
+ interface LegacyZeroShotImageClassificationInput {
12
+ inputs: { image: Blob | ArrayBuffer };
25
13
  }
26
14
 
27
- export type ZeroShotImageClassificationOutput = ZeroShotImageClassificationOutputValue[];
15
+ export type ZeroShotImageClassificationArgs = BaseArgs &
16
+ (ZeroShotImageClassificationInput | LegacyZeroShotImageClassificationInput);
17
+
18
+ async function preparePayload(args: ZeroShotImageClassificationArgs): Promise<RequestArgs> {
19
+ if (args.inputs instanceof Blob) {
20
+ return {
21
+ ...args,
22
+ inputs: {
23
+ image: base64FromBytes(new Uint8Array(await args.inputs.arrayBuffer())),
24
+ },
25
+ };
26
+ } else {
27
+ return {
28
+ ...args,
29
+ inputs: {
30
+ image: base64FromBytes(
31
+ new Uint8Array(
32
+ args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
33
+ )
34
+ ),
35
+ },
36
+ };
37
+ }
38
+ }
28
39
 
29
40
  /**
30
41
  * Classify an image to specified classes.
@@ -34,18 +45,8 @@ export async function zeroShotImageClassification(
34
45
  args: ZeroShotImageClassificationArgs,
35
46
  options?: Options
36
47
  ): Promise<ZeroShotImageClassificationOutput> {
37
- const reqArgs: RequestArgs = {
38
- ...args,
39
- inputs: {
40
- image: base64FromBytes(
41
- new Uint8Array(
42
- args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
43
- )
44
- ),
45
- },
46
- } as RequestArgs;
47
-
48
- const res = await request<ZeroShotImageClassificationOutput>(reqArgs, {
48
+ const payload = await preparePayload(args);
49
+ const res = await request<ZeroShotImageClassificationOutput>(payload, {
49
50
  ...options,
50
51
  taskHint: "zero-shot-image-classification",
51
52
  });
@@ -4,37 +4,15 @@ import { request } from "../custom/request";
4
4
  import type { RequestArgs } from "../../types";
5
5
  import { toArray } from "../../utils/toArray";
6
6
  import { base64FromBytes } from "../../utils/base64FromBytes";
7
+ import type {
8
+ DocumentQuestionAnsweringInput,
9
+ DocumentQuestionAnsweringInputData,
10
+ DocumentQuestionAnsweringOutput,
11
+ } from "@huggingface/tasks";
7
12
 
8
- export type DocumentQuestionAnsweringArgs = BaseArgs & {
9
- inputs: {
10
- /**
11
- * Raw image
12
- *
13
- * You can use native `File` in browsers, or `new Blob([buffer])` in node, or for a base64 image `new Blob([btoa(base64String)])`, or even `await (await fetch('...)).blob()`
14
- **/
15
- image: Blob | ArrayBuffer;
16
- question: string;
17
- };
18
- };
19
-
20
- export interface DocumentQuestionAnsweringOutput {
21
- /**
22
- * A string that’s the answer within the document.
23
- */
24
- answer: string;
25
- /**
26
- * ?
27
- */
28
- end?: number;
29
- /**
30
- * A float that represents how likely that the answer is correct
31
- */
32
- score?: number;
33
- /**
34
- * ?
35
- */
36
- start?: number;
37
- }
13
+ /// Override the type to properly set inputs.image as Blob
14
+ export type DocumentQuestionAnsweringArgs = BaseArgs &
15
+ DocumentQuestionAnsweringInput & { inputs: DocumentQuestionAnsweringInputData & { image: Blob } };
38
16
 
39
17
  /**
40
18
  * Answers a question on a document image. Recommended model: impira/layoutlm-document-qa.
@@ -42,32 +20,36 @@ export interface DocumentQuestionAnsweringOutput {
42
20
  export async function documentQuestionAnswering(
43
21
  args: DocumentQuestionAnsweringArgs,
44
22
  options?: Options
45
- ): Promise<DocumentQuestionAnsweringOutput> {
23
+ ): Promise<DocumentQuestionAnsweringOutput[number]> {
46
24
  const reqArgs: RequestArgs = {
47
25
  ...args,
48
26
  inputs: {
49
27
  question: args.inputs.question,
50
28
  // convert Blob or ArrayBuffer to base64
51
- image: base64FromBytes(
52
- new Uint8Array(
53
- args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
54
- )
55
- ),
29
+ image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer())),
56
30
  },
57
31
  } as RequestArgs;
58
32
  const res = toArray(
59
- await request<[DocumentQuestionAnsweringOutput] | DocumentQuestionAnsweringOutput>(reqArgs, {
33
+ await request<DocumentQuestionAnsweringOutput | DocumentQuestionAnsweringOutput[number]>(reqArgs, {
60
34
  ...options,
61
35
  taskHint: "document-question-answering",
62
36
  })
63
- )?.[0];
37
+ );
38
+
64
39
  const isValidOutput =
65
- typeof res?.answer === "string" &&
66
- (typeof res.end === "number" || typeof res.end === "undefined") &&
67
- (typeof res.score === "number" || typeof res.score === "undefined") &&
68
- (typeof res.start === "number" || typeof res.start === "undefined");
40
+ Array.isArray(res) &&
41
+ res.every(
42
+ (elem) =>
43
+ typeof elem === "object" &&
44
+ !!elem &&
45
+ typeof elem?.answer === "string" &&
46
+ (typeof elem.end === "number" || typeof elem.end === "undefined") &&
47
+ (typeof elem.score === "number" || typeof elem.score === "undefined") &&
48
+ (typeof elem.start === "number" || typeof elem.start === "undefined")
49
+ );
69
50
  if (!isValidOutput) {
70
51
  throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>");
71
52
  }
72
- return res;
53
+
54
+ return res[0];
73
55
  }
@@ -1,30 +1,16 @@
1
+ import type {
2
+ VisualQuestionAnsweringInput,
3
+ VisualQuestionAnsweringInputData,
4
+ VisualQuestionAnsweringOutput,
5
+ } from "@huggingface/tasks";
1
6
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
7
  import type { BaseArgs, Options, RequestArgs } from "../../types";
3
8
  import { base64FromBytes } from "../../utils/base64FromBytes";
4
9
  import { request } from "../custom/request";
5
10
 
6
- export type VisualQuestionAnsweringArgs = BaseArgs & {
7
- inputs: {
8
- /**
9
- * Raw image
10
- *
11
- * You can use native `File` in browsers, or `new Blob([buffer])` in node, or for a base64 image `new Blob([btoa(base64String)])`, or even `await (await fetch('...)).blob()`
12
- **/
13
- image: Blob | ArrayBuffer;
14
- question: string;
15
- };
16
- };
17
-
18
- export interface VisualQuestionAnsweringOutput {
19
- /**
20
- * A string that’s the answer to a visual question.
21
- */
22
- answer: string;
23
- /**
24
- * Answer correctness score.
25
- */
26
- score: number;
27
- }
11
+ /// Override the type to properly set inputs.image as Blob
12
+ export type VisualQuestionAnsweringArgs = BaseArgs &
13
+ VisualQuestionAnsweringInput & { inputs: VisualQuestionAnsweringInputData & { image: Blob } };
28
14
 
29
15
  /**
30
16
  * Answers a question on an image. Recommended model: dandelin/vilt-b32-finetuned-vqa.
@@ -32,28 +18,26 @@ export interface VisualQuestionAnsweringOutput {
32
18
  export async function visualQuestionAnswering(
33
19
  args: VisualQuestionAnsweringArgs,
34
20
  options?: Options
35
- ): Promise<VisualQuestionAnsweringOutput> {
21
+ ): Promise<VisualQuestionAnsweringOutput[number]> {
36
22
  const reqArgs: RequestArgs = {
37
23
  ...args,
38
24
  inputs: {
39
25
  question: args.inputs.question,
40
26
  // convert Blob or ArrayBuffer to base64
41
- image: base64FromBytes(
42
- new Uint8Array(
43
- args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
44
- )
45
- ),
27
+ image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer())),
46
28
  },
47
29
  } as RequestArgs;
48
- const res = (
49
- await request<[VisualQuestionAnsweringOutput]>(reqArgs, {
50
- ...options,
51
- taskHint: "visual-question-answering",
52
- })
53
- )?.[0];
54
- const isValidOutput = typeof res?.answer === "string" && typeof res.score === "number";
30
+ const res = await request<VisualQuestionAnsweringOutput>(reqArgs, {
31
+ ...options,
32
+ taskHint: "visual-question-answering",
33
+ });
34
+ const isValidOutput =
35
+ Array.isArray(res) &&
36
+ res.every(
37
+ (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
38
+ );
55
39
  if (!isValidOutput) {
56
40
  throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
57
41
  }
58
- return res;
42
+ return res[0];
59
43
  }
@@ -1,29 +1,9 @@
1
+ import type { FillMaskInput, FillMaskOutput } from "@huggingface/tasks";
1
2
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
3
  import type { BaseArgs, Options } from "../../types";
3
4
  import { request } from "../custom/request";
4
5
 
5
- export type FillMaskArgs = BaseArgs & {
6
- inputs: string;
7
- };
8
-
9
- export type FillMaskOutput = {
10
- /**
11
- * The probability for this token.
12
- */
13
- score: number;
14
- /**
15
- * The actual sequence of tokens that ran against the model (may contain special tokens)
16
- */
17
- sequence: string;
18
- /**
19
- * The id of the token
20
- */
21
- token: number;
22
- /**
23
- * The string representation of the token
24
- */
25
- token_str: string;
26
- }[];
6
+ export type FillMaskArgs = BaseArgs & FillMaskInput;
27
7
 
28
8
  /**
29
9
  * Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
@@ -1,32 +1,9 @@
1
+ import type { QuestionAnsweringInput, QuestionAnsweringOutput } from "@huggingface/tasks";
1
2
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
3
  import type { BaseArgs, Options } from "../../types";
3
4
  import { request } from "../custom/request";
4
5
 
5
- export type QuestionAnsweringArgs = BaseArgs & {
6
- inputs: {
7
- context: string;
8
- question: string;
9
- };
10
- };
11
-
12
- export interface QuestionAnsweringOutput {
13
- /**
14
- * A string that’s the answer within the text.
15
- */
16
- answer: string;
17
- /**
18
- * The index (string wise) of the stop of the answer within context.
19
- */
20
- end: number;
21
- /**
22
- * A float that represents how likely that the answer is correct
23
- */
24
- score: number;
25
- /**
26
- * The index (string wise) of the start of the answer within context.
27
- */
28
- start: number;
29
- }
6
+ export type QuestionAnsweringArgs = BaseArgs & QuestionAnsweringInput;
30
7
 
31
8
  /**
32
9
  * Want to have a nice know-it-all bot that can answer any question?. Recommended model: deepset/roberta-base-squad2
@@ -34,20 +11,29 @@ export interface QuestionAnsweringOutput {
34
11
  export async function questionAnswering(
35
12
  args: QuestionAnsweringArgs,
36
13
  options?: Options
37
- ): Promise<QuestionAnsweringOutput> {
38
- const res = await request<QuestionAnsweringOutput>(args, {
14
+ ): Promise<QuestionAnsweringOutput[number]> {
15
+ const res = await request<QuestionAnsweringOutput | QuestionAnsweringOutput[number]>(args, {
39
16
  ...options,
40
17
  taskHint: "question-answering",
41
18
  });
42
- const isValidOutput =
43
- typeof res === "object" &&
44
- !!res &&
45
- typeof res.answer === "string" &&
46
- typeof res.end === "number" &&
47
- typeof res.score === "number" &&
48
- typeof res.start === "number";
19
+ const isValidOutput = Array.isArray(res)
20
+ ? res.every(
21
+ (elem) =>
22
+ typeof elem === "object" &&
23
+ !!elem &&
24
+ typeof elem.answer === "string" &&
25
+ typeof elem.end === "number" &&
26
+ typeof elem.score === "number" &&
27
+ typeof elem.start === "number"
28
+ )
29
+ : typeof res === "object" &&
30
+ !!res &&
31
+ typeof res.answer === "string" &&
32
+ typeof res.end === "number" &&
33
+ typeof res.score === "number" &&
34
+ typeof res.start === "number";
49
35
  if (!isValidOutput) {
50
- throw new InferenceOutputError("Expected {answer: string, end: number, score: number, start: number}");
36
+ throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
51
37
  }
52
- return res;
38
+ return Array.isArray(res) ? res[0] : res;
53
39
  }
@@ -1,22 +1,11 @@
1
+ import type { SentenceSimilarityInput, SentenceSimilarityOutput } from "@huggingface/tasks";
1
2
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
3
  import { getDefaultTask } from "../../lib/getDefaultTask";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { request } from "../custom/request";
6
+ import { omit } from "../../utils/omit";
5
7
 
6
- export type SentenceSimilarityArgs = BaseArgs & {
7
- /**
8
- * The inputs vary based on the model.
9
- *
10
- * For example when using sentence-transformers/paraphrase-xlm-r-multilingual-v1 the inputs will have a `source_sentence` string and
11
- * a `sentences` array of strings
12
- */
13
- inputs: Record<string, unknown> | Record<string, unknown>[];
14
- };
15
-
16
- /**
17
- * Returned values are a list of floats
18
- */
19
- export type SentenceSimilarityOutput = number[];
8
+ export type SentenceSimilarityArgs = BaseArgs & SentenceSimilarityInput;
20
9
 
21
10
  /**
22
11
  * Calculate the semantic similarity between one text and a list of other sentences by comparing their embeddings.
@@ -26,7 +15,7 @@ export async function sentenceSimilarity(
26
15
  options?: Options
27
16
  ): Promise<SentenceSimilarityOutput> {
28
17
  const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : undefined;
29
- const res = await request<SentenceSimilarityOutput>(args, {
18
+ const res = await request<SentenceSimilarityOutput>(prepareInput(args), {
30
19
  ...options,
31
20
  taskHint: "sentence-similarity",
32
21
  ...(defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }),
@@ -38,3 +27,11 @@ export async function sentenceSimilarity(
38
27
  }
39
28
  return res;
40
29
  }
30
+
31
+ function prepareInput(args: SentenceSimilarityArgs) {
32
+ return {
33
+ ...omit(args, ["inputs", "parameters"]),
34
+ inputs: { ...omit(args.inputs, "sourceSentence") },
35
+ parameters: { source_sentence: args.inputs.sourceSentence, ...args.parameters },
36
+ };
37
+ }
@@ -1,50 +1,9 @@
1
+ import type { SummarizationInput, SummarizationOutput } from "@huggingface/tasks";
1
2
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
3
  import type { BaseArgs, Options } from "../../types";
3
4
  import { request } from "../custom/request";
4
5
 
5
- export type SummarizationArgs = BaseArgs & {
6
- /**
7
- * A string to be summarized
8
- */
9
- inputs: string;
10
- parameters?: {
11
- /**
12
- * (Default: None). Integer to define the maximum length in tokens of the output summary.
13
- */
14
- max_length?: number;
15
- /**
16
- * (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. Network can cause some overhead so it will be a soft limit.
17
- */
18
- max_time?: number;
19
- /**
20
- * (Default: None). Integer to define the minimum length in tokens of the output summary.
21
- */
22
- min_length?: number;
23
- /**
24
- * (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized to not be picked in successive generation passes.
25
- */
26
- repetition_penalty?: number;
27
- /**
28
- * (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, 0 means always take the highest score, 100.0 is getting closer to uniform probability.
29
- */
30
- temperature?: number;
31
- /**
32
- * (Default: None). Integer to define the top tokens considered within the sample operation to create new text.
33
- */
34
- top_k?: number;
35
- /**
36
- * (Default: None). Float to define the tokens that are within the sample operation of text generation. Add tokens in the sample for more probable to least probable until the sum of the probabilities is greater than top_p.
37
- */
38
- top_p?: number;
39
- };
40
- };
41
-
42
- export interface SummarizationOutput {
43
- /**
44
- * The string after translation
45
- */
46
- summary_text: string;
47
- }
6
+ export type SummarizationArgs = BaseArgs & SummarizationInput;
48
7
 
49
8
  /**
50
9
  * This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.