@huggingface/inference 3.0.0 → 3.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. package/README.md +11 -6
  2. package/dist/index.cjs +193 -76
  3. package/dist/index.js +193 -76
  4. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  5. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  6. package/dist/src/providers/replicate.d.ts.map +1 -1
  7. package/dist/src/providers/together.d.ts.map +1 -1
  8. package/dist/src/tasks/audio/audioClassification.d.ts +4 -18
  9. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  10. package/dist/src/tasks/audio/audioToAudio.d.ts +10 -9
  11. package/dist/src/tasks/audio/audioToAudio.d.ts.map +1 -1
  12. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts +3 -12
  13. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  14. package/dist/src/tasks/audio/textToSpeech.d.ts +4 -8
  15. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  16. package/dist/src/tasks/audio/utils.d.ts +11 -0
  17. package/dist/src/tasks/audio/utils.d.ts.map +1 -0
  18. package/dist/src/tasks/cv/imageClassification.d.ts +3 -17
  19. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  20. package/dist/src/tasks/cv/imageSegmentation.d.ts +3 -21
  21. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  22. package/dist/src/tasks/cv/imageToImage.d.ts +3 -49
  23. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  24. package/dist/src/tasks/cv/imageToText.d.ts +3 -12
  25. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  26. package/dist/src/tasks/cv/objectDetection.d.ts +3 -26
  27. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  28. package/dist/src/tasks/cv/textToImage.d.ts +3 -38
  29. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  30. package/dist/src/tasks/cv/textToVideo.d.ts +6 -0
  31. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -0
  32. package/dist/src/tasks/cv/utils.d.ts +11 -0
  33. package/dist/src/tasks/cv/utils.d.ts.map +1 -0
  34. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts +7 -15
  35. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  36. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts +5 -28
  37. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  38. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts +5 -20
  39. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  40. package/dist/src/tasks/nlp/fillMask.d.ts +2 -21
  41. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  42. package/dist/src/tasks/nlp/questionAnswering.d.ts +3 -25
  43. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  44. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts +2 -13
  45. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  46. package/dist/src/tasks/nlp/summarization.d.ts +2 -42
  47. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  48. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts +3 -31
  49. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  50. package/dist/src/tasks/nlp/textClassification.d.ts +2 -16
  51. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  52. package/dist/src/tasks/nlp/tokenClassification.d.ts +2 -45
  53. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  54. package/dist/src/tasks/nlp/translation.d.ts +2 -13
  55. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  56. package/dist/src/tasks/nlp/zeroShotClassification.d.ts +2 -22
  57. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  58. package/dist/src/types.d.ts +4 -0
  59. package/dist/src/types.d.ts.map +1 -1
  60. package/package.json +2 -2
  61. package/src/lib/makeRequestOptions.ts +7 -5
  62. package/src/providers/fal-ai.ts +12 -0
  63. package/src/providers/replicate.ts +6 -3
  64. package/src/providers/together.ts +2 -0
  65. package/src/tasks/audio/audioClassification.ts +7 -22
  66. package/src/tasks/audio/audioToAudio.ts +43 -23
  67. package/src/tasks/audio/automaticSpeechRecognition.ts +35 -23
  68. package/src/tasks/audio/textToSpeech.ts +23 -14
  69. package/src/tasks/audio/utils.ts +18 -0
  70. package/src/tasks/cv/imageClassification.ts +5 -20
  71. package/src/tasks/cv/imageSegmentation.ts +5 -24
  72. package/src/tasks/cv/imageToImage.ts +4 -52
  73. package/src/tasks/cv/imageToText.ts +6 -15
  74. package/src/tasks/cv/objectDetection.ts +5 -30
  75. package/src/tasks/cv/textToImage.ts +14 -50
  76. package/src/tasks/cv/textToVideo.ts +67 -0
  77. package/src/tasks/cv/utils.ts +13 -0
  78. package/src/tasks/cv/zeroShotImageClassification.ts +32 -31
  79. package/src/tasks/multimodal/documentQuestionAnswering.ts +25 -43
  80. package/src/tasks/multimodal/visualQuestionAnswering.ts +20 -36
  81. package/src/tasks/nlp/fillMask.ts +2 -22
  82. package/src/tasks/nlp/questionAnswering.ts +22 -36
  83. package/src/tasks/nlp/sentenceSimilarity.ts +12 -15
  84. package/src/tasks/nlp/summarization.ts +2 -43
  85. package/src/tasks/nlp/tableQuestionAnswering.ts +25 -41
  86. package/src/tasks/nlp/textClassification.ts +3 -18
  87. package/src/tasks/nlp/tokenClassification.ts +2 -47
  88. package/src/tasks/nlp/translation.ts +3 -17
  89. package/src/tasks/nlp/zeroShotClassification.ts +2 -24
  90. package/src/types.ts +7 -1
