@huggingface/inference 2.5.2 → 2.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. package/dist/index.d.ts +48 -2
  2. package/dist/index.js +168 -78
  3. package/dist/index.mjs +168 -78
  4. package/package.json +1 -1
  5. package/src/lib/getDefaultTask.ts +1 -1
  6. package/src/lib/makeRequestOptions.ts +34 -5
  7. package/src/tasks/audio/audioClassification.ts +4 -1
  8. package/src/tasks/audio/audioToAudio.ts +4 -1
  9. package/src/tasks/audio/automaticSpeechRecognition.ts +4 -1
  10. package/src/tasks/audio/textToSpeech.ts +4 -1
  11. package/src/tasks/custom/request.ts +3 -1
  12. package/src/tasks/custom/streamingRequest.ts +3 -1
  13. package/src/tasks/cv/imageClassification.ts +4 -1
  14. package/src/tasks/cv/imageSegmentation.ts +4 -1
  15. package/src/tasks/cv/imageToImage.ts +4 -1
  16. package/src/tasks/cv/imageToText.ts +6 -1
  17. package/src/tasks/cv/objectDetection.ts +4 -1
  18. package/src/tasks/cv/textToImage.ts +4 -1
  19. package/src/tasks/cv/zeroShotImageClassification.ts +4 -1
  20. package/src/tasks/multimodal/documentQuestionAnswering.ts +4 -1
  21. package/src/tasks/multimodal/visualQuestionAnswering.ts +6 -1
  22. package/src/tasks/nlp/conversational.ts +1 -1
  23. package/src/tasks/nlp/featureExtraction.ts +7 -10
  24. package/src/tasks/nlp/fillMask.ts +4 -1
  25. package/src/tasks/nlp/questionAnswering.ts +4 -1
  26. package/src/tasks/nlp/sentenceSimilarity.ts +6 -10
  27. package/src/tasks/nlp/summarization.ts +4 -1
  28. package/src/tasks/nlp/tableQuestionAnswering.ts +4 -1
  29. package/src/tasks/nlp/textClassification.ts +6 -1
  30. package/src/tasks/nlp/textGeneration.ts +4 -1
  31. package/src/tasks/nlp/textGenerationStream.ts +4 -1
  32. package/src/tasks/nlp/tokenClassification.ts +6 -1
  33. package/src/tasks/nlp/translation.ts +4 -1
  34. package/src/tasks/nlp/zeroShotClassification.ts +4 -1
  35. package/src/tasks/tabular/tabularClassification.ts +4 -1
  36. package/src/tasks/tabular/tabularRegression.ts +4 -1
  37. package/src/types.ts +36 -2
