@huggingface/inference 1.6.2 → 1.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,8 @@
1
1
  import { toArray } from "./utils/to-array";
2
+ import type { EventSourceMessage } from "./vendor/fetch-event-source/parse";
3
+ import { getLines, getMessages } from "./vendor/fetch-event-source/parse";
4
+
5
+ const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co/models/";
2
6
 
3
7
  export interface Options {
4
8
  /**
@@ -223,6 +227,86 @@ export interface TextGenerationReturn {
223
227
  generated_text: string;
224
228
  }
225
229
 
230
+ export interface TextGenerationStreamToken {
231
+ /** Token ID from the model tokenizer */
232
+ id: number;
233
+ /** Token text */
234
+ text: string;
235
+ /** Logprob */
236
+ logprob: number;
237
+ /**
238
+ * Is the token a special token
239
+ * Can be used to ignore tokens when concatenating
240
+ */
241
+ special: boolean;
242
+ }
243
+
244
+ export interface TextGenerationStreamPrefillToken {
245
+ /** Token ID from the model tokenizer */
246
+ id: number;
247
+ /** Token text */
248
+ text: string;
249
+ /**
250
+ * Logprob
251
+ * Optional since the logprob of the first token cannot be computed
252
+ */
253
+ logprob?: number;
254
+ }
255
+
256
+ export interface TextGenerationStreamBestOfSequence {
257
+ /** Generated text */
258
+ generated_text: string;
259
+ /** Generation finish reason */
260
+ finish_reason: TextGenerationStreamFinishReason;
261
+ /** Number of generated tokens */
262
+ generated_tokens: number;
263
+ /** Sampling seed if sampling was activated */
264
+ seed?: number;
265
+ /** Prompt tokens */
266
+ prefill: TextGenerationStreamPrefillToken[];
267
+ /** Generated tokens */
268
+ tokens: TextGenerationStreamToken[];
269
+ }
270
+
271
+ export enum TextGenerationStreamFinishReason {
272
+ /** number of generated tokens == `max_new_tokens` */
273
+ Length = "length",
274
+ /** the model generated its end of sequence token */
275
+ EndOfSequenceToken = "eos_token",
276
+ /** the model generated a text included in `stop_sequences` */
277
+ StopSequence = "stop_sequence",
278
+ }
279
+
280
+ export interface TextGenerationStreamDetails {
281
+ /** Generation finish reason */
282
+ finish_reason: TextGenerationStreamFinishReason;
283
+ /** Number of generated tokens */
284
+ generated_tokens: number;
285
+ /** Sampling seed if sampling was activated */
286
+ seed?: number;
287
+ /** Prompt tokens */
288
+ prefill: TextGenerationStreamPrefillToken[];
289
+ /** */
290
+ tokens: TextGenerationStreamToken[];
291
+ /** Additional sequences when using the `best_of` parameter */
292
+ best_of_sequences?: TextGenerationStreamBestOfSequence[];
293
+ }
294
+
295
+ export interface TextGenerationStreamReturn {
296
+ /** Generated token, one at a time */
297
+ token: TextGenerationStreamToken;
298
+ /**
299
+ * Complete generated text
300
+ * Only available when the generation is finished
301
+ */
302
+ generated_text?: string;
303
+ /**
304
+ * Generation details
305
+ * Only available when the generation is finished
306
+ */
307
+ details?: TextGenerationStreamDetails;
308
+ }
309
+
226
310
  export type TokenClassificationArgs = Args & {
227
311
  /**
228
312
  * A string to be classified
@@ -519,21 +603,52 @@ export class HfInference {
519
603
  * Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
520
604
  */
521
605
  public async fillMask(args: FillMaskArgs, options?: Options): Promise<FillMaskReturn> {
522
- return this.request(args, options);
606
+ const res = await this.request<FillMaskReturn>(args, options);
607
+ const isValidOutput =
608
+ Array.isArray(res) &&
609
+ res.every(
610
+ (x) =>
611
+ typeof x.score === "number" &&
612
+ typeof x.sequence === "string" &&
613
+ typeof x.token === "number" &&
614
+ typeof x.token_str === "string"
615
+ );
616
+ if (!isValidOutput) {
617
+ throw new TypeError(
618
+ "Invalid inference output: output must be of type Array<score: number, sequence:string, token:number, token_str:string>"
619
+ );
620
+ }
621
+ return res;
523
622
  }
524
623
 
525
624
  /**
526
625
  * This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.
527
626
  */
528
627
  public async summarization(args: SummarizationArgs, options?: Options): Promise<SummarizationReturn> {
529
- return (await this.request<SummarizationReturn[]>(args, options))?.[0];
628
+ const res = await this.request<SummarizationReturn[]>(args, options);
629
+ const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.summary_text === "string");
630
+ if (!isValidOutput) {
631
+ throw new TypeError("Invalid inference output: output must be of type Array<summary_text: string>");
632
+ }
633
+ return res?.[0];
530
634
  }
