@huggingface/inference 1.6.2 → 1.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +7 -0
- package/dist/index.d.ts +94 -1
- package/dist/index.js +299 -22
- package/dist/index.mjs +297 -21
- package/package.json +1 -1
- package/src/HfInference.ts +355 -23
- package/src/vendor/fetch-event-source/parse.spec.ts +389 -0
- package/src/vendor/fetch-event-source/parse.ts +216 -0
package/src/HfInference.ts
CHANGED
|
@@ -1,4 +1,8 @@
|
|
|
1
1
|
import { toArray } from "./utils/to-array";
|
|
2
|
+
import type { EventSourceMessage } from "./vendor/fetch-event-source/parse";
|
|
3
|
+
import { getLines, getMessages } from "./vendor/fetch-event-source/parse";
|
|
4
|
+
|
|
5
|
+
const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co/models/";
|
|
2
6
|
|
|
3
7
|
export interface Options {
|
|
4
8
|
/**
|
|
@@ -223,6 +227,86 @@ export interface TextGenerationReturn {
|
|
|
223
227
|
generated_text: string;
|
|
224
228
|
}
|
|
225
229
|
|
|
230
|
+
export interface TextGenerationStreamToken {
|
|
231
|
+
/** Token ID from the model tokenizer */
|
|
232
|
+
id: number;
|
|
233
|
+
/** Token text */
|
|
234
|
+
text: string;
|
|
235
|
+
/** Logprob */
|
|
236
|
+
logprob: number;
|
|
237
|
+
/**
|
|
238
|
+
* Is the token a special token
|
|
239
|
+
* Can be used to ignore tokens when concatenating
|
|
240
|
+
*/
|
|
241
|
+
special: boolean;
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
export interface TextGenerationStreamPrefillToken {
|
|
245
|
+
/** Token ID from the model tokenizer */
|
|
246
|
+
id: number;
|
|
247
|
+
/** Token text */
|
|
248
|
+
text: string;
|
|
249
|
+
/**
|
|
250
|
+
* Logprob
|
|
251
|
+
* Optional since the logprob of the first token cannot be computed
|
|
252
|
+
*/
|
|
253
|
+
logprob?: number;
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
export interface TextGenerationStreamBestOfSequence {
|
|
257
|
+
/** Generated text */
|
|
258
|
+
generated_text: string;
|
|
259
|
+
/** Generation finish reason */
|
|
260
|
+
finish_reason: TextGenerationStreamFinishReason;
|
|
261
|
+
/** Number of generated tokens */
|
|
262
|
+
generated_tokens: number;
|
|
263
|
+
/** Sampling seed if sampling was activated */
|
|
264
|
+
seed?: number;
|
|
265
|
+
/** Prompt tokens */
|
|
266
|
+
prefill: TextGenerationStreamPrefillToken[];
|
|
267
|
+
/** Generated tokens */
|
|
268
|
+
tokens: TextGenerationStreamToken[];
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
export enum TextGenerationStreamFinishReason {
|
|
272
|
+
/** number of generated tokens == `max_new_tokens` */
|
|
273
|
+
Length = "length",
|
|
274
|
+
/** the model generated its end of sequence token */
|
|
275
|
+
EndOfSequenceToken = "eos_token",
|
|
276
|
+
/** the model generated a text included in `stop_sequences` */
|
|
277
|
+
StopSequence = "stop_sequence",
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
export interface TextGenerationStreamDetails {
|
|
281
|
+
/** Generation finish reason */
|
|
282
|
+
finish_reason: TextGenerationStreamFinishReason;
|
|
283
|
+
/** Number of generated tokens */
|
|
284
|
+
generated_tokens: number;
|
|
285
|
+
/** Sampling seed if sampling was activated */
|
|
286
|
+
seed?: number;
|
|
287
|
+
/** Prompt tokens */
|
|
288
|
+
prefill: TextGenerationStreamPrefillToken[];
|
|
289
|
+
/** */
|
|
290
|
+
tokens: TextGenerationStreamToken[];
|
|
291
|
+
/** Additional sequences when using the `best_of` parameter */
|
|
292
|
+
best_of_sequences?: TextGenerationStreamBestOfSequence[];
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
export interface TextGenerationStreamReturn {
|
|
296
|
+
/** Generated token, one at a time */
|
|
297
|
+
token: TextGenerationStreamToken;
|
|
298
|
+
/**
|
|
299
|
+
* Complete generated text
|
|
300
|
+
* Only available when the generation is finished
|
|
301
|
+
*/
|
|
302
|
+
generated_text?: string;
|
|
303
|
+
/**
|
|
304
|
+
* Generation details
|
|
305
|
+
* Only available when the generation is finished
|
|
306
|
+
*/
|
|
307
|
+
details?: TextGenerationStreamDetails;
|
|
308
|
+
}
|
|
309
|
+
|
|
226
310
|
export type TokenClassificationArgs = Args & {
|
|
227
311
|
/**
|
|
228
312
|
* A string to be classified
|
|
@@ -519,21 +603,52 @@ export class HfInference {
|
|
|
519
603
|
* Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
|
|
520
604
|
*/
|
|
521
605
|
public async fillMask(args: FillMaskArgs, options?: Options): Promise<FillMaskReturn> {
|
|
522
|
-
|
|
606
|
+
const res = await this.request<FillMaskReturn>(args, options);
|
|
607
|
+
const isValidOutput =
|
|
608
|
+
Array.isArray(res) &&
|
|
609
|
+
res.every(
|
|
610
|
+
(x) =>
|
|
611
|
+
typeof x.score === "number" &&
|
|
612
|
+
typeof x.sequence === "string" &&
|
|
613
|
+
typeof x.token === "number" &&
|
|
614
|
+
typeof x.token_str === "string"
|
|
615
|
+
);
|
|
616
|
+
if (!isValidOutput) {
|
|
617
|
+
throw new TypeError(
|
|
618
|
+
"Invalid inference output: output must be of type Array<score: number, sequence:string, token:number, token_str:string>"
|
|
619
|
+
);
|
|
620
|
+
}
|
|
621
|
+
return res;
|
|
523
622
|
}
|
|
524
623
|
|
|
525
624
|
/**
|
|
526
625
|
* This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.
|
|
527
626
|
*/
|
|
528
627
|
public async summarization(args: SummarizationArgs, options?: Options): Promise<SummarizationReturn> {
|
|
529
|
-
|
|
628
|
+
const res = await this.request<SummarizationReturn[]>(args, options);
|
|
629
|
+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.summary_text === "string");
|
|
630
|
+
if (!isValidOutput) {
|
|
631
|
+
throw new TypeError("Invalid inference output: output must be of type Array<summary_text: string>");
|
|
632
|
+
}
|
|
633
|
+
return res?.[0];
|
|
530
634
|
}
|
|
531
635
|
|
|
532
636
|
/**
|
|
533
637
|
* Want to have a nice know-it-all bot that can answer any question?. Recommended model: deepset/roberta-base-squad2
|
|
534
638
|
*/
|
|
535
639
|
public async questionAnswer(args: QuestionAnswerArgs, options?: Options): Promise<QuestionAnswerReturn> {
|
|
536
|
-
|
|
640
|
+
const res = await this.request<QuestionAnswerReturn>(args, options);
|
|
641
|
+
const isValidOutput =
|
|
642
|
+
typeof res.answer === "string" &&
|
|
643
|
+
typeof res.end === "number" &&
|
|
644
|
+
typeof res.score === "number" &&
|
|
645
|
+
typeof res.start === "number";
|
|
646
|
+
if (!isValidOutput) {
|
|
647
|
+
throw new TypeError(
|
|
648
|
+
"Invalid inference output: output must be of type <answer: string, end: number, score: number, start: number>"
|
|
649
|
+
);
|
|
650
|
+
}
|
|
651
|
+
return res;
|
|
537
652
|
}
|
|
538
653
|
|
|
539
654
|
/**
|
|
@@ -543,21 +658,55 @@ export class HfInference {
|
|
|
543
658
|
args: TableQuestionAnswerArgs,
|
|
544
659
|
options?: Options
|
|
545
660
|
): Promise<TableQuestionAnswerReturn> {
|
|
546
|
-
|
|
661
|
+
const res = await this.request<TableQuestionAnswerReturn>(args, options);
|
|
662
|
+
const isValidOutput =
|
|
663
|
+
typeof res.aggregator === "string" &&
|
|
664
|
+
typeof res.answer === "string" &&
|
|
665
|
+
Array.isArray(res.cells) &&
|
|
666
|
+
res.cells.every((x) => typeof x === "string") &&
|
|
667
|
+
Array.isArray(res.coordinates) &&
|
|
668
|
+
res.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number"));
|
|
669
|
+
if (!isValidOutput) {
|
|
670
|
+
throw new TypeError(
|
|
671
|
+
"Invalid inference output: output must be of type <aggregator: string, answer: string, cells: string[], coordinates: number[][]>"
|
|
672
|
+
);
|
|
673
|
+
}
|
|
674
|
+
return res;
|
|
547
675
|
}
|
|
548
676
|
|
|
549
677
|
/**
|
|
550
678
|
* Usually used for sentiment-analysis this will output the likelihood of classes of an input. Recommended model: distilbert-base-uncased-finetuned-sst-2-english
|
|
551
679
|
*/
|
|
552
680
|
public async textClassification(args: TextClassificationArgs, options?: Options): Promise<TextClassificationReturn> {
|
|
553
|
-
|
|
681
|
+
const res = (await this.request<TextClassificationReturn[]>(args, options))?.[0];
|
|
682
|
+
const isValidOutput =
|
|
683
|
+
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
684
|
+
if (!isValidOutput) {
|
|
685
|
+
throw new TypeError("Invalid inference output: output must be of type Array<label: string, score: number>");
|
|
686
|
+
}
|
|
687
|
+
return res;
|
|
554
688
|
}
|
|
555
689
|
|
|
556
690
|
/**
|
|
557
691
|
* Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
|
|
558
692
|
*/
|
|
559
693
|
public async textGeneration(args: TextGenerationArgs, options?: Options): Promise<TextGenerationReturn> {
|
|
560
|
-
|
|
694
|
+
const res = await this.request<TextGenerationReturn[]>(args, options);
|
|
695
|
+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.generated_text === "string");
|
|
696
|
+
if (!isValidOutput) {
|
|
697
|
+
throw new TypeError("Invalid inference output: output must be of type Array<generated_text: string>");
|
|
698
|
+
}
|
|
699
|
+
return res?.[0];
|
|
700
|
+
}
|
|
701
|
+
|
|
702
|
+
/**
|
|
703
|
+
* Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time
|
|
704
|
+
*/
|
|
705
|
+
public async *textGenerationStream(
|
|
706
|
+
args: TextGenerationArgs,
|
|
707
|
+
options?: Options
|
|
708
|
+
): AsyncGenerator<TextGenerationStreamReturn> {
|
|
709
|
+
yield* this.streamingRequest<TextGenerationStreamReturn>(args, options);
|
|
561
710
|
}
|
|
562
711
|
|
|
563
712
|
/**
|
|
@@ -567,14 +716,35 @@ export class HfInference {
|
|
|
567
716
|
args: TokenClassificationArgs,
|
|
568
717
|
options?: Options
|
|
569
718
|
): Promise<TokenClassificationReturn> {
|
|
570
|
-
|
|
719
|
+
const res = toArray(await this.request<TokenClassificationReturnValue | TokenClassificationReturn>(args, options));
|
|
720
|
+
const isValidOutput =
|
|
721
|
+
Array.isArray(res) &&
|
|
722
|
+
res.every(
|
|
723
|
+
(x) =>
|
|
724
|
+
typeof x.end === "number" &&
|
|
725
|
+
typeof x.entity_group === "string" &&
|
|
726
|
+
typeof x.score === "number" &&
|
|
727
|
+
typeof x.start === "number" &&
|
|
728
|
+
typeof x.word === "string"
|
|
729
|
+
);
|
|
730
|
+
if (!isValidOutput) {
|
|
731
|
+
throw new TypeError(
|
|
732
|
+
"Invalid inference output: output must be of type Array<end: number, entity_group: string, score: number, start: number, word: string>"
|
|
733
|
+
);
|
|
734
|
+
}
|
|
735
|
+
return res;
|
|
571
736
|
}
|
|
572
737
|
|
|
573
738
|
/**
|
|
574
739
|
* This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en.
|
|
575
740
|
*/
|
|
576
741
|
public async translation(args: TranslationArgs, options?: Options): Promise<TranslationReturn> {
|
|
577
|
-
|
|
742
|
+
const res = await this.request<TranslationReturn[]>(args, options);
|
|
743
|
+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.translation_text === "string");
|
|
744
|
+
if (!isValidOutput) {
|
|
745
|
+
throw new TypeError("Invalid inference output: output must be of type Array<translation_text: string>");
|
|
746
|
+
}
|
|
747
|
+
return res?.[0];
|
|
578
748
|
}
|
|
579
749
|
|
|
580
750
|
/**
|
|
@@ -584,9 +754,25 @@ export class HfInference {
|
|
|
584
754
|
args: ZeroShotClassificationArgs,
|
|
585
755
|
options?: Options
|
|
586
756
|
): Promise<ZeroShotClassificationReturn> {
|
|
587
|
-
|
|
588
|
-
await this.request<ZeroShotClassificationReturnValue |
|
|
757
|
+
const res = toArray(
|
|
758
|
+
await this.request<ZeroShotClassificationReturnValue | ZeroShotClassificationReturn>(args, options)
|
|
589
759
|
);
|
|
760
|
+
const isValidOutput =
|
|
761
|
+
Array.isArray(res) &&
|
|
762
|
+
res.every(
|
|
763
|
+
(x) =>
|
|
764
|
+
Array.isArray(x.labels) &&
|
|
765
|
+
x.labels.every((_label) => typeof _label === "string") &&
|
|
766
|
+
Array.isArray(x.scores) &&
|
|
767
|
+
x.scores.every((_score) => typeof _score === "number") &&
|
|
768
|
+
typeof x.sequence === "string"
|
|
769
|
+
);
|
|
770
|
+
if (!isValidOutput) {
|
|
771
|
+
throw new TypeError(
|
|
772
|
+
"Invalid inference output: output must be of type Array<labels: string[], scores: number[], sequence: string>"
|
|
773
|
+
);
|
|
774
|
+
}
|
|
775
|
+
return res;
|
|
590
776
|
}
|
|
591
777
|
|
|
592
778
|
/**
|
|
@@ -594,14 +780,29 @@ export class HfInference {
|
|
|
594
780
|
*
|
|
595
781
|
*/
|
|
596
782
|
public async conversational(args: ConversationalArgs, options?: Options): Promise<ConversationalReturn> {
|
|
597
|
-
|
|
783
|
+
const res = await this.request<ConversationalReturn>(args, options);
|
|
784
|
+
const isValidOutput =
|
|
785
|
+
Array.isArray(res.conversation.generated_responses) &&
|
|
786
|
+
res.conversation.generated_responses.every((x) => typeof x === "string") &&
|
|
787
|
+
Array.isArray(res.conversation.past_user_inputs) &&
|
|
788
|
+
res.conversation.past_user_inputs.every((x) => typeof x === "string") &&
|
|
789
|
+
typeof res.generated_text === "string" &&
|
|
790
|
+
Array.isArray(res.warnings) &&
|
|
791
|
+
res.warnings.every((x) => typeof x === "string");
|
|
792
|
+
if (!isValidOutput) {
|
|
793
|
+
throw new TypeError(
|
|
794
|
+
"Invalid inference output: output must be of type <conversation: {generated_responses: string[], past_user_inputs: string[]}, generated_text: string, warnings: string[]>"
|
|
795
|
+
);
|
|
796
|
+
}
|
|
797
|
+
return res;
|
|
598
798
|
}
|
|
599
799
|
|
|
600
800
|
/**
|
|
601
801
|
* This task reads some text and outputs raw float values, that are usually consumed as part of a semantic database/semantic search.
|
|
602
802
|
*/
|
|
603
803
|
public async featureExtraction(args: FeatureExtractionArgs, options?: Options): Promise<FeatureExtractionReturn> {
|
|
604
|
-
|
|
804
|
+
const res = await this.request<FeatureExtractionReturn>(args, options);
|
|
805
|
+
return res;
|
|
605
806
|
}
|
|
606
807
|
|
|
607
808
|
/**
|
|
@@ -612,10 +813,15 @@ export class HfInference {
|
|
|
612
813
|
args: AutomaticSpeechRecognitionArgs,
|
|
613
814
|
options?: Options
|
|
614
815
|
): Promise<AutomaticSpeechRecognitionReturn> {
|
|
615
|
-
|
|
816
|
+
const res = await this.request<AutomaticSpeechRecognitionReturn>(args, {
|
|
616
817
|
...options,
|
|
617
818
|
binary: true,
|
|
618
819
|
});
|
|
820
|
+
const isValidOutput = typeof res.text === "string";
|
|
821
|
+
if (!isValidOutput) {
|
|
822
|
+
throw new TypeError("Invalid inference output: output must be of type <text: string>");
|
|
823
|
+
}
|
|
824
|
+
return res;
|
|
619
825
|
}
|
|
620
826
|
|
|
621
827
|
/**
|
|
@@ -626,10 +832,16 @@ export class HfInference {
|
|
|
626
832
|
args: AudioClassificationArgs,
|
|
627
833
|
options?: Options
|
|
628
834
|
): Promise<AudioClassificationReturn> {
|
|
629
|
-
|
|
835
|
+
const res = await this.request<AudioClassificationReturn>(args, {
|
|
630
836
|
...options,
|
|
631
837
|
binary: true,
|
|
632
838
|
});
|
|
839
|
+
const isValidOutput =
|
|
840
|
+
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
841
|
+
if (!isValidOutput) {
|
|
842
|
+
throw new TypeError("Invalid inference output: output must be of type Array<label: string, score: number>");
|
|
843
|
+
}
|
|
844
|
+
return res;
|
|
633
845
|
}
|
|
634
846
|
|
|
635
847
|
/**
|
|
@@ -640,10 +852,16 @@ export class HfInference {
|
|
|
640
852
|
args: ImageClassificationArgs,
|
|
641
853
|
options?: Options
|
|
642
854
|
): Promise<ImageClassificationReturn> {
|
|
643
|
-
|
|
855
|
+
const res = await this.request<ImageClassificationReturn>(args, {
|
|
644
856
|
...options,
|
|
645
857
|
binary: true,
|
|
646
858
|
});
|
|
859
|
+
const isValidOutput =
|
|
860
|
+
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
861
|
+
if (!isValidOutput) {
|
|
862
|
+
throw new TypeError("Invalid inference output: output must be of type Array<label: string, score: number>");
|
|
863
|
+
}
|
|
864
|
+
return res;
|
|
647
865
|
}
|
|
648
866
|
|
|
649
867
|
/**
|
|
@@ -651,10 +869,27 @@ export class HfInference {
|
|
|
651
869
|
* Recommended model: facebook/detr-resnet-50
|
|
652
870
|
*/
|
|
653
871
|
public async objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionReturn> {
|
|
654
|
-
|
|
872
|
+
const res = await this.request<ObjectDetectionReturn>(args, {
|
|
655
873
|
...options,
|
|
656
874
|
binary: true,
|
|
657
875
|
});
|
|
876
|
+
const isValidOutput =
|
|
877
|
+
Array.isArray(res) &&
|
|
878
|
+
res.every(
|
|
879
|
+
(x) =>
|
|
880
|
+
typeof x.label === "string" &&
|
|
881
|
+
typeof x.score === "number" &&
|
|
882
|
+
typeof x.box.xmin === "number" &&
|
|
883
|
+
typeof x.box.ymin === "number" &&
|
|
884
|
+
typeof x.box.xmax === "number" &&
|
|
885
|
+
typeof x.box.ymax === "number"
|
|
886
|
+
);
|
|
887
|
+
if (!isValidOutput) {
|
|
888
|
+
throw new TypeError(
|
|
889
|
+
"Invalid inference output: output must be of type Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
|
|
890
|
+
);
|
|
891
|
+
}
|
|
892
|
+
return res;
|
|
658
893
|
}
|
|
659
894
|
|
|
660
895
|
/**
|
|
@@ -662,10 +897,19 @@ export class HfInference {
|
|
|
662
897
|
* Recommended model: facebook/detr-resnet-50-panoptic
|
|
663
898
|
*/
|
|
664
899
|
public async imageSegmentation(args: ImageSegmentationArgs, options?: Options): Promise<ImageSegmentationReturn> {
|
|
665
|
-
|
|
900
|
+
const res = await this.request<ImageSegmentationReturn>(args, {
|
|
666
901
|
...options,
|
|
667
902
|
binary: true,
|
|
668
903
|
});
|
|
904
|
+
const isValidOutput =
|
|
905
|
+
Array.isArray(res) &&
|
|
906
|
+
res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
|
|
907
|
+
if (!isValidOutput) {
|
|
908
|
+
throw new TypeError(
|
|
909
|
+
"Invalid inference output: output must be of type Array<label: string, mask: string, score: number>"
|
|
910
|
+
);
|
|
911
|
+
}
|
|
912
|
+
return res;
|
|
669
913
|
}
|
|
670
914
|
|
|
671
915
|
/**
|
|
@@ -673,21 +917,32 @@ export class HfInference {
|
|
|
673
917
|
* Recommended model: stabilityai/stable-diffusion-2
|
|
674
918
|
*/
|
|
675
919
|
public async textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageReturn> {
|
|
676
|
-
|
|
920
|
+
const res = await this.request<TextToImageReturn>(args, {
|
|
677
921
|
...options,
|
|
678
922
|
blob: true,
|
|
679
923
|
});
|
|
924
|
+
const isValidOutput = res && res instanceof Blob;
|
|
925
|
+
if (!isValidOutput) {
|
|
926
|
+
throw new TypeError("Invalid inference output: output must be of type object & of instance Blob");
|
|
927
|
+
}
|
|
928
|
+
return res;
|
|
680
929
|
}
|
|
681
930
|
|
|
682
|
-
|
|
683
|
-
|
|
931
|
+
/**
|
|
932
|
+
* Helper that prepares request arguments
|
|
933
|
+
*/
|
|
934
|
+
private makeRequestOptions(
|
|
935
|
+
args: Args & {
|
|
936
|
+
data?: Blob | ArrayBuffer;
|
|
937
|
+
stream?: boolean;
|
|
938
|
+
},
|
|
684
939
|
options?: Options & {
|
|
685
940
|
binary?: boolean;
|
|
686
941
|
blob?: boolean;
|
|
687
942
|
/** For internal HF use, which is why it's not exposed in {@link Options} */
|
|
688
943
|
includeCredentials?: boolean;
|
|
689
944
|
}
|
|
690
|
-
)
|
|
945
|
+
) {
|
|
691
946
|
const mergedOptions = { ...this.defaultOptions, ...options };
|
|
692
947
|
const { model, ...otherArgs } = args;
|
|
693
948
|
|
|
@@ -712,7 +967,8 @@ export class HfInference {
|
|
|
712
967
|
}
|
|
713
968
|
}
|
|
714
969
|
|
|
715
|
-
const
|
|
970
|
+
const url = `${HF_INFERENCE_API_BASE_URL}${model}`;
|
|
971
|
+
const info: RequestInit = {
|
|
716
972
|
headers,
|
|
717
973
|
method: "POST",
|
|
718
974
|
body: options?.binary
|
|
@@ -722,7 +978,22 @@ export class HfInference {
|
|
|
722
978
|
options: mergedOptions,
|
|
723
979
|
}),
|
|
724
980
|
credentials: options?.includeCredentials ? "include" : "same-origin",
|
|
725
|
-
}
|
|
981
|
+
};
|
|
982
|
+
|
|
983
|
+
return { url, info, mergedOptions };
|
|
984
|
+
}
|
|
985
|
+
|
|
986
|
+
public async request<T>(
|
|
987
|
+
args: Args & { data?: Blob | ArrayBuffer },
|
|
988
|
+
options?: Options & {
|
|
989
|
+
binary?: boolean;
|
|
990
|
+
blob?: boolean;
|
|
991
|
+
/** For internal HF use, which is why it's not exposed in {@link Options} */
|
|
992
|
+
includeCredentials?: boolean;
|
|
993
|
+
}
|
|
994
|
+
): Promise<T> {
|
|
995
|
+
const { url, info, mergedOptions } = this.makeRequestOptions(args, options);
|
|
996
|
+
const response = await fetch(url, info);
|
|
726
997
|
|
|
727
998
|
if (mergedOptions.retry_on_error !== false && response.status === 503 && !mergedOptions.wait_for_model) {
|
|
728
999
|
return this.request(args, {
|
|
@@ -744,4 +1015,65 @@ export class HfInference {
|
|
|
744
1015
|
}
|
|
745
1016
|
return output;
|
|
746
1017
|
}
|
|
1018
|
+
|
|
1019
|
+
/**
|
|
1020
|
+
* Make request that uses server-sent events and returns response as a generator
|
|
1021
|
+
*/
|
|
1022
|
+
public async *streamingRequest<T>(
|
|
1023
|
+
args: Args & { data?: Blob | ArrayBuffer },
|
|
1024
|
+
options?: Options & {
|
|
1025
|
+
binary?: boolean;
|
|
1026
|
+
blob?: boolean;
|
|
1027
|
+
/** For internal HF use, which is why it's not exposed in {@link Options} */
|
|
1028
|
+
includeCredentials?: boolean;
|
|
1029
|
+
}
|
|
1030
|
+
): AsyncGenerator<T> {
|
|
1031
|
+
const { url, info, mergedOptions } = this.makeRequestOptions({ ...args, stream: true }, options);
|
|
1032
|
+
const response = await fetch(url, info);
|
|
1033
|
+
|
|
1034
|
+
if (mergedOptions.retry_on_error !== false && response.status === 503 && !mergedOptions.wait_for_model) {
|
|
1035
|
+
return this.streamingRequest(args, {
|
|
1036
|
+
...mergedOptions,
|
|
1037
|
+
wait_for_model: true,
|
|
1038
|
+
});
|
|
1039
|
+
}
|
|
1040
|
+
if (!response.ok) {
|
|
1041
|
+
throw new Error(`Server response contains error: ${response.status}`);
|
|
1042
|
+
}
|
|
1043
|
+
if (response.headers.get("content-type") !== "text/event-stream") {
|
|
1044
|
+
throw new Error(`Server does not support event stream content type`);
|
|
1045
|
+
}
|
|
1046
|
+
|
|
1047
|
+
const reader = response.body.getReader();
|
|
1048
|
+
const events: EventSourceMessage[] = [];
|
|
1049
|
+
|
|
1050
|
+
const onEvent = (event: EventSourceMessage) => {
|
|
1051
|
+
// accumulate events in array
|
|
1052
|
+
events.push(event);
|
|
1053
|
+
};
|
|
1054
|
+
|
|
1055
|
+
const onChunk = getLines(
|
|
1056
|
+
getMessages(
|
|
1057
|
+
() => {},
|
|
1058
|
+
() => {},
|
|
1059
|
+
onEvent
|
|
1060
|
+
)
|
|
1061
|
+
);
|
|
1062
|
+
|
|
1063
|
+
try {
|
|
1064
|
+
while (true) {
|
|
1065
|
+
const { done, value } = await reader.read();
|
|
1066
|
+
if (done) return;
|
|
1067
|
+
onChunk(value);
|
|
1068
|
+
while (events.length > 0) {
|
|
1069
|
+
const event = events.shift();
|
|
1070
|
+
if (event.data.length > 0) {
|
|
1071
|
+
yield JSON.parse(event.data) as T;
|
|
1072
|
+
}
|
|
1073
|
+
}
|
|
1074
|
+
}
|
|
1075
|
+
} finally {
|
|
1076
|
+
reader.releaseLock();
|
|
1077
|
+
}
|
|
1078
|
+
}
|
|
747
1079
|
}
|