@@ -20,7 +20,12 @@ export interface ImageToTextOutput {
20
20
  * This task reads some image input and outputs the text caption.
21
21
  */
22
22
  export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
23
- const res = (await request<[ImageToTextOutput]>(args, options))?.[0];
23
+ const res = (
24
+ await request<[ImageToTextOutput]>(args, {
25
+ ...options,
26
+ taskHint: "image-to-text",
27
+ })
28
+ )?.[0];
24
29
 
25
30
  if (typeof res?.generated_text !== "string") {
26
31
  throw new InferenceOutputError("Expected {generated_text: string}");
@@ -37,7 +37,10 @@ export type ObjectDetectionOutput = ObjectDetectionOutputValue[];
37
37
  * Recommended model: facebook/detr-resnet-50
38
38
  */
39
39
  export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> {
40
- const res = await request<ObjectDetectionOutput>(args, options);
40
+ const res = await request<ObjectDetectionOutput>(args, {
41
+ ...options,
42
+ taskHint: "object-detection",
43
+ });
41
44
  const isValidOutput =
42
45
  Array.isArray(res) &&
43
46
  res.every(
@@ -39,7 +39,10 @@ export type TextToImageOutput = Blob;
39
39
  * Recommended model: stabilityai/stable-diffusion-2
40
40
  */
41
41
  export async function textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageOutput> {
42
- const res = await request<TextToImageOutput>(args, options);
42
+ const res = await request<TextToImageOutput>(args, {
43
+ ...options,
44
+ taskHint: "text-to-image",
45
+ });
43
46
  const isValidOutput = res && res instanceof Blob;
44
47
  if (!isValidOutput) {
45
48
  throw new InferenceOutputError("Expected Blob");
@@ -45,7 +45,10 @@ export async function zeroShotImageClassification(
45
45
  },
46
46
  } as RequestArgs;
47
47
 
48
- const res = await request<ZeroShotImageClassificationOutput>(reqArgs, options);
48
+ const res = await request<ZeroShotImageClassificationOutput>(reqArgs, {
49
+ ...options,
50
+ taskHint: "zero-shot-image-classification",
51
+ });
49
52
  const isValidOutput =
50
53
  Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
51
54
  if (!isValidOutput) {
@@ -56,7 +56,10 @@ export async function documentQuestionAnswering(
56
56
  },
57
57
  } as RequestArgs;
58
58
  const res = toArray(
59
- await request<[DocumentQuestionAnsweringOutput] | DocumentQuestionAnsweringOutput>(reqArgs, options)
59
+ await request<[DocumentQuestionAnsweringOutput] | DocumentQuestionAnsweringOutput>(reqArgs, {
60
+ ...options,
61
+ taskHint: "document-question-answering",
62
+ })
60
63
  )?.[0];
61
64
  const isValidOutput =
62
65
  typeof res?.answer === "string" &&
@@ -45,7 +45,12 @@ export async function visualQuestionAnswering(
45
45
  ),
46
46
  },
47
47
  } as RequestArgs;
48
- const res = (await request<[VisualQuestionAnsweringOutput]>(reqArgs, options))?.[0];
48
+ const res = (
49
+ await request<[VisualQuestionAnsweringOutput]>(reqArgs, {
50
+ ...options,
51
+ taskHint: "visual-question-answering",
52
+ })
53
+ )?.[0];
49
54
  const isValidOutput = typeof res?.answer === "string" && typeof res.score === "number";
50
55
  if (!isValidOutput) {
51
56
  throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
@@ -63,7 +63,7 @@ export interface ConversationalOutput {
63
63
  *
64
64
  */
65
65
  export async function conversational(args: ConversationalArgs, options?: Options): Promise<ConversationalOutput> {
66
- const res = await request<ConversationalOutput>(args, options);
66
+ const res = await request<ConversationalOutput>(args, { ...options, taskHint: "conversational" });
67
67
  const isValidOutput =
68
68
  Array.isArray(res.conversation.generated_responses) &&
69
69
  res.conversation.generated_responses.every((x) => typeof x === "string") &&
@@ -25,16 +25,13 @@ export async function featureExtraction(
25
25
  args: FeatureExtractionArgs,
26
26
  options?: Options
27
27
  ): Promise<FeatureExtractionOutput> {
28
- const defaultTask = await getDefaultTask(args.model, args.accessToken);
29
- const res = await request<FeatureExtractionOutput>(
30
- args,
31
- defaultTask === "sentence-similarity"
32
- ? {
33
- ...options,
34
- task: "feature-extraction",
35
- }
36
- : options
37
- );
28
+ const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : undefined;
29
+
30
+ const res = await request<FeatureExtractionOutput>(args, {
31
+ ...options,
32
+ taskHint: "feature-extraction",
33
+ ...(defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" }),
34
+ });
38
35
  let isValidOutput = true;
39
36
 
40
37
  const isNumArrayRec = (arr: unknown[], maxDepth: number, curDepth = 0): boolean => {
@@ -29,7 +29,10 @@ export type FillMaskOutput = {
29
29
  * Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
30
30
  */
31
31
  export async function fillMask(args: FillMaskArgs, options?: Options): Promise<FillMaskOutput> {
32
- const res = await request<FillMaskOutput>(args, options);
32
+ const res = await request<FillMaskOutput>(args, {
33
+ ...options,
34
+ taskHint: "fill-mask",
35
+ });
33
36
  const isValidOutput =
34
37
  Array.isArray(res) &&
35
38
  res.every(
@@ -35,7 +35,10 @@ export async function questionAnswering(
35
35
  args: QuestionAnsweringArgs,
36
36
  options?: Options
37
37
  ): Promise<QuestionAnsweringOutput> {
38
- const res = await request<QuestionAnsweringOutput>(args, options);
38
+ const res = await request<QuestionAnsweringOutput>(args, {
39
+ ...options,
40
+ taskHint: "question-answering",
41
+ });
39
42
  const isValidOutput =
40
43
  typeof res === "object" &&
41
44
  !!res &&
@@ -25,16 +25,12 @@ export async function sentenceSimilarity(
25
25
  args: SentenceSimilarityArgs,
26
26
  options?: Options
27
27
  ): Promise<SentenceSimilarityOutput> {
28
- const defaultTask = await getDefaultTask(args.model, args.accessToken);
29
- const res = await request<SentenceSimilarityOutput>(
30
- args,
31
- defaultTask === "feature-extraction"
32
- ? {
33
- ...options,
34
- task: "sentence-similarity",
35
- }
36
- : options
37
- );
28
+ const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : undefined;
29
+ const res = await request<SentenceSimilarityOutput>(args, {
30
+ ...options,
31
+ taskHint: "sentence-similarity",
32
+ ...(defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }),
33
+ });
38
34
 
39
35
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
40
36
  if (!isValidOutput) {
@@ -50,7 +50,10 @@ export interface SummarizationOutput {
50
50
  * This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.
51
51
  */
52
52
  export async function summarization(args: SummarizationArgs, options?: Options): Promise<SummarizationOutput> {
53
- const res = await request<SummarizationOutput[]>(args, options);
53
+ const res = await request<SummarizationOutput[]>(args, {
54
+ ...options,
55
+ taskHint: "summarization",
56
+ });
54
57
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string");
55
58
  if (!isValidOutput) {
56
59
  throw new InferenceOutputError("Expected Array<{summary_text: string}>");
@@ -41,7 +41,10 @@ export async function tableQuestionAnswering(
41
41
  args: TableQuestionAnsweringArgs,
42
42
  options?: Options
43
43
  ): Promise<TableQuestionAnsweringOutput> {
44
- const res = await request<TableQuestionAnsweringOutput>(args, options);
44
+ const res = await request<TableQuestionAnsweringOutput>(args, {
45
+ ...options,
46
+ taskHint: "table-question-answering",
47
+ });
45
48
  const isValidOutput =
46
49
  typeof res?.aggregator === "string" &&
47
50
  typeof res.answer === "string" &&
@@ -27,7 +27,12 @@ export async function textClassification(
27
27
  args: TextClassificationArgs,
28
28
  options?: Options
29
29
  ): Promise<TextClassificationOutput> {
30
- const res = (await request<TextClassificationOutput[]>(args, options))?.[0];
30
+ const res = (
31
+ await request<TextClassificationOutput[]>(args, {
32
+ ...options,
33
+ taskHint: "text-classification",
34
+ })
35
+ )?.[0];
31
36
  const isValidOutput =
32
37
  Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number");
33
38
  if (!isValidOutput) {
@@ -62,7 +62,10 @@ export interface TextGenerationOutput {
62
62
  * Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
63
63
  */
64
64
  export async function textGeneration(args: TextGenerationArgs, options?: Options): Promise<TextGenerationOutput> {
65
- const res = await request<TextGenerationOutput[]>(args, options);
65
+ const res = await request<TextGenerationOutput[]>(args, {
66
+ ...options,
67
+ taskHint: "text-generation",
68
+ });
66
69
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string");
67
70
  if (!isValidOutput) {
68
71
  throw new InferenceOutputError("Expected Array<{generated_text: string}>");
@@ -88,5 +88,8 @@ export async function* textGenerationStream(
88
88
  args: TextGenerationArgs,
89
89
  options?: Options
90
90
  ): AsyncGenerator<TextGenerationStreamOutput> {
91
- yield* streamingRequest<TextGenerationStreamOutput>(args, options);
91
+ yield* streamingRequest<TextGenerationStreamOutput>(args, {
92
+ ...options,
93
+ taskHint: "text-generation",
94
+ });
92
95
  }
@@ -58,7 +58,12 @@ export async function tokenClassification(
58
58
  args: TokenClassificationArgs,
59
59
  options?: Options
60
60
  ): Promise<TokenClassificationOutput> {
61
- const res = toArray(await request<TokenClassificationOutput[number] | TokenClassificationOutput>(args, options));
61
+ const res = toArray(
62
+ await request<TokenClassificationOutput[number] | TokenClassificationOutput>(args, {
63
+ ...options,
64
+ taskHint: "token-classification",
65
+ })
66
+ );
62
67
  const isValidOutput =
63
68
  Array.isArray(res) &&
64
69
  res.every(
@@ -20,7 +20,10 @@ export interface TranslationOutput {
20
20
  * This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en.
21
21
  */
22
22
  export async function translation(args: TranslationArgs, options?: Options): Promise<TranslationOutput> {
23
- const res = await request<TranslationOutput[]>(args, options);
23
+ const res = await request<TranslationOutput[]>(args, {
24
+ ...options,
25
+ taskHint: "translation",
26
+ });
24
27
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");
25
28
  if (!isValidOutput) {
26
29
  throw new InferenceOutputError("Expected type Array<{translation_text: string}>");
@@ -36,7 +36,10 @@ export async function zeroShotClassification(
36
36
  options?: Options
37
37
  ): Promise<ZeroShotClassificationOutput> {
38
38
  const res = toArray(
39
- await request<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>(args, options)
39
+ await request<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>(args, {
40
+ ...options,
41
+ taskHint: "zero-shot-classification",
42
+ })
40
43
  );
41
44
  const isValidOutput =
42
45
  Array.isArray(res) &&
@@ -25,7 +25,10 @@ export async function tabularClassification(
25
25
  args: TabularClassificationArgs,
26
26
  options?: Options
27
27
  ): Promise<TabularClassificationOutput> {
28
- const res = await request<TabularClassificationOutput>(args, options);
28
+ const res = await request<TabularClassificationOutput>(args, {
29
+ ...options,
30
+ taskHint: "tabular-classification",
31
+ });
29
32
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
30
33
  if (!isValidOutput) {
31
34
  throw new InferenceOutputError("Expected number[]");
@@ -25,7 +25,10 @@ export async function tabularRegression(
25
25
  args: TabularRegressionArgs,
26
26
  options?: Options
27
27
  ): Promise<TabularRegressionOutput> {
28
- const res = await request<TabularRegressionOutput>(args, options);
28
+ const res = await request<TabularRegressionOutput>(args, {
29
+ ...options,
30
+ taskHint: "tabular-regression",
31
+ });
29
32
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
30
33
  if (!isValidOutput) {
31
34
  throw new InferenceOutputError("Expected number[]");
package/src/types.ts CHANGED
@@ -26,7 +26,39 @@ export interface Options {
26
26
  fetch?: typeof fetch;
27
27
  }
28
28
 
29
- export type InferenceTask = "text-classification" | "feature-extraction" | "sentence-similarity";
29
+ export type InferenceTask =
30
+ | "audio-classification"
31
+ | "audio-to-audio"
32
+ | "automatic-speech-recognition"
33
+ | "conversational"
34
+ | "depth-estimation"
35
+ | "document-question-answering"
36
+ | "feature-extraction"
37
+ | "fill-mask"
38
+ | "image-classification"
39
+ | "image-segmentation"
40
+ | "image-to-image"
41
+ | "image-to-text"
42
+ | "object-detection"
43
+ | "video-classification"
44
+ | "question-answering"
45
+ | "reinforcement-learning"
46
+ | "sentence-similarity"
47
+ | "summarization"
48
+ | "table-question-answering"
49
+ | "tabular-classification"
50
+ | "tabular-regression"
51
+ | "text-classification"
52
+ | "text-generation"
53
+ | "text-to-image"
54
+ | "text-to-speech"
55
+ | "text-to-video"
56
+ | "token-classification"
57
+ | "translation"
58
+ | "unconditional-image-generation"
59
+ | "visual-question-answering"
60
+ | "zero-shot-classification"
61
+ | "zero-shot-image-classification";
30
62
 
31
63
  export interface BaseArgs {
32
64
  /**
@@ -37,8 +69,10 @@ export interface BaseArgs {
37
69
  accessToken?: string;
38
70
  /**
39
71
  * The model to use. Can be a full URL for HF inference endpoints.
72
+ *
73
+ * If not specified, will call huggingface.co/api/tasks to get the default model for the task.
40
74
  */
41
- model: string;
75
+ model?: string;
42
76
  }
43
77
 
44
78
  export type RequestArgs = BaseArgs &