531
635
 
532
636
  /**
533
637
  * Want to have a nice know-it-all bot that can answer any question?. Recommended model: deepset/roberta-base-squad2
534
638
  */
535
639
  public async questionAnswer(args: QuestionAnswerArgs, options?: Options): Promise<QuestionAnswerReturn> {
536
- return await this.request(args, options);
640
+ const res = await this.request<QuestionAnswerReturn>(args, options);
641
+ const isValidOutput =
642
+ typeof res.answer === "string" &&
643
+ typeof res.end === "number" &&
644
+ typeof res.score === "number" &&
645
+ typeof res.start === "number";
646
+ if (!isValidOutput) {
647
+ throw new TypeError(
648
+ "Invalid inference output: output must be of type <answer: string, end: number, score: number, start: number>"
649
+ );
650
+ }
651
+ return res;
537
652
  }
538
653
 
539
654
  /**
@@ -543,21 +658,55 @@ export class HfInference {
543
658
  args: TableQuestionAnswerArgs,
544
659
  options?: Options
545
660
  ): Promise<TableQuestionAnswerReturn> {
546
- return await this.request(args, options);
661
+ const res = await this.request<TableQuestionAnswerReturn>(args, options);
662
+ const isValidOutput =
663
+ typeof res.aggregator === "string" &&
664
+ typeof res.answer === "string" &&
665
+ Array.isArray(res.cells) &&
666
+ res.cells.every((x) => typeof x === "string") &&
667
+ Array.isArray(res.coordinates) &&
668
+ res.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number"));
669
+ if (!isValidOutput) {
670
+ throw new TypeError(
671
+ "Invalid inference output: output must be of type <aggregator: string, answer: string, cells: string[], coordinates: number[][]>"
672
+ );
673
+ }
674
+ return res;
547
675
  }
548
676
 
549
677
  /**
550
678
  * Usually used for sentiment-analysis this will output the likelihood of classes of an input. Recommended model: distilbert-base-uncased-finetuned-sst-2-english
551
679
  */
552
680
  public async textClassification(args: TextClassificationArgs, options?: Options): Promise<TextClassificationReturn> {
553
- return (await this.request<TextClassificationReturn[]>(args, options))?.[0];
681
+ const res = (await this.request<TextClassificationReturn[]>(args, options))?.[0];
682
+ const isValidOutput =
683
+ Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
684
+ if (!isValidOutput) {
685
+ throw new TypeError("Invalid inference output: output must be of type Array<label: string, score: number>");
686
+ }
687
+ return res;
554
688
  }
555
689
 
556
690
  /**
557
691
  * Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
558
692
  */
559
693
  public async textGeneration(args: TextGenerationArgs, options?: Options): Promise<TextGenerationReturn> {
560
- return (await this.request<TextGenerationReturn[]>(args, options))?.[0];
694
+ const res = await this.request<TextGenerationReturn[]>(args, options);
695
+ const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.generated_text === "string");
696
+ if (!isValidOutput) {
697
+ throw new TypeError("Invalid inference output: output must be of type Array<generated_text: string>");
698
+ }
699
+ return res?.[0];
700
+ }
701
+
702
+ /**
703
+ * Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time
704
+ */
705
+ public async *textGenerationStream(
706
+ args: TextGenerationArgs,
707
+ options?: Options
708
+ ): AsyncGenerator<TextGenerationStreamReturn> {
709
+ yield* this.streamingRequest<TextGenerationStreamReturn>(args, options);
561
710
  }
562
711
 