package/README.md CHANGED
@@ -42,15 +42,15 @@ const hf = new HfInference('your access token')
42
42
 
43
43
  Your access token should be kept private. If you need to protect it in front-end applications, we suggest setting up a proxy server that stores the access token.
44
44
 
45
- ### Requesting third-party inference providers
45
+ ### Third-party inference providers
46
46
 
47
- You can request inference from third-party providers with the inference client.
47
+ You can send inference requests to third-party providers with the inference client.
48
48
 
49
49
  Currently, we support the following providers: [Fal.ai](https://fal.ai), [Replicate](https://replicate.com), [Together](https://together.xyz) and [Sambanova](https://sambanova.ai).
50
50
 
51
- To make request to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
51
+ To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
52
52
  ```ts
53
- const accessToken = "hf_..."; // Either a HF access token, or an API key from the 3rd party provider (Replicate in this example)
53
+ const accessToken = "hf_..."; // Either a HF access token, or an API key from the third-party provider (Replicate in this example)
54
54
 
55
55
  const client = new HfInference(accessToken);
56
56
  await client.textToImage({
@@ -63,14 +63,19 @@ await client.textToImage({
63
63
  When authenticated with a Hugging Face access token, the request is routed through https://huggingface.co.
64
64
  When authenticated with a third-party provider key, the request is made directly against that provider's inference API.
65
65
 
66
- Only a subset of models are supported when requesting 3rd party providers. You can check the list of supported models per pipeline tasks here:
66
+ Only a subset of models are supported when requesting third-party providers. You can check the list of supported models per pipeline tasks here:
67
67
  - [Fal.ai supported models](./src/providers/fal-ai.ts)
68
68
  - [Replicate supported models](./src/providers/replicate.ts)
69
69
  - [Sambanova supported models](./src/providers/sambanova.ts)
70
70
  - [Together supported models](./src/providers/together.ts)
71
71
  - [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)
72
72
 
73
- #### Tree-shaking
73
+ ❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
74
+ This is not an issue for LLMs as everyone converged on the OpenAI API anyways, but can be more tricky for other tasks like "text-to-image" or "automatic-speech-recognition" where there exists no standard API. Let us know if any help is needed or if we can make things easier for you!
75
+
76
+ 👋**Want to add another provider?** Get in touch if you'd like to add support for another Inference provider, and/or request it on https://huggingface.co/spaces/huggingface/HuggingDiscussions/discussions/49
77
+
78
+ ### Tree-shaking
74
79
 
75
80
  You can import the functions you need directly from the module instead of using the `HfInference` class.
76
81
 
package/dist/index.cjs CHANGED
@@ -107,10 +107,22 @@ var FAL_AI_API_BASE_URL = "https://fal.run";
107
107
  var FAL_AI_SUPPORTED_MODEL_IDS = {
108
108
  "text-to-image": {
109
109
  "black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
110
- "black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev"
110
+ "black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
111
+ "playgroundai/playground-v2.5-1024px-aesthetic": "fal-ai/playground-v25",
112
+ "ByteDance/SDXL-Lightning": "fal-ai/lightning-models",
113
+ "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "fal-ai/pixart-sigma",
114
+ "stabilityai/stable-diffusion-3-medium": "fal-ai/stable-diffusion-v3-medium",
115
+ "Warlord-K/Sana-1024": "fal-ai/sana",
116
+ "fal/AuraFlow-v0.2": "fal-ai/aura-flow",
117
+ "stabilityai/stable-diffusion-3.5-large": "fal-ai/stable-diffusion-v35-large",
118
+ "Kwai-Kolors/Kolors": "fal-ai/kolors"
111
119
  },
112
120
  "automatic-speech-recognition": {
113
121
  "openai/whisper-large-v3": "fal-ai/whisper"
122
+ },
123
+ "text-to-video": {
124
+ "genmo/mochi-1-preview": "fal-ai/mochi-v1",
125
+ "tencent/HunyuanVideo": "fal-ai/hunyuan-video"
114
126
  }
115
127
  };
116
128
 
@@ -120,10 +132,13 @@ var REPLICATE_SUPPORTED_MODEL_IDS = {
120
132
  "text-to-image": {
121
133
  "black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
122
134
  "ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637"
135
+ },
136
+ "text-to-speech": {
137
+ "OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts:39a59319327b27327fa3095149c5a746e7f2aee18c75055c3368237a6503cd26"
138
+ },
139
+ "text-to-video": {
140
+ "genmo/mochi-1-preview": "genmoai/mochi-1:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460"
123
141
  }
124
- // "text-to-speech": {
125
- // "SWivid/F5-TTS": "x-lance/f5-tts:87faf6dd7a692dd82043f662e76369cab126a2cf1937e25a9d41e0b834fd230e"
126
- // },
127
142
  };
128
143
 
129
144
  // src/providers/sambanova.ts
@@ -159,6 +174,8 @@ var TOGETHER_SUPPORTED_MODEL_IDS = {
159
174
  },
160
175
  conversational: {
161
176
  "databricks/dbrx-instruct": "databricks/dbrx-instruct",
177
+ "deepseek-ai/DeepSeek-R1": "deepseek-ai/DeepSeek-R1",
178
+ "deepseek-ai/DeepSeek-V3": "deepseek-ai/DeepSeek-V3",
162
179
  "deepseek-ai/deepseek-llm-67b-chat": "deepseek-ai/deepseek-llm-67b-chat",
163
180
  "google/gemma-2-9b-it": "google/gemma-2-9b-it",
164
181
  "google/gemma-2b-it": "google/gemma-2-27b-it",
@@ -204,7 +221,8 @@ function isUrl(modelOrUrl) {
204
221
  var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_HUB_URL}/api/inference-proxy/{{PROVIDER}}`;
205
222
  var tasks = null;
206
223
  async function makeRequestOptions(args, options) {
207
- const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...otherArgs } = args;
224
+ const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
225
+ let otherArgs = remainingArgs;
208
226
  const provider = maybeProvider ?? "hf-inference";
209
227
  const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion: chatCompletion2 } = options ?? {};
