@huggingface/inference 3.0.1 → 3.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.cjs +195 -69
- package/dist/index.js +194 -69
- package/dist/src/providers/fal-ai.d.ts.map +1 -1
- package/dist/src/providers/replicate.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts +4 -18
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioToAudio.d.ts +10 -9
- package/dist/src/tasks/audio/audioToAudio.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts +3 -12
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts +4 -8
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/audio/utils.d.ts +11 -0
- package/dist/src/tasks/audio/utils.d.ts.map +1 -0
- package/dist/src/tasks/cv/imageClassification.d.ts +3 -17
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts +3 -21
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts +3 -49
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts +3 -12
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts +3 -26
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts +3 -38
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts +6 -0
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -0
- package/dist/src/tasks/cv/utils.d.ts +11 -0
- package/dist/src/tasks/cv/utils.d.ts.map +1 -0
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts +7 -15
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/index.d.ts +1 -0
- package/dist/src/tasks/index.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts +5 -28
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts +5 -20
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts +2 -21
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts +3 -25
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts +2 -13
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts +2 -42
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts +3 -31
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts +2 -16
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts +2 -45
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts +2 -13
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts +2 -22
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/types.d.ts +4 -0
- package/dist/src/types.d.ts.map +1 -1
- package/package.json +2 -2
- package/src/providers/fal-ai.ts +4 -0
- package/src/providers/replicate.ts +3 -0
- package/src/tasks/audio/audioClassification.ts +7 -22
- package/src/tasks/audio/audioToAudio.ts +43 -23
- package/src/tasks/audio/automaticSpeechRecognition.ts +35 -23
- package/src/tasks/audio/textToSpeech.ts +8 -14
- package/src/tasks/audio/utils.ts +18 -0
- package/src/tasks/cv/imageClassification.ts +5 -20
- package/src/tasks/cv/imageSegmentation.ts +5 -24
- package/src/tasks/cv/imageToImage.ts +4 -52
- package/src/tasks/cv/imageToText.ts +6 -15
- package/src/tasks/cv/objectDetection.ts +5 -30
- package/src/tasks/cv/textToImage.ts +14 -50
- package/src/tasks/cv/textToVideo.ts +67 -0
- package/src/tasks/cv/utils.ts +13 -0
- package/src/tasks/cv/zeroShotImageClassification.ts +32 -31
- package/src/tasks/index.ts +1 -0
- package/src/tasks/multimodal/documentQuestionAnswering.ts +25 -43
- package/src/tasks/multimodal/visualQuestionAnswering.ts +20 -36
- package/src/tasks/nlp/fillMask.ts +2 -22
- package/src/tasks/nlp/questionAnswering.ts +22 -36
- package/src/tasks/nlp/sentenceSimilarity.ts +12 -15
- package/src/tasks/nlp/summarization.ts +2 -43
- package/src/tasks/nlp/tableQuestionAnswering.ts +25 -41
- package/src/tasks/nlp/textClassification.ts +3 -18
- package/src/tasks/nlp/tokenClassification.ts +2 -47
- package/src/tasks/nlp/translation.ts +3 -17
- package/src/tasks/nlp/zeroShotClassification.ts +2 -24
- package/src/types.ts +7 -1
package/dist/index.cjs
CHANGED
|
@@ -54,6 +54,7 @@ __export(src_exports, {
|
|
|
54
54
|
textGenerationStream: () => textGenerationStream,
|
|
55
55
|
textToImage: () => textToImage,
|
|
56
56
|
textToSpeech: () => textToSpeech,
|
|
57
|
+
textToVideo: () => textToVideo,
|
|
57
58
|
tokenClassification: () => tokenClassification,
|
|
58
59
|
translation: () => translation,
|
|
59
60
|
visualQuestionAnswering: () => visualQuestionAnswering,
|
|
@@ -91,6 +92,7 @@ __export(tasks_exports, {
|
|
|
91
92
|
textGenerationStream: () => textGenerationStream,
|
|
92
93
|
textToImage: () => textToImage,
|
|
93
94
|
textToSpeech: () => textToSpeech,
|
|
95
|
+
textToVideo: () => textToVideo,
|
|
94
96
|
tokenClassification: () => tokenClassification,
|
|
95
97
|
translation: () => translation,
|
|
96
98
|
visualQuestionAnswering: () => visualQuestionAnswering,
|
|
@@ -119,6 +121,10 @@ var FAL_AI_SUPPORTED_MODEL_IDS = {
|
|
|
119
121
|
},
|
|
120
122
|
"automatic-speech-recognition": {
|
|
121
123
|
"openai/whisper-large-v3": "fal-ai/whisper"
|
|
124
|
+
},
|
|
125
|
+
"text-to-video": {
|
|
126
|
+
"genmo/mochi-1-preview": "fal-ai/mochi-v1",
|
|
127
|
+
"tencent/HunyuanVideo": "fal-ai/hunyuan-video"
|
|
122
128
|
}
|
|
123
129
|
};
|
|
124
130
|
|
|
@@ -131,6 +137,9 @@ var REPLICATE_SUPPORTED_MODEL_IDS = {
|
|
|
131
137
|
},
|
|
132
138
|
"text-to-speech": {
|
|
133
139
|
"OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts:39a59319327b27327fa3095149c5a746e7f2aee18c75055c3368237a6503cd26"
|
|
140
|
+
},
|
|
141
|
+
"text-to-video": {
|
|
142
|
+
"genmo/mochi-1-preview": "genmoai/mochi-1:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460"
|
|
134
143
|
}
|
|
135
144
|
};
|
|
136
145
|
|
|
@@ -596,9 +605,42 @@ var InferenceOutputError = class extends TypeError {
|
|
|
596
605
|
}
|
|
597
606
|
};
|
|
598
607
|
|
|
608
|
+
// src/utils/pick.ts
|
|
609
|
+
function pick(o, props) {
|
|
610
|
+
return Object.assign(
|
|
611
|
+
{},
|
|
612
|
+
...props.map((prop) => {
|
|
613
|
+
if (o[prop] !== void 0) {
|
|
614
|
+
return { [prop]: o[prop] };
|
|
615
|
+
}
|
|
616
|
+
})
|
|
617
|
+
);
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
// src/utils/typedInclude.ts
|
|
621
|
+
function typedInclude(arr, v) {
|
|
622
|
+
return arr.includes(v);
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
// src/utils/omit.ts
|
|
626
|
+
function omit(o, props) {
|
|
627
|
+
const propsArr = Array.isArray(props) ? props : [props];
|
|
628
|
+
const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
|
|
629
|
+
return pick(o, letsKeep);
|
|
630
|
+
}
|
|
631
|
+
|
|
632
|
+
// src/tasks/audio/utils.ts
|
|
633
|
+
function preparePayload(args) {
|
|
634
|
+
return "data" in args ? args : {
|
|
635
|
+
...omit(args, "inputs"),
|
|
636
|
+
data: args.inputs
|
|
637
|
+
};
|
|
638
|
+
}
|
|
639
|
+
|
|
599
640
|
// src/tasks/audio/audioClassification.ts
|
|
600
641
|
async function audioClassification(args, options) {
|
|
601
|
-
const
|
|
642
|
+
const payload = preparePayload(args);
|
|
643
|
+
const res = await request(payload, {
|
|
602
644
|
...options,
|
|
603
645
|
taskHint: "audio-classification"
|
|
604
646
|
});
|
|
@@ -624,15 +666,8 @@ function base64FromBytes(arr) {
|
|
|
624
666
|
|
|
625
667
|
// src/tasks/audio/automaticSpeechRecognition.ts
|
|
626
668
|
async function automaticSpeechRecognition(args, options) {
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
const base64audio = base64FromBytes(
|
|
630
|
-
new Uint8Array(args.data instanceof ArrayBuffer ? args.data : await args.data.arrayBuffer())
|
|
631
|
-
);
|
|
632
|
-
args.audio_url = `data:${contentType};base64,${base64audio}`;
|
|
633
|
-
delete args.data;
|
|
634
|
-
}
|
|
635
|
-
const res = await request(args, {
|
|
669
|
+
const payload = await buildPayload(args);
|
|
670
|
+
const res = await request(payload, {
|
|
636
671
|
...options,
|
|
637
672
|
taskHint: "automatic-speech-recognition"
|
|
638
673
|
});
|
|
@@ -642,6 +677,32 @@ async function automaticSpeechRecognition(args, options) {
|
|
|
642
677
|
}
|
|
643
678
|
return res;
|
|
644
679
|
}
|
|
680
|
+
var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
|
|
681
|
+
async function buildPayload(args) {
|
|
682
|
+
if (args.provider === "fal-ai") {
|
|
683
|
+
const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : void 0;
|
|
684
|
+
const contentType = blob?.type;
|
|
685
|
+
if (!contentType) {
|
|
686
|
+
throw new Error(
|
|
687
|
+
`Unable to determine the input's content-type. Make sure your are passing a Blob when using provider fal-ai.`
|
|
688
|
+
);
|
|
689
|
+
}
|
|
690
|
+
if (!FAL_AI_SUPPORTED_BLOB_TYPES.includes(contentType)) {
|
|
691
|
+
throw new Error(
|
|
692
|
+
`Provider fal-ai does not support blob type ${contentType} - supported content types are: ${FAL_AI_SUPPORTED_BLOB_TYPES.join(
|
|
693
|
+
", "
|
|
694
|
+
)}`
|
|
695
|
+
);
|
|
696
|
+
}
|
|
697
|
+
const base64audio = base64FromBytes(new Uint8Array(await blob.arrayBuffer()));
|
|
698
|
+
return {
|
|
699
|
+
..."data" in args ? omit(args, "data") : omit(args, "inputs"),
|
|
700
|
+
audio_url: `data:${contentType};base64,${base64audio}`
|
|
701
|
+
};
|
|
702
|
+
} else {
|
|
703
|
+
return preparePayload(args);
|
|
704
|
+
}
|
|
705
|
+
}
|
|
645
706
|
|
|
646
707
|
// src/tasks/audio/textToSpeech.ts
|
|
647
708
|
async function textToSpeech(args, options) {
|
|
@@ -649,6 +710,9 @@ async function textToSpeech(args, options) {
|
|
|
649
710
|
...options,
|
|
650
711
|
taskHint: "text-to-speech"
|
|
651
712
|
});
|
|
713
|
+
if (res instanceof Blob) {
|
|
714
|
+
return res;
|
|
715
|
+
}
|
|
652
716
|
if (res && typeof res === "object") {
|
|
653
717
|
if ("output" in res) {
|
|
654
718
|
if (typeof res.output === "string") {
|
|
@@ -662,31 +726,39 @@ async function textToSpeech(args, options) {
|
|
|
662
726
|
}
|
|
663
727
|
}
|
|
664
728
|
}
|
|
665
|
-
|
|
666
|
-
if (!isValidOutput) {
|
|
667
|
-
throw new InferenceOutputError("Expected Blob");
|
|
668
|
-
}
|
|
669
|
-
return res;
|
|
729
|
+
throw new InferenceOutputError("Expected Blob or object with output");
|
|
670
730
|
}
|
|
671
731
|
|
|
672
732
|
// src/tasks/audio/audioToAudio.ts
|
|
673
733
|
async function audioToAudio(args, options) {
|
|
674
|
-
const
|
|
734
|
+
const payload = preparePayload(args);
|
|
735
|
+
const res = await request(payload, {
|
|
675
736
|
...options,
|
|
676
737
|
taskHint: "audio-to-audio"
|
|
677
738
|
});
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
if (!
|
|
682
|
-
throw new InferenceOutputError("Expected Array
|
|
739
|
+
return validateOutput(res);
|
|
740
|
+
}
|
|
741
|
+
function validateOutput(output) {
|
|
742
|
+
if (!Array.isArray(output)) {
|
|
743
|
+
throw new InferenceOutputError("Expected Array");
|
|
683
744
|
}
|
|
684
|
-
|
|
745
|
+
if (!output.every((elem) => {
|
|
746
|
+
return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
|
|
747
|
+
})) {
|
|
748
|
+
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
|
|
749
|
+
}
|
|
750
|
+
return output;
|
|
751
|
+
}
|
|
752
|
+
|
|
753
|
+
// src/tasks/cv/utils.ts
|
|
754
|
+
function preparePayload2(args) {
|
|
755
|
+
return "data" in args ? args : { ...omit(args, "inputs"), data: args.inputs };
|
|
685
756
|
}
|
|
686
757
|
|
|
687
758
|
// src/tasks/cv/imageClassification.ts
|
|
688
759
|
async function imageClassification(args, options) {
|
|
689
|
-
const
|
|
760
|
+
const payload = preparePayload2(args);
|
|
761
|
+
const res = await request(payload, {
|
|
690
762
|
...options,
|
|
691
763
|
taskHint: "image-classification"
|
|
692
764
|
});
|
|
@@ -699,7 +771,8 @@ async function imageClassification(args, options) {
|
|
|
699
771
|
|
|
700
772
|
// src/tasks/cv/imageSegmentation.ts
|
|
701
773
|
async function imageSegmentation(args, options) {
|
|
702
|
-
const
|
|
774
|
+
const payload = preparePayload2(args);
|
|
775
|
+
const res = await request(payload, {
|
|
703
776
|
...options,
|
|
704
777
|
taskHint: "image-segmentation"
|
|
705
778
|
});
|
|
@@ -712,7 +785,8 @@ async function imageSegmentation(args, options) {
|
|
|
712
785
|
|
|
713
786
|
// src/tasks/cv/imageToText.ts
|
|
714
787
|
async function imageToText(args, options) {
|
|
715
|
-
const
|
|
788
|
+
const payload = preparePayload2(args);
|
|
789
|
+
const res = (await request(payload, {
|
|
716
790
|
...options,
|
|
717
791
|
taskHint: "image-to-text"
|
|
718
792
|
}))?.[0];
|
|
@@ -724,7 +798,8 @@ async function imageToText(args, options) {
|
|
|
724
798
|
|
|
725
799
|
// src/tasks/cv/objectDetection.ts
|
|
726
800
|
async function objectDetection(args, options) {
|
|
727
|
-
const
|
|
801
|
+
const payload = preparePayload2(args);
|
|
802
|
+
const res = await request(payload, {
|
|
728
803
|
...options,
|
|
729
804
|
taskHint: "object-detection"
|
|
730
805
|
});
|
|
@@ -741,15 +816,13 @@ async function objectDetection(args, options) {
|
|
|
741
816
|
|
|
742
817
|
// src/tasks/cv/textToImage.ts
|
|
743
818
|
async function textToImage(args, options) {
|
|
744
|
-
|
|
745
|
-
args
|
|
746
|
-
|
|
747
|
-
args.
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
}
|
|
752
|
-
const res = await request(args, {
|
|
819
|
+
const payload = args.provider === "together" || args.provider === "fal-ai" || args.provider === "replicate" ? {
|
|
820
|
+
...omit(args, ["inputs", "parameters"]),
|
|
821
|
+
...args.parameters,
|
|
822
|
+
...args.provider !== "replicate" ? { response_format: "base64" } : void 0,
|
|
823
|
+
prompt: args.inputs
|
|
824
|
+
} : args;
|
|
825
|
+
const res = await request(payload, {
|
|
753
826
|
...options,
|
|
754
827
|
taskHint: "text-to-image"
|
|
755
828
|
});
|
|
@@ -806,18 +879,30 @@ async function imageToImage(args, options) {
|
|
|
806
879
|
}
|
|
807
880
|
|
|
808
881
|
// src/tasks/cv/zeroShotImageClassification.ts
|
|
809
|
-
async function
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
new Uint8Array(
|
|
815
|
-
|
|
882
|
+
async function preparePayload3(args) {
|
|
883
|
+
if (args.inputs instanceof Blob) {
|
|
884
|
+
return {
|
|
885
|
+
...args,
|
|
886
|
+
inputs: {
|
|
887
|
+
image: base64FromBytes(new Uint8Array(await args.inputs.arrayBuffer()))
|
|
888
|
+
}
|
|
889
|
+
};
|
|
890
|
+
} else {
|
|
891
|
+
return {
|
|
892
|
+
...args,
|
|
893
|
+
inputs: {
|
|
894
|
+
image: base64FromBytes(
|
|
895
|
+
new Uint8Array(
|
|
896
|
+
args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
|
|
897
|
+
)
|
|
816
898
|
)
|
|
817
|
-
|
|
818
|
-
}
|
|
819
|
-
}
|
|
820
|
-
|
|
899
|
+
}
|
|
900
|
+
};
|
|
901
|
+
}
|
|
902
|
+
}
|
|
903
|
+
async function zeroShotImageClassification(args, options) {
|
|
904
|
+
const payload = await preparePayload3(args);
|
|
905
|
+
const res = await request(payload, {
|
|
821
906
|
...options,
|
|
822
907
|
taskHint: "zero-shot-image-classification"
|
|
823
908
|
});
|
|
@@ -828,6 +913,36 @@ async function zeroShotImageClassification(args, options) {
|
|
|
828
913
|
return res;
|
|
829
914
|
}
|
|
830
915
|
|
|
916
|
+
// src/tasks/cv/textToVideo.ts
|
|
917
|
+
var SUPPORTED_PROVIDERS = ["fal-ai", "replicate"];
|
|
918
|
+
async function textToVideo(args, options) {
|
|
919
|
+
if (!args.provider || !typedInclude(SUPPORTED_PROVIDERS, args.provider)) {
|
|
920
|
+
throw new Error(
|
|
921
|
+
`textToVideo inference is only supported for the following providers: ${SUPPORTED_PROVIDERS.join(", ")}`
|
|
922
|
+
);
|
|
923
|
+
}
|
|
924
|
+
const payload = args.provider === "fal-ai" || args.provider === "replicate" ? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs } : args;
|
|
925
|
+
const res = await request(payload, {
|
|
926
|
+
...options,
|
|
927
|
+
taskHint: "text-to-video"
|
|
928
|
+
});
|
|
929
|
+
if (args.provider === "fal-ai") {
|
|
930
|
+
const isValidOutput = typeof res === "object" && !!res && "video" in res && typeof res.video === "object" && !!res.video && "url" in res.video && typeof res.video.url === "string" && isUrl(res.video.url);
|
|
931
|
+
if (!isValidOutput) {
|
|
932
|
+
throw new InferenceOutputError("Expected { video: { url: string } }");
|
|
933
|
+
}
|
|
934
|
+
const urlResponse = await fetch(res.video.url);
|
|
935
|
+
return await urlResponse.blob();
|
|
936
|
+
} else {
|
|
937
|
+
const isValidOutput = typeof res === "object" && !!res && "output" in res && typeof res.output === "string" && isUrl(res.output);
|
|
938
|
+
if (!isValidOutput) {
|
|
939
|
+
throw new InferenceOutputError("Expected { output: string }");
|
|
940
|
+
}
|
|
941
|
+
const urlResponse = await fetch(res.output);
|
|
942
|
+
return await urlResponse.blob();
|
|
943
|
+
}
|
|
944
|
+
}
|
|
945
|
+
|
|
831
946
|
// src/lib/getDefaultTask.ts
|
|
832
947
|
var taskCache = /* @__PURE__ */ new Map();
|
|
833
948
|
var CACHE_DURATION = 10 * 60 * 1e3;
|
|
@@ -906,17 +1021,19 @@ async function questionAnswering(args, options) {
|
|
|
906
1021
|
...options,
|
|
907
1022
|
taskHint: "question-answering"
|
|
908
1023
|
});
|
|
909
|
-
const isValidOutput =
|
|
1024
|
+
const isValidOutput = Array.isArray(res) ? res.every(
|
|
1025
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
|
|
1026
|
+
) : typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";
|
|
910
1027
|
if (!isValidOutput) {
|
|
911
|
-
throw new InferenceOutputError("Expected {answer: string, end: number, score: number, start: number}");
|
|
1028
|
+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
912
1029
|
}
|
|
913
|
-
return res;
|
|
1030
|
+
return Array.isArray(res) ? res[0] : res;
|
|
914
1031
|
}
|
|
915
1032
|
|
|
916
1033
|
// src/tasks/nlp/sentenceSimilarity.ts
|
|
917
1034
|
async function sentenceSimilarity(args, options) {
|
|
918
1035
|
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : void 0;
|
|
919
|
-
const res = await request(args, {
|
|
1036
|
+
const res = await request(prepareInput(args), {
|
|
920
1037
|
...options,
|
|
921
1038
|
taskHint: "sentence-similarity",
|
|
922
1039
|
...defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }
|
|
@@ -927,6 +1044,13 @@ async function sentenceSimilarity(args, options) {
|
|
|
927
1044
|
}
|
|
928
1045
|
return res;
|
|
929
1046
|
}
|
|
1047
|
+
function prepareInput(args) {
|
|
1048
|
+
return {
|
|
1049
|
+
...omit(args, ["inputs", "parameters"]),
|
|
1050
|
+
inputs: { ...omit(args.inputs, "sourceSentence") },
|
|
1051
|
+
parameters: { source_sentence: args.inputs.sourceSentence, ...args.parameters }
|
|
1052
|
+
};
|
|
1053
|
+
}
|
|
930
1054
|
|
|
931
1055
|
// src/tasks/nlp/summarization.ts
|
|
932
1056
|
async function summarization(args, options) {
|
|
@@ -947,13 +1071,18 @@ async function tableQuestionAnswering(args, options) {
|
|
|
947
1071
|
...options,
|
|
948
1072
|
taskHint: "table-question-answering"
|
|
949
1073
|
});
|
|
950
|
-
const isValidOutput =
|
|
1074
|
+
const isValidOutput = Array.isArray(res) ? res.every((elem) => validate(elem)) : validate(res);
|
|
951
1075
|
if (!isValidOutput) {
|
|
952
1076
|
throw new InferenceOutputError(
|
|
953
1077
|
"Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
|
|
954
1078
|
);
|
|
955
1079
|
}
|
|
956
|
-
return res;
|
|
1080
|
+
return Array.isArray(res) ? res[0] : res;
|
|
1081
|
+
}
|
|
1082
|
+
function validate(elem) {
|
|
1083
|
+
return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
|
|
1084
|
+
(coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
|
|
1085
|
+
);
|
|
957
1086
|
}
|
|
958
1087
|
|
|
959
1088
|
// src/tasks/nlp/textClassification.ts
|
|
@@ -1096,11 +1225,7 @@ async function documentQuestionAnswering(args, options) {
|
|
|
1096
1225
|
inputs: {
|
|
1097
1226
|
question: args.inputs.question,
|
|
1098
1227
|
// convert Blob or ArrayBuffer to base64
|
|
1099
|
-
image: base64FromBytes(
|
|
1100
|
-
new Uint8Array(
|
|
1101
|
-
args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
|
|
1102
|
-
)
|
|
1103
|
-
)
|
|
1228
|
+
image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer()))
|
|
1104
1229
|
}
|
|
1105
1230
|
};
|
|
1106
1231
|
const res = toArray(
|
|
@@ -1108,12 +1233,14 @@ async function documentQuestionAnswering(args, options) {
|
|
|
1108
1233
|
...options,
|
|
1109
1234
|
taskHint: "document-question-answering"
|
|
1110
1235
|
})
|
|
1111
|
-
)
|
|
1112
|
-
const isValidOutput =
|
|
1236
|
+
);
|
|
1237
|
+
const isValidOutput = Array.isArray(res) && res.every(
|
|
1238
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
|
|
1239
|
+
);
|
|
1113
1240
|
if (!isValidOutput) {
|
|
1114
1241
|
throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>");
|
|
1115
1242
|
}
|
|
1116
|
-
return res;
|
|
1243
|
+
return res[0];
|
|
1117
1244
|
}
|
|
1118
1245
|
|
|
1119
1246
|
// src/tasks/multimodal/visualQuestionAnswering.ts
|
|
@@ -1123,22 +1250,20 @@ async function visualQuestionAnswering(args, options) {
|
|
|
1123
1250
|
inputs: {
|
|
1124
1251
|
question: args.inputs.question,
|
|
1125
1252
|
// convert Blob or ArrayBuffer to base64
|
|
1126
|
-
image: base64FromBytes(
|
|
1127
|
-
new Uint8Array(
|
|
1128
|
-
args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
|
|
1129
|
-
)
|
|
1130
|
-
)
|
|
1253
|
+
image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer()))
|
|
1131
1254
|
}
|
|
1132
1255
|
};
|
|
1133
|
-
const res =
|
|
1256
|
+
const res = await request(reqArgs, {
|
|
1134
1257
|
...options,
|
|
1135
1258
|
taskHint: "visual-question-answering"
|
|
1136
|
-
})
|
|
1137
|
-
const isValidOutput =
|
|
1259
|
+
});
|
|
1260
|
+
const isValidOutput = Array.isArray(res) && res.every(
|
|
1261
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
|
|
1262
|
+
);
|
|
1138
1263
|
if (!isValidOutput) {
|
|
1139
1264
|
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
|
1140
1265
|
}
|
|
1141
|
-
return res;
|
|
1266
|
+
return res[0];
|
|
1142
1267
|
}
|
|
1143
1268
|
|
|
1144
1269
|
// src/tasks/tabular/tabularRegression.ts
|
|
@@ -1245,6 +1370,7 @@ var INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-i
|
|
|
1245
1370
|
textGenerationStream,
|
|
1246
1371
|
textToImage,
|
|
1247
1372
|
textToSpeech,
|
|
1373
|
+
textToVideo,
|
|
1248
1374
|
tokenClassification,
|
|
1249
1375
|
translation,
|
|
1250
1376
|
visualQuestionAnswering,
|