563
712
  /**
@@ -567,14 +716,35 @@ export class HfInference {
567
716
  args: TokenClassificationArgs,
568
717
  options?: Options
569
718
  ): Promise<TokenClassificationReturn> {
570
- return toArray(await this.request(args, options));
719
+ const res = toArray(await this.request<TokenClassificationReturnValue | TokenClassificationReturn>(args, options));
720
+ const isValidOutput =
721
+ Array.isArray(res) &&
722
+ res.every(
723
+ (x) =>
724
+ typeof x.end === "number" &&
725
+ typeof x.entity_group === "string" &&
726
+ typeof x.score === "number" &&
727
+ typeof x.start === "number" &&
728
+ typeof x.word === "string"
729
+ );
730
+ if (!isValidOutput) {
731
+ throw new TypeError(
732
+ "Invalid inference output: output must be of type Array<end: number, entity_group: string, score: number, start: number, word: string>"
733
+ );
734
+ }
735
+ return res;
571
736
  }
572
737
 
573
738
  /**
574
739
  * This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en.
575
740
  */
576
741
  public async translation(args: TranslationArgs, options?: Options): Promise<TranslationReturn> {
577
- return (await this.request<TranslationReturn[]>(args, options))?.[0];
742
+ const res = await this.request<TranslationReturn[]>(args, options);
743
+ const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.translation_text === "string");
744
+ if (!isValidOutput) {
745
+ throw new TypeError("Invalid inference output: output must be of type Array<translation_text: string>");
746
+ }
747
+ return res?.[0];
578
748
  }
579
749
 
580
750
  /**
@@ -584,9 +754,25 @@ export class HfInference {
584
754
  args: ZeroShotClassificationArgs,
585
755
  options?: Options
586
756
  ): Promise<ZeroShotClassificationReturn> {
587
- return toArray(
588
- await this.request<ZeroShotClassificationReturnValue | ZeroShotClassificationReturnValue[]>(args, options)
757
+ const res = toArray(
758
+ await this.request<ZeroShotClassificationReturnValue | ZeroShotClassificationReturn>(args, options)
589
759
  );
760
+ const isValidOutput =
761
+ Array.isArray(res) &&
762
+ res.every(
763
+ (x) =>
764
+ Array.isArray(x.labels) &&
765
+ x.labels.every((_label) => typeof _label === "string") &&
766
+ Array.isArray(x.scores) &&
767
+ x.scores.every((_score) => typeof _score === "number") &&
768
+ typeof x.sequence === "string"
769
+ );
770
+ if (!isValidOutput) {
771
+ throw new TypeError(
772
+ "Invalid inference output: output must be of type Array<labels: string[], scores: number[], sequence: string>"
773
+ );
774
+ }
775
+ return res;
590
776
  }
591
777
 
592
778
  /**
@@ -594,14 +780,29 @@ export class HfInference {
594
780
  *
595
781
  */
596
782
  public async conversational(args: ConversationalArgs, options?: Options): Promise<ConversationalReturn> {
597
- return await this.request(args, options);
783
+ const res = await this.request<ConversationalReturn>(args, options);
784
+ const isValidOutput =
785
+ Array.isArray(res.conversation.generated_responses) &&
786
+ res.conversation.generated_responses.every((x) => typeof x === "string") &&
787
+ Array.isArray(res.conversation.past_user_inputs) &&
788
+ res.conversation.past_user_inputs.every((x) => typeof x === "string") &&
789
+ typeof res.generated_text === "string" &&
790
+ Array.isArray(res.warnings) &&
791
+ res.warnings.every((x) => typeof x === "string");
792
+ if (!isValidOutput) {
793
+ throw new TypeError(
794
+ "Invalid inference output: output must be of type <conversation: {generated_responses: string[], past_user_inputs: string[]}, generated_text: string, warnings: string[]>"
795
+ );
796
+ }
797
+ return res;
598
798
  }
599
799
 
600
800
  /**
601
801
  * This task reads some text and outputs raw float values, that are usually consumed as part of a semantic database/semantic search.
602
802
  */
603
803
  public async featureExtraction(args: FeatureExtractionArgs, options?: Options): Promise<FeatureExtractionReturn> {
604
- return await this.request(args, options);
804
+ const res = await this.request<FeatureExtractionReturn>(args, options);
805
+ return res;
605
806
  }
606
807
 