210
228
  if (endpointUrl && provider !== "hf-inference") {
@@ -263,9 +281,9 @@ async function makeRequestOptions(args, options) {
263
281
  } else if (includeCredentials === true) {
264
282
  credentials = "include";
265
283
  }
266
- if (provider === "replicate" && model.includes(":")) {
267
- const version = model.split(":")[1];
268
- otherArgs.version = version;
284
+ if (provider === "replicate") {
285
+ const version = model.includes(":") ? model.split(":")[1] : void 0;
286
+ otherArgs = { input: otherArgs, version };
269
287
  }
270
288
  const info = {
271
289
  headers,
@@ -585,9 +603,42 @@ var InferenceOutputError = class extends TypeError {
585
603
  }
586
604
  };
587
605
 
606
+ // src/utils/pick.ts
607
+ function pick(o, props) {
608
+ return Object.assign(
609
+ {},
610
+ ...props.map((prop) => {
611
+ if (o[prop] !== void 0) {
612
+ return { [prop]: o[prop] };
613
+ }
614
+ })
615
+ );
616
+ }
617
+
618
+ // src/utils/typedInclude.ts
619
+ function typedInclude(arr, v) {
620
+ return arr.includes(v);
621
+ }
622
+
623
+ // src/utils/omit.ts
624
+ function omit(o, props) {
625
+ const propsArr = Array.isArray(props) ? props : [props];
626
+ const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
627
+ return pick(o, letsKeep);
628
+ }
629
+
630
+ // src/tasks/audio/utils.ts
631
+ function preparePayload(args) {
632
+ return "data" in args ? args : {
633
+ ...omit(args, "inputs"),
634
+ data: args.inputs
635
+ };
636
+ }
637
+
588
638
  // src/tasks/audio/audioClassification.ts
589
639
  async function audioClassification(args, options) {
590
- const res = await request(args, {
640
+ const payload = preparePayload(args);
641
+ const res = await request(payload, {
591
642
  ...options,
592
643
  taskHint: "audio-classification"
593
644
  });
@@ -613,15 +664,8 @@ function base64FromBytes(arr) {
613
664
 
614
665
  // src/tasks/audio/automaticSpeechRecognition.ts
615
666
  async function automaticSpeechRecognition(args, options) {
616
- if (args.provider === "fal-ai") {
617
- const contentType = args.data instanceof Blob ? args.data.type : "audio/mpeg";
618
- const base64audio = base64FromBytes(
619
- new Uint8Array(args.data instanceof ArrayBuffer ? args.data : await args.data.arrayBuffer())
620
- );
621
- args.audio_url = `data:${contentType};base64,${base64audio}`;
622
- delete args.data;
623
- }
624
- const res = await request(args, {
667
+ const payload = await buildPayload(args);
668
+ const res = await request(payload, {
625
669
  ...options,
626
670
  taskHint: "automatic-speech-recognition"
627
671
  });
@@ -631,6 +675,32 @@ async function automaticSpeechRecognition(args, options) {
631
675
  }
632
676
  return res;
633
677
  }
678
+ var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
679
+ async function buildPayload(args) {
680
+ if (args.provider === "fal-ai") {
681
+ const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : void 0;
682
+ const contentType = blob?.type;
683
+ if (!contentType) {
684
+ throw new Error(
685
+ `Unable to determine the input's content-type. Make sure your are passing a Blob when using provider fal-ai.`
686
+ );
687
+ }
688
+ if (!FAL_AI_SUPPORTED_BLOB_TYPES.includes(contentType)) {
689
+ throw new Error(
690
+ `Provider fal-ai does not support blob type ${contentType} - supported content types are: ${FAL_AI_SUPPORTED_BLOB_TYPES.join(
691
+ ", "
692
+ )}`
693
+ );
694
+ }
695
+ const base64audio = base64FromBytes(new Uint8Array(await blob.arrayBuffer()));
696
+ return {
697
+ ..."data" in args ? omit(args, "data") : omit(args, "inputs"),
698
+ audio_url: `data:${contentType};base64,${base64audio}`
699
+ };
700
+ } else {
701
+ return preparePayload(args);
702
+ }
703
+ }
634
704
 
635
705
  // src/tasks/audio/textToSpeech.ts
636
706
  async function textToSpeech(args, options) {
@@ -638,31 +708,55 @@ async function textToSpeech(args, options) {
638
708
  ...options,
639
709
  taskHint: "text-to-speech"
640
710
  });
641
- const isValidOutput = res && res instanceof Blob;
642
- if (!isValidOutput) {
643
- throw new InferenceOutputError("Expected Blob");
711
+ if (res instanceof Blob) {
712
+ return res;
644
713
  }
645
- return res;
714
+ if (res && typeof res === "object") {
715
+ if ("output" in res) {
716
+ if (typeof res.output === "string") {
717
+ const urlResponse = await fetch(res.output);
718
+ const blob = await urlResponse.blob();
719
+ return blob;
720
+ } else if (Array.isArray(res.output)) {
721
+ const urlResponse = await fetch(res.output[0]);
722
+ const blob = await urlResponse.blob();
723
+ return blob;
724
+ }
725
+ }
726
+ }
727
+ throw new InferenceOutputError("Expected Blob or object with output");
646
728
  }
647
729
 
648
730
  // src/tasks/audio/audioToAudio.ts
649
731
  async function audioToAudio(args, options) {
650
- const res = await request(args, {
732
+ const payload = preparePayload(args);
733
+ const res = await request(payload, {
651
734
  ...options,
652
735
  taskHint: "audio-to-audio"
653
736
  });
654
- const isValidOutput = Array.isArray(res) && res.every(
655
- (x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string"
656
- );
657
- if (!isValidOutput) {
658
- throw new InferenceOutputError("Expected Array<{label: string, blob: string, content-type: string}>");
737
+ return validateOutput(res);
738
+ }
739
+ function validateOutput(output) {
740
+ if (!Array.isArray(output)) {
741
+ throw new InferenceOutputError("Expected Array");
659
742
  }
660
- return res;
743
+ if (!output.every((elem) => {
744
+ return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
745
+ })) {
746
+ throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
747
+ }
748
+ return output;
749
+ }
750
+
751
+ // src/tasks/cv/utils.ts
752
+ function preparePayload2(args) {
753
+ return "data" in args ? args : { ...omit(args, "inputs"), data: args.inputs };
661
754
  }
662
755
 
663
756
  // src/tasks/cv/imageClassification.ts
664
757
  async function imageClassification(args, options) {
665
- const res = await request(args, {
758
+ const payload = preparePayload2(args);
759
+ const res = await request(payload, {
666
760
  ...options,
667
761
  taskHint: "image-classification"
668
762
  });
@@ -675,7 +769,8 @@ async function imageClassification(args, options) {
675
769
 
676
770
  // src/tasks/cv/imageSegmentation.ts
677
771
  async function imageSegmentation(args, options) {
678
- const res = await request(args, {
772
+ const payload = preparePayload2(args);
773
+ const res = await request(payload, {
679
774
  ...options,
680
775
  taskHint: "image-segmentation"
681
776
  });
@@ -688,7 +783,8 @@ async function imageSegmentation(args, options) {
688
783
 
689
784
  // src/tasks/cv/imageToText.ts
690
785
  async function imageToText(args, options) {
691
- const res = (await request(args, {
786
+ const payload = preparePayload2(args);
787
+ const res = (await request(payload, {
692
788
  ...options,
693
789
  taskHint: "image-to-text"
694
790
  }))?.[0];
@@ -700,7 +796,8 @@ async function imageToText(args, options) {
700
796
 
701
797
  // src/tasks/cv/objectDetection.ts
702
798
  async function objectDetection(args, options) {
703
- const res = await request(args, {
799
+ const payload = preparePayload2(args);
800
+ const res = await request(payload, {
704
801
  ...options,
705
802
  taskHint: "object-detection"
706
803
  });
@@ -717,15 +814,13 @@ async function objectDetection(args, options) {
717
814
 
718
815
  // src/tasks/cv/textToImage.ts
719
816
  async function textToImage(args, options) {
720
- if (args.provider === "together" || args.provider === "fal-ai") {
721
- args.prompt = args.inputs;
722
- args.inputs = "";
723
- args.response_format = "base64";
724
- } else if (args.provider === "replicate") {
725
- args.input = { prompt: args.inputs };
726
- delete args.inputs;
727
- }
728
- const res = await request(args, {
817
+ const payload = args.provider === "together" || args.provider === "fal-ai" || args.provider === "replicate" ? {
818
+ ...omit(args, ["inputs", "parameters"]),
819
+ ...args.parameters,
820
+ ...args.provider !== "replicate" ? { response_format: "base64" } : void 0,
821
+ prompt: args.inputs
822
+ } : args;
823
+ const res = await request(payload, {
729
824
  ...options,
730
825
  taskHint: "text-to-image"
731
826
  });
@@ -782,18 +877,30 @@ async function imageToImage(args, options) {
782
877
  }
783
878
 
784
879
  // src/tasks/cv/zeroShotImageClassification.ts
785
- async function zeroShotImageClassification(args, options) {
786
- const reqArgs = {
787
- ...args,
788
- inputs: {
789
- image: base64FromBytes(
790
- new Uint8Array(
791
- args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
880
+ async function preparePayload3(args) {
881
+ if (args.inputs instanceof Blob) {
882
+ return {
883
+ ...args,
884
+ inputs: {
885
+ image: base64FromBytes(new Uint8Array(await args.inputs.arrayBuffer()))
886
+ }
887
+ };
888
+ } else {
889
+ return {
890
+ ...args,
891
+ inputs: {
892
+ image: base64FromBytes(
893
+ new Uint8Array(
894
+ args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
895
+ )
792
896
  )
793
- )
794
- }
795
- };
796
- const res = await request(reqArgs, {
897
+ }
898
+ };
899
+ }
900
+ }
901
+ async function zeroShotImageClassification(args, options) {
902
+ const payload = await preparePayload3(args);
903
+ const res = await request(payload, {
797
904
  ...options,
798
905
  taskHint: "zero-shot-image-classification"
799
906
  });
@@ -882,17 +989,19 @@ async function questionAnswering(args, options) {
882
989
  ...options,
883
990
  taskHint: "question-answering"
884
991
  });
885
- const isValidOutput = typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";
992
+ const isValidOutput = Array.isArray(res) ? res.every(
993
+ (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
994
+ ) : typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";
886
995
  if (!isValidOutput) {
887
- throw new InferenceOutputError("Expected {answer: string, end: number, score: number, start: number}");
996
+ throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
888
997
  }
889
- return res;
998
+ return Array.isArray(res) ? res[0] : res;
890
999
  }
891
1000
 
892
1001
  // src/tasks/nlp/sentenceSimilarity.ts
893
1002
  async function sentenceSimilarity(args, options) {
894
1003
  const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : void 0;
895
- const res = await request(args, {
1004
+ const res = await request(prepareInput(args), {
896
1005
  ...options,
897
1006
  taskHint: "sentence-similarity",
898
1007
  ...defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }
@@ -903,6 +1012,13 @@ async function sentenceSimilarity(args, options) {
903
1012
  }
904
1013
  return res;
905
1014
  }
1015
+ function prepareInput(args) {
1016
+ return {
1017
+ ...omit(args, ["inputs", "parameters"]),
1018
+ inputs: { ...omit(args.inputs, "sourceSentence") },
1019
+ parameters: { source_sentence: args.inputs.sourceSentence, ...args.parameters }
1020
+ };
1021
+ }
906
1022
 
907
1023
  // src/tasks/nlp/summarization.ts
908
1024
  async function summarization(args, options) {
@@ -923,13 +1039,18 @@ async function tableQuestionAnswering(args, options) {
923
1039
  ...options,
924
1040
  taskHint: "table-question-answering"
925
1041
  });
926
- const isValidOutput = typeof res?.aggregator === "string" && typeof res.answer === "string" && Array.isArray(res.cells) && res.cells.every((x) => typeof x === "string") && Array.isArray(res.coordinates) && res.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number"));
1042
+ const isValidOutput = Array.isArray(res) ? res.every((elem) => validate(elem)) : validate(res);
927
1043
  if (!isValidOutput) {
928
1044
  throw new InferenceOutputError(
929
1045
  "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
930
1046
  );
931
1047
  }
932
- return res;
1048
+ return Array.isArray(res) ? res[0] : res;
1049
+ }
1050
+ function validate(elem) {
1051
+ return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
1052
+ (coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
1053
+ );
933
1054
  }
934
1055
 
935
1056
  // src/tasks/nlp/textClassification.ts
@@ -1072,11 +1193,7 @@ async function documentQuestionAnswering(args, options) {
1072
1193
  inputs: {
1073
1194
  question: args.inputs.question,
1074
1195
  // convert Blob or ArrayBuffer to base64
1075
- image: base64FromBytes(
1076
- new Uint8Array(
1077
- args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
1078
- )
1079
- )
1196
+ image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer()))
1080
1197
  }
1081
1198
  };
1082
1199
  const res = toArray(
@@ -1084,12 +1201,14 @@ async function documentQuestionAnswering(args, options) {
1084
1201
  ...options,
1085
1202
  taskHint: "document-question-answering"
1086
1203
  })
1087
- )?.[0];
1088
- const isValidOutput = typeof res?.answer === "string" && (typeof res.end === "number" || typeof res.end === "undefined") && (typeof res.score === "number" || typeof res.score === "undefined") && (typeof res.start === "number" || typeof res.start === "undefined");
1204
+ );
1205
+ const isValidOutput = Array.isArray(res) && res.every(
1206
+ (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
1207
+ );
1089
1208
  if (!isValidOutput) {
1090
1209
  throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>");
1091
1210
  }
1092
- return res;
1211
+ return res[0];
1093
1212
  }
1094
1213
 
1095
1214
  // src/tasks/multimodal/visualQuestionAnswering.ts
@@ -1099,22 +1218,20 @@ async function visualQuestionAnswering(args, options) {
1099
1218
  inputs: {
1100
1219
  question: args.inputs.question,
1101
1220
  // convert Blob or ArrayBuffer to base64
1102
- image: base64FromBytes(
1103
- new Uint8Array(
1104
- args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
1105
- )
1106
- )
1221
+ image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer()))
1107
1222
  }
1108
1223
  };
1109
- const res = (await request(reqArgs, {
1224
+ const res = await request(reqArgs, {
1110
1225
  ...options,
1111
1226
  taskHint: "visual-question-answering"
1112
- }))?.[0];
1113
- const isValidOutput = typeof res?.answer === "string" && typeof res.score === "number";
1227
+ });
1228
+ const isValidOutput = Array.isArray(res) && res.every(
1229
+ (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
1230
+ );
1114
1231
  if (!isValidOutput) {
1115
1232
  throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
1116
1233
  }
1117
- return res;
1234
+ return res[0];
1118
1235
  }
1119
1236
 
1120
1237
  // src/tasks/tabular/tabularRegression.ts