607
808
  /**
@@ -612,10 +813,15 @@ export class HfInference {
612
813
  args: AutomaticSpeechRecognitionArgs,
613
814
  options?: Options
614
815
  ): Promise<AutomaticSpeechRecognitionReturn> {
615
- return await this.request(args, {
816
+ const res = await this.request<AutomaticSpeechRecognitionReturn>(args, {
616
817
  ...options,
617
818
  binary: true,
618
819
  });
820
+ const isValidOutput = typeof res.text === "string";
821
+ if (!isValidOutput) {
822
+ throw new TypeError("Invalid inference output: output must be of type <text: string>");
823
+ }
824
+ return res;
619
825
  }
620
826
 
621
827
  /**
@@ -626,10 +832,16 @@ export class HfInference {
626
832
  args: AudioClassificationArgs,
627
833
  options?: Options
628
834
  ): Promise<AudioClassificationReturn> {
629
- return await this.request(args, {
835
+ const res = await this.request<AudioClassificationReturn>(args, {
630
836
  ...options,
631
837
  binary: true,
632
838
  });
839
+ const isValidOutput =
840
+ Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
841
+ if (!isValidOutput) {
842
+ throw new TypeError("Invalid inference output: output must be of type Array<label: string, score: number>");
843
+ }
844
+ return res;
633
845
  }
634
846
 
635
847
  /**
@@ -640,10 +852,16 @@ export class HfInference {
640
852
  args: ImageClassificationArgs,
641
853
  options?: Options
642
854
  ): Promise<ImageClassificationReturn> {
643
- return await this.request(args, {
855
+ const res = await this.request<ImageClassificationReturn>(args, {
644
856
  ...options,
645
857
  binary: true,
646
858
  });
859
+ const isValidOutput =
860
+ Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
861
+ if (!isValidOutput) {
862
+ throw new TypeError("Invalid inference output: output must be of type Array<label: string, score: number>");
863
+ }
864
+ return res;
647
865
  }
648
866
 
649
867
  /**
@@ -651,10 +869,27 @@ export class HfInference {
651
869
  * Recommended model: facebook/detr-resnet-50
652
870
  */
653
871
  public async objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionReturn> {
654
- return await this.request(args, {
872
+ const res = await this.request<ObjectDetectionReturn>(args, {
655
873
  ...options,
656
874
  binary: true,
657
875
  });
876
+ const isValidOutput =
877
+ Array.isArray(res) &&
878
+ res.every(
879
+ (x) =>
880
+ typeof x.label === "string" &&
881
+ typeof x.score === "number" &&
882
+ typeof x.box.xmin === "number" &&
883
+ typeof x.box.ymin === "number" &&
884
+ typeof x.box.xmax === "number" &&
885
+ typeof x.box.ymax === "number"
886
+ );
887
+ if (!isValidOutput) {
888
+ throw new TypeError(
889
+ "Invalid inference output: output must be of type Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
890
+ );
891
+ }
892
+ return res;
658
893
  }
659
894
 
660
895
  /**
@@ -662,10 +897,19 @@ export class HfInference {
662
897
  * Recommended model: facebook/detr-resnet-50-panoptic
663
898
  */
664
899
  public async imageSegmentation(args: ImageSegmentationArgs, options?: Options): Promise<ImageSegmentationReturn> {
665
- return await this.request(args, {
900
+ const res = await this.request<ImageSegmentationReturn>(args, {
666
901
  ...options,
667
902
  binary: true,
668
903
  });
904
+ const isValidOutput =
905
+ Array.isArray(res) &&
906
+ res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
907
+ if (!isValidOutput) {
908
+ throw new TypeError(
909
+ "Invalid inference output: output must be of type Array<label: string, mask: string, score: number>"
910
+ );
911
+ }
912
+ return res;
669
913
  }
670
914
 
671
915
  /**
@@ -673,21 +917,32 @@ export class HfInference {
673
917
  * Recommended model: stabilityai/stable-diffusion-2
674
918
  */
675
919
  public async textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageReturn> {
676
- return await this.request(args, {
920
+ const res = await this.request<TextToImageReturn>(args, {
677
921
  ...options,
678
922
  blob: true,
679
923
  });
924
+ const isValidOutput = res && res instanceof Blob;
925
+ if (!isValidOutput) {
926
+ throw new TypeError("Invalid inference output: output must be of type object & of instance Blob");
927
+ }
928
+ return res;
680
929
  }
681
930
 
682
- public async request<T>(
683
- args: Args & { data?: Blob | ArrayBuffer },
931
+ /**
932
+ * Helper that prepares request arguments
933
+ */
934
+ private makeRequestOptions(
935
+ args: Args & {
936
+ data?: Blob | ArrayBuffer;
937
+ stream?: boolean;
938
+ },
684
939
  options?: Options & {
685
940
  binary?: boolean;
686
941
  blob?: boolean;
687
942
  /** For internal HF use, which is why it's not exposed in {@link Options} */
688
943
  includeCredentials?: boolean;
689
944
  }
690
- ): Promise<T> {
945
+ ) {
691
946
  const mergedOptions = { ...this.defaultOptions, ...options };
692
947
  const { model, ...otherArgs } = args;
693
948
 
@@ -712,7 +967,8 @@ export class HfInference {
712
967
  }
713
968
  }
714
969
 
715
- const response = await fetch(`https://api-inference.huggingface.co/models/${model}`, {
970
+ const url = `${HF_INFERENCE_API_BASE_URL}${model}`;
971
+ const info: RequestInit = {
716
972
  headers,
717
973
  method: "POST",
718
974
  body: options?.binary
@@ -722,7 +978,22 @@ export class HfInference {
722
978
  options: mergedOptions,
723
979
  }),
724
980
  credentials: options?.includeCredentials ? "include" : "same-origin",
725
- });
981
+ };
982
+
983
+ return { url, info, mergedOptions };
984
+ }
985
+
986
+ public async request<T>(
987
+ args: Args & { data?: Blob | ArrayBuffer },
988
+ options?: Options & {
989
+ binary?: boolean;
990
+ blob?: boolean;
991
+ /** For internal HF use, which is why it's not exposed in {@link Options} */
992
+ includeCredentials?: boolean;
993
+ }
994
+ ): Promise<T> {
995
+ const { url, info, mergedOptions } = this.makeRequestOptions(args, options);
996
+ const response = await fetch(url, info);
726
997
 
727
998
  if (mergedOptions.retry_on_error !== false && response.status === 503 && !mergedOptions.wait_for_model) {
728
999
  return this.request(args, {
@@ -744,4 +1015,65 @@ export class HfInference {
744
1015
  }
745
1016
  return output;
746
1017
  }
1018
+
1019
+ /**
1020
+ * Make request that uses server-sent events and returns response as a generator
1021
+ */
1022
+ public async *streamingRequest<T>(
1023
+ args: Args & { data?: Blob | ArrayBuffer },
1024
+ options?: Options & {
1025
+ binary?: boolean;
1026
+ blob?: boolean;
1027
+ /** For internal HF use, which is why it's not exposed in {@link Options} */
1028
+ includeCredentials?: boolean;
1029
+ }
1030
+ ): AsyncGenerator<T> {
1031
+ const { url, info, mergedOptions } = this.makeRequestOptions({ ...args, stream: true }, options);
1032
+ const response = await fetch(url, info);
1033
+
1034
+ if (mergedOptions.retry_on_error !== false && response.status === 503 && !mergedOptions.wait_for_model) {
1035
+ return this.streamingRequest(args, {
1036
+ ...mergedOptions,
1037
+ wait_for_model: true,
1038
+ });
1039
+ }
1040
+ if (!response.ok) {
1041
+ throw new Error(`Server response contains error: ${response.status}`);
1042
+ }
1043
+ if (response.headers.get("content-type") !== "text/event-stream") {
1044
+ throw new Error(`Server does not support event stream content type`);
1045
+ }
1046
+
1047
+ const reader = response.body.getReader();
1048
+ const events: EventSourceMessage[] = [];
1049
+
1050
+ const onEvent = (event: EventSourceMessage) => {
1051
+ // accumulate events in array
1052
+ events.push(event);
1053
+ };
1054
+
1055
+ const onChunk = getLines(
1056
+ getMessages(
1057
+ () => {},
1058
+ () => {},
1059
+ onEvent
1060
+ )
1061
+ );
1062
+
1063
+ try {
1064
+ while (true) {
1065
+ const { done, value } = await reader.read();
1066
+ if (done) return;
1067
+ onChunk(value);
1068
+ while (events.length > 0) {
1069
+ const event = events.shift();
1070
+ if (event.data.length > 0) {
1071
+ yield JSON.parse(event.data) as T;
1072
+ }
1073
+ }
1074
+ }
1075
+ } finally {
1076
+ reader.releaseLock();
1077
+ }
1078
+ }
747
1079
  }