@huggingface/inference 3.0.0 → 3.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +11 -6
- package/dist/index.cjs +193 -76
- package/dist/index.js +193 -76
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/fal-ai.d.ts.map +1 -1
- package/dist/src/providers/replicate.d.ts.map +1 -1
- package/dist/src/providers/together.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts +4 -18
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioToAudio.d.ts +10 -9
- package/dist/src/tasks/audio/audioToAudio.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts +3 -12
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts +4 -8
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/audio/utils.d.ts +11 -0
- package/dist/src/tasks/audio/utils.d.ts.map +1 -0
- package/dist/src/tasks/cv/imageClassification.d.ts +3 -17
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts +3 -21
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts +3 -49
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts +3 -12
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts +3 -26
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts +3 -38
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts +6 -0
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -0
- package/dist/src/tasks/cv/utils.d.ts +11 -0
- package/dist/src/tasks/cv/utils.d.ts.map +1 -0
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts +7 -15
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts +5 -28
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts +5 -20
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts +2 -21
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts +3 -25
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts +2 -13
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts +2 -42
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts +3 -31
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts +2 -16
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts +2 -45
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts +2 -13
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts +2 -22
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/types.d.ts +4 -0
- package/dist/src/types.d.ts.map +1 -1
- package/package.json +2 -2
- package/src/lib/makeRequestOptions.ts +7 -5
- package/src/providers/fal-ai.ts +12 -0
- package/src/providers/replicate.ts +6 -3
- package/src/providers/together.ts +2 -0
- package/src/tasks/audio/audioClassification.ts +7 -22
- package/src/tasks/audio/audioToAudio.ts +43 -23
- package/src/tasks/audio/automaticSpeechRecognition.ts +35 -23
- package/src/tasks/audio/textToSpeech.ts +23 -14
- package/src/tasks/audio/utils.ts +18 -0
- package/src/tasks/cv/imageClassification.ts +5 -20
- package/src/tasks/cv/imageSegmentation.ts +5 -24
- package/src/tasks/cv/imageToImage.ts +4 -52
- package/src/tasks/cv/imageToText.ts +6 -15
- package/src/tasks/cv/objectDetection.ts +5 -30
- package/src/tasks/cv/textToImage.ts +14 -50
- package/src/tasks/cv/textToVideo.ts +67 -0
- package/src/tasks/cv/utils.ts +13 -0
- package/src/tasks/cv/zeroShotImageClassification.ts +32 -31
- package/src/tasks/multimodal/documentQuestionAnswering.ts +25 -43
- package/src/tasks/multimodal/visualQuestionAnswering.ts +20 -36
- package/src/tasks/nlp/fillMask.ts +2 -22
- package/src/tasks/nlp/questionAnswering.ts +22 -36
- package/src/tasks/nlp/sentenceSimilarity.ts +12 -15
- package/src/tasks/nlp/summarization.ts +2 -43
- package/src/tasks/nlp/tableQuestionAnswering.ts +25 -41
- package/src/tasks/nlp/textClassification.ts +3 -18
- package/src/tasks/nlp/tokenClassification.ts +2 -47
- package/src/tasks/nlp/translation.ts +3 -17
- package/src/tasks/nlp/zeroShotClassification.ts +2 -24
- package/src/types.ts +7 -1
package/README.md
CHANGED
|
@@ -42,15 +42,15 @@ const hf = new HfInference('your access token')
|
|
|
42
42
|
|
|
43
43
|
Your access token should be kept private. If you need to protect it in front-end applications, we suggest setting up a proxy server that stores the access token.
|
|
44
44
|
|
|
45
|
-
###
|
|
45
|
+
### Third-party inference providers
|
|
46
46
|
|
|
47
|
-
You can
|
|
47
|
+
You can send inference requests to third-party providers with the inference client.
|
|
48
48
|
|
|
49
49
|
Currently, we support the following providers: [Fal.ai](https://fal.ai), [Replicate](https://replicate.com), [Together](https://together.xyz) and [Sambanova](https://sambanova.ai).
|
|
50
50
|
|
|
51
|
-
To
|
|
51
|
+
To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
|
|
52
52
|
```ts
|
|
53
|
-
const accessToken = "hf_..."; // Either a HF access token, or an API key from the
|
|
53
|
+
const accessToken = "hf_..."; // Either a HF access token, or an API key from the third-party provider (Replicate in this example)
|
|
54
54
|
|
|
55
55
|
const client = new HfInference(accessToken);
|
|
56
56
|
await client.textToImage({
|
|
@@ -63,14 +63,19 @@ await client.textToImage({
|
|
|
63
63
|
When authenticated with a Hugging Face access token, the request is routed through https://huggingface.co.
|
|
64
64
|
When authenticated with a third-party provider key, the request is made directly against that provider's inference API.
|
|
65
65
|
|
|
66
|
-
Only a subset of models are supported when requesting
|
|
66
|
+
Only a subset of models are supported when requesting third-party providers. You can check the list of supported models per pipeline tasks here:
|
|
67
67
|
- [Fal.ai supported models](./src/providers/fal-ai.ts)
|
|
68
68
|
- [Replicate supported models](./src/providers/replicate.ts)
|
|
69
69
|
- [Sambanova supported models](./src/providers/sambanova.ts)
|
|
70
70
|
- [Together supported models](./src/providers/together.ts)
|
|
71
71
|
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)
|
|
72
72
|
|
|
73
|
-
|
|
73
|
+
❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
|
|
74
|
+
This is not an issue for LLMs as everyone converged on the OpenAI API anyways, but can be more tricky for other tasks like "text-to-image" or "automatic-speech-recognition" where there exists no standard API. Let us know if any help is needed or if we can make things easier for you!
|
|
75
|
+
|
|
76
|
+
👋**Want to add another provider?** Get in touch if you'd like to add support for another Inference provider, and/or request it on https://huggingface.co/spaces/huggingface/HuggingDiscussions/discussions/49
|
|
77
|
+
|
|
78
|
+
### Tree-shaking
|
|
74
79
|
|
|
75
80
|
You can import the functions you need directly from the module instead of using the `HfInference` class.
|
|
76
81
|
|
package/dist/index.cjs
CHANGED
|
@@ -107,10 +107,22 @@ var FAL_AI_API_BASE_URL = "https://fal.run";
|
|
|
107
107
|
var FAL_AI_SUPPORTED_MODEL_IDS = {
|
|
108
108
|
"text-to-image": {
|
|
109
109
|
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
|
|
110
|
-
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev"
|
|
110
|
+
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
|
|
111
|
+
"playgroundai/playground-v2.5-1024px-aesthetic": "fal-ai/playground-v25",
|
|
112
|
+
"ByteDance/SDXL-Lightning": "fal-ai/lightning-models",
|
|
113
|
+
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "fal-ai/pixart-sigma",
|
|
114
|
+
"stabilityai/stable-diffusion-3-medium": "fal-ai/stable-diffusion-v3-medium",
|
|
115
|
+
"Warlord-K/Sana-1024": "fal-ai/sana",
|
|
116
|
+
"fal/AuraFlow-v0.2": "fal-ai/aura-flow",
|
|
117
|
+
"stabilityai/stable-diffusion-3.5-large": "fal-ai/stable-diffusion-v35-large",
|
|
118
|
+
"Kwai-Kolors/Kolors": "fal-ai/kolors"
|
|
111
119
|
},
|
|
112
120
|
"automatic-speech-recognition": {
|
|
113
121
|
"openai/whisper-large-v3": "fal-ai/whisper"
|
|
122
|
+
},
|
|
123
|
+
"text-to-video": {
|
|
124
|
+
"genmo/mochi-1-preview": "fal-ai/mochi-v1",
|
|
125
|
+
"tencent/HunyuanVideo": "fal-ai/hunyuan-video"
|
|
114
126
|
}
|
|
115
127
|
};
|
|
116
128
|
|
|
@@ -120,10 +132,13 @@ var REPLICATE_SUPPORTED_MODEL_IDS = {
|
|
|
120
132
|
"text-to-image": {
|
|
121
133
|
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
|
|
122
134
|
"ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637"
|
|
135
|
+
},
|
|
136
|
+
"text-to-speech": {
|
|
137
|
+
"OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts:39a59319327b27327fa3095149c5a746e7f2aee18c75055c3368237a6503cd26"
|
|
138
|
+
},
|
|
139
|
+
"text-to-video": {
|
|
140
|
+
"genmo/mochi-1-preview": "genmoai/mochi-1:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460"
|
|
123
141
|
}
|
|
124
|
-
// "text-to-speech": {
|
|
125
|
-
// "SWivid/F5-TTS": "x-lance/f5-tts:87faf6dd7a692dd82043f662e76369cab126a2cf1937e25a9d41e0b834fd230e"
|
|
126
|
-
// },
|
|
127
142
|
};
|
|
128
143
|
|
|
129
144
|
// src/providers/sambanova.ts
|
|
@@ -159,6 +174,8 @@ var TOGETHER_SUPPORTED_MODEL_IDS = {
|
|
|
159
174
|
},
|
|
160
175
|
conversational: {
|
|
161
176
|
"databricks/dbrx-instruct": "databricks/dbrx-instruct",
|
|
177
|
+
"deepseek-ai/DeepSeek-R1": "deepseek-ai/DeepSeek-R1",
|
|
178
|
+
"deepseek-ai/DeepSeek-V3": "deepseek-ai/DeepSeek-V3",
|
|
162
179
|
"deepseek-ai/deepseek-llm-67b-chat": "deepseek-ai/deepseek-llm-67b-chat",
|
|
163
180
|
"google/gemma-2-9b-it": "google/gemma-2-9b-it",
|
|
164
181
|
"google/gemma-2b-it": "google/gemma-2-27b-it",
|
|
@@ -204,7 +221,8 @@ function isUrl(modelOrUrl) {
|
|
|
204
221
|
var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_HUB_URL}/api/inference-proxy/{{PROVIDER}}`;
|
|
205
222
|
var tasks = null;
|
|
206
223
|
async function makeRequestOptions(args, options) {
|
|
207
|
-
const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...
|
|
224
|
+
const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
|
|
225
|
+
let otherArgs = remainingArgs;
|
|
208
226
|
const provider = maybeProvider ?? "hf-inference";
|
|
209
227
|
const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion: chatCompletion2 } = options ?? {};
|
|
210
228
|
if (endpointUrl && provider !== "hf-inference") {
|
|
@@ -263,9 +281,9 @@ async function makeRequestOptions(args, options) {
|
|
|
263
281
|
} else if (includeCredentials === true) {
|
|
264
282
|
credentials = "include";
|
|
265
283
|
}
|
|
266
|
-
if (provider === "replicate"
|
|
267
|
-
const version = model.split(":")[1];
|
|
268
|
-
otherArgs
|
|
284
|
+
if (provider === "replicate") {
|
|
285
|
+
const version = model.includes(":") ? model.split(":")[1] : void 0;
|
|
286
|
+
otherArgs = { input: otherArgs, version };
|
|
269
287
|
}
|
|
270
288
|
const info = {
|
|
271
289
|
headers,
|
|
@@ -585,9 +603,42 @@ var InferenceOutputError = class extends TypeError {
|
|
|
585
603
|
}
|
|
586
604
|
};
|
|
587
605
|
|
|
606
|
+
// src/utils/pick.ts
|
|
607
|
+
function pick(o, props) {
|
|
608
|
+
return Object.assign(
|
|
609
|
+
{},
|
|
610
|
+
...props.map((prop) => {
|
|
611
|
+
if (o[prop] !== void 0) {
|
|
612
|
+
return { [prop]: o[prop] };
|
|
613
|
+
}
|
|
614
|
+
})
|
|
615
|
+
);
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
// src/utils/typedInclude.ts
|
|
619
|
+
function typedInclude(arr, v) {
|
|
620
|
+
return arr.includes(v);
|
|
621
|
+
}
|
|
622
|
+
|
|
623
|
+
// src/utils/omit.ts
|
|
624
|
+
function omit(o, props) {
|
|
625
|
+
const propsArr = Array.isArray(props) ? props : [props];
|
|
626
|
+
const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
|
|
627
|
+
return pick(o, letsKeep);
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
// src/tasks/audio/utils.ts
|
|
631
|
+
function preparePayload(args) {
|
|
632
|
+
return "data" in args ? args : {
|
|
633
|
+
...omit(args, "inputs"),
|
|
634
|
+
data: args.inputs
|
|
635
|
+
};
|
|
636
|
+
}
|
|
637
|
+
|
|
588
638
|
// src/tasks/audio/audioClassification.ts
|
|
589
639
|
async function audioClassification(args, options) {
|
|
590
|
-
const
|
|
640
|
+
const payload = preparePayload(args);
|
|
641
|
+
const res = await request(payload, {
|
|
591
642
|
...options,
|
|
592
643
|
taskHint: "audio-classification"
|
|
593
644
|
});
|
|
@@ -613,15 +664,8 @@ function base64FromBytes(arr) {
|
|
|
613
664
|
|
|
614
665
|
// src/tasks/audio/automaticSpeechRecognition.ts
|
|
615
666
|
async function automaticSpeechRecognition(args, options) {
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
const base64audio = base64FromBytes(
|
|
619
|
-
new Uint8Array(args.data instanceof ArrayBuffer ? args.data : await args.data.arrayBuffer())
|
|
620
|
-
);
|
|
621
|
-
args.audio_url = `data:${contentType};base64,${base64audio}`;
|
|
622
|
-
delete args.data;
|
|
623
|
-
}
|
|
624
|
-
const res = await request(args, {
|
|
667
|
+
const payload = await buildPayload(args);
|
|
668
|
+
const res = await request(payload, {
|
|
625
669
|
...options,
|
|
626
670
|
taskHint: "automatic-speech-recognition"
|
|
627
671
|
});
|
|
@@ -631,6 +675,32 @@ async function automaticSpeechRecognition(args, options) {
|
|
|
631
675
|
}
|
|
632
676
|
return res;
|
|
633
677
|
}
|
|
678
|
+
var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
|
|
679
|
+
async function buildPayload(args) {
|
|
680
|
+
if (args.provider === "fal-ai") {
|
|
681
|
+
const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : void 0;
|
|
682
|
+
const contentType = blob?.type;
|
|
683
|
+
if (!contentType) {
|
|
684
|
+
throw new Error(
|
|
685
|
+
`Unable to determine the input's content-type. Make sure your are passing a Blob when using provider fal-ai.`
|
|
686
|
+
);
|
|
687
|
+
}
|
|
688
|
+
if (!FAL_AI_SUPPORTED_BLOB_TYPES.includes(contentType)) {
|
|
689
|
+
throw new Error(
|
|
690
|
+
`Provider fal-ai does not support blob type ${contentType} - supported content types are: ${FAL_AI_SUPPORTED_BLOB_TYPES.join(
|
|
691
|
+
", "
|
|
692
|
+
)}`
|
|
693
|
+
);
|
|
694
|
+
}
|
|
695
|
+
const base64audio = base64FromBytes(new Uint8Array(await blob.arrayBuffer()));
|
|
696
|
+
return {
|
|
697
|
+
..."data" in args ? omit(args, "data") : omit(args, "inputs"),
|
|
698
|
+
audio_url: `data:${contentType};base64,${base64audio}`
|
|
699
|
+
};
|
|
700
|
+
} else {
|
|
701
|
+
return preparePayload(args);
|
|
702
|
+
}
|
|
703
|
+
}
|
|
634
704
|
|
|
635
705
|
// src/tasks/audio/textToSpeech.ts
|
|
636
706
|
async function textToSpeech(args, options) {
|
|
@@ -638,31 +708,55 @@ async function textToSpeech(args, options) {
|
|
|
638
708
|
...options,
|
|
639
709
|
taskHint: "text-to-speech"
|
|
640
710
|
});
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
throw new InferenceOutputError("Expected Blob");
|
|
711
|
+
if (res instanceof Blob) {
|
|
712
|
+
return res;
|
|
644
713
|
}
|
|
645
|
-
|
|
714
|
+
if (res && typeof res === "object") {
|
|
715
|
+
if ("output" in res) {
|
|
716
|
+
if (typeof res.output === "string") {
|
|
717
|
+
const urlResponse = await fetch(res.output);
|
|
718
|
+
const blob = await urlResponse.blob();
|
|
719
|
+
return blob;
|
|
720
|
+
} else if (Array.isArray(res.output)) {
|
|
721
|
+
const urlResponse = await fetch(res.output[0]);
|
|
722
|
+
const blob = await urlResponse.blob();
|
|
723
|
+
return blob;
|
|
724
|
+
}
|
|
725
|
+
}
|
|
726
|
+
}
|
|
727
|
+
throw new InferenceOutputError("Expected Blob or object with output");
|
|
646
728
|
}
|
|
647
729
|
|
|
648
730
|
// src/tasks/audio/audioToAudio.ts
|
|
649
731
|
async function audioToAudio(args, options) {
|
|
650
|
-
const
|
|
732
|
+
const payload = preparePayload(args);
|
|
733
|
+
const res = await request(payload, {
|
|
651
734
|
...options,
|
|
652
735
|
taskHint: "audio-to-audio"
|
|
653
736
|
});
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
if (!
|
|
658
|
-
throw new InferenceOutputError("Expected Array
|
|
737
|
+
return validateOutput(res);
|
|
738
|
+
}
|
|
739
|
+
function validateOutput(output) {
|
|
740
|
+
if (!Array.isArray(output)) {
|
|
741
|
+
throw new InferenceOutputError("Expected Array");
|
|
659
742
|
}
|
|
660
|
-
|
|
743
|
+
if (!output.every((elem) => {
|
|
744
|
+
return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
|
|
745
|
+
})) {
|
|
746
|
+
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
|
|
747
|
+
}
|
|
748
|
+
return output;
|
|
749
|
+
}
|
|
750
|
+
|
|
751
|
+
// src/tasks/cv/utils.ts
|
|
752
|
+
function preparePayload2(args) {
|
|
753
|
+
return "data" in args ? args : { ...omit(args, "inputs"), data: args.inputs };
|
|
661
754
|
}
|
|
662
755
|
|
|
663
756
|
// src/tasks/cv/imageClassification.ts
|
|
664
757
|
async function imageClassification(args, options) {
|
|
665
|
-
const
|
|
758
|
+
const payload = preparePayload2(args);
|
|
759
|
+
const res = await request(payload, {
|
|
666
760
|
...options,
|
|
667
761
|
taskHint: "image-classification"
|
|
668
762
|
});
|
|
@@ -675,7 +769,8 @@ async function imageClassification(args, options) {
|
|
|
675
769
|
|
|
676
770
|
// src/tasks/cv/imageSegmentation.ts
|
|
677
771
|
async function imageSegmentation(args, options) {
|
|
678
|
-
const
|
|
772
|
+
const payload = preparePayload2(args);
|
|
773
|
+
const res = await request(payload, {
|
|
679
774
|
...options,
|
|
680
775
|
taskHint: "image-segmentation"
|
|
681
776
|
});
|
|
@@ -688,7 +783,8 @@ async function imageSegmentation(args, options) {
|
|
|
688
783
|
|
|
689
784
|
// src/tasks/cv/imageToText.ts
|
|
690
785
|
async function imageToText(args, options) {
|
|
691
|
-
const
|
|
786
|
+
const payload = preparePayload2(args);
|
|
787
|
+
const res = (await request(payload, {
|
|
692
788
|
...options,
|
|
693
789
|
taskHint: "image-to-text"
|
|
694
790
|
}))?.[0];
|
|
@@ -700,7 +796,8 @@ async function imageToText(args, options) {
|
|
|
700
796
|
|
|
701
797
|
// src/tasks/cv/objectDetection.ts
|
|
702
798
|
async function objectDetection(args, options) {
|
|
703
|
-
const
|
|
799
|
+
const payload = preparePayload2(args);
|
|
800
|
+
const res = await request(payload, {
|
|
704
801
|
...options,
|
|
705
802
|
taskHint: "object-detection"
|
|
706
803
|
});
|
|
@@ -717,15 +814,13 @@ async function objectDetection(args, options) {
|
|
|
717
814
|
|
|
718
815
|
// src/tasks/cv/textToImage.ts
|
|
719
816
|
async function textToImage(args, options) {
|
|
720
|
-
|
|
721
|
-
args
|
|
722
|
-
args.
|
|
723
|
-
args.
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
}
|
|
728
|
-
const res = await request(args, {
|
|
817
|
+
const payload = args.provider === "together" || args.provider === "fal-ai" || args.provider === "replicate" ? {
|
|
818
|
+
...omit(args, ["inputs", "parameters"]),
|
|
819
|
+
...args.parameters,
|
|
820
|
+
...args.provider !== "replicate" ? { response_format: "base64" } : void 0,
|
|
821
|
+
prompt: args.inputs
|
|
822
|
+
} : args;
|
|
823
|
+
const res = await request(payload, {
|
|
729
824
|
...options,
|
|
730
825
|
taskHint: "text-to-image"
|
|
731
826
|
});
|
|
@@ -782,18 +877,30 @@ async function imageToImage(args, options) {
|
|
|
782
877
|
}
|
|
783
878
|
|
|
784
879
|
// src/tasks/cv/zeroShotImageClassification.ts
|
|
785
|
-
async function
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
new Uint8Array(
|
|
791
|
-
|
|
880
|
+
async function preparePayload3(args) {
|
|
881
|
+
if (args.inputs instanceof Blob) {
|
|
882
|
+
return {
|
|
883
|
+
...args,
|
|
884
|
+
inputs: {
|
|
885
|
+
image: base64FromBytes(new Uint8Array(await args.inputs.arrayBuffer()))
|
|
886
|
+
}
|
|
887
|
+
};
|
|
888
|
+
} else {
|
|
889
|
+
return {
|
|
890
|
+
...args,
|
|
891
|
+
inputs: {
|
|
892
|
+
image: base64FromBytes(
|
|
893
|
+
new Uint8Array(
|
|
894
|
+
args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
|
|
895
|
+
)
|
|
792
896
|
)
|
|
793
|
-
|
|
794
|
-
}
|
|
795
|
-
}
|
|
796
|
-
|
|
897
|
+
}
|
|
898
|
+
};
|
|
899
|
+
}
|
|
900
|
+
}
|
|
901
|
+
async function zeroShotImageClassification(args, options) {
|
|
902
|
+
const payload = await preparePayload3(args);
|
|
903
|
+
const res = await request(payload, {
|
|
797
904
|
...options,
|
|
798
905
|
taskHint: "zero-shot-image-classification"
|
|
799
906
|
});
|
|
@@ -882,17 +989,19 @@ async function questionAnswering(args, options) {
|
|
|
882
989
|
...options,
|
|
883
990
|
taskHint: "question-answering"
|
|
884
991
|
});
|
|
885
|
-
const isValidOutput =
|
|
992
|
+
const isValidOutput = Array.isArray(res) ? res.every(
|
|
993
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
|
|
994
|
+
) : typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";
|
|
886
995
|
if (!isValidOutput) {
|
|
887
|
-
throw new InferenceOutputError("Expected {answer: string, end: number, score: number, start: number}");
|
|
996
|
+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
888
997
|
}
|
|
889
|
-
return res;
|
|
998
|
+
return Array.isArray(res) ? res[0] : res;
|
|
890
999
|
}
|
|
891
1000
|
|
|
892
1001
|
// src/tasks/nlp/sentenceSimilarity.ts
|
|
893
1002
|
async function sentenceSimilarity(args, options) {
|
|
894
1003
|
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : void 0;
|
|
895
|
-
const res = await request(args, {
|
|
1004
|
+
const res = await request(prepareInput(args), {
|
|
896
1005
|
...options,
|
|
897
1006
|
taskHint: "sentence-similarity",
|
|
898
1007
|
...defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }
|
|
@@ -903,6 +1012,13 @@ async function sentenceSimilarity(args, options) {
|
|
|
903
1012
|
}
|
|
904
1013
|
return res;
|
|
905
1014
|
}
|
|
1015
|
+
function prepareInput(args) {
|
|
1016
|
+
return {
|
|
1017
|
+
...omit(args, ["inputs", "parameters"]),
|
|
1018
|
+
inputs: { ...omit(args.inputs, "sourceSentence") },
|
|
1019
|
+
parameters: { source_sentence: args.inputs.sourceSentence, ...args.parameters }
|
|
1020
|
+
};
|
|
1021
|
+
}
|
|
906
1022
|
|
|
907
1023
|
// src/tasks/nlp/summarization.ts
|
|
908
1024
|
async function summarization(args, options) {
|
|
@@ -923,13 +1039,18 @@ async function tableQuestionAnswering(args, options) {
|
|
|
923
1039
|
...options,
|
|
924
1040
|
taskHint: "table-question-answering"
|
|
925
1041
|
});
|
|
926
|
-
const isValidOutput =
|
|
1042
|
+
const isValidOutput = Array.isArray(res) ? res.every((elem) => validate(elem)) : validate(res);
|
|
927
1043
|
if (!isValidOutput) {
|
|
928
1044
|
throw new InferenceOutputError(
|
|
929
1045
|
"Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
|
|
930
1046
|
);
|
|
931
1047
|
}
|
|
932
|
-
return res;
|
|
1048
|
+
return Array.isArray(res) ? res[0] : res;
|
|
1049
|
+
}
|
|
1050
|
+
function validate(elem) {
|
|
1051
|
+
return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
|
|
1052
|
+
(coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
|
|
1053
|
+
);
|
|
933
1054
|
}
|
|
934
1055
|
|
|
935
1056
|
// src/tasks/nlp/textClassification.ts
|
|
@@ -1072,11 +1193,7 @@ async function documentQuestionAnswering(args, options) {
|
|
|
1072
1193
|
inputs: {
|
|
1073
1194
|
question: args.inputs.question,
|
|
1074
1195
|
// convert Blob or ArrayBuffer to base64
|
|
1075
|
-
image: base64FromBytes(
|
|
1076
|
-
new Uint8Array(
|
|
1077
|
-
args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
|
|
1078
|
-
)
|
|
1079
|
-
)
|
|
1196
|
+
image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer()))
|
|
1080
1197
|
}
|
|
1081
1198
|
};
|
|
1082
1199
|
const res = toArray(
|
|
@@ -1084,12 +1201,14 @@ async function documentQuestionAnswering(args, options) {
|
|
|
1084
1201
|
...options,
|
|
1085
1202
|
taskHint: "document-question-answering"
|
|
1086
1203
|
})
|
|
1087
|
-
)
|
|
1088
|
-
const isValidOutput =
|
|
1204
|
+
);
|
|
1205
|
+
const isValidOutput = Array.isArray(res) && res.every(
|
|
1206
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
|
|
1207
|
+
);
|
|
1089
1208
|
if (!isValidOutput) {
|
|
1090
1209
|
throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>");
|
|
1091
1210
|
}
|
|
1092
|
-
return res;
|
|
1211
|
+
return res[0];
|
|
1093
1212
|
}
|
|
1094
1213
|
|
|
1095
1214
|
// src/tasks/multimodal/visualQuestionAnswering.ts
|
|
@@ -1099,22 +1218,20 @@ async function visualQuestionAnswering(args, options) {
|
|
|
1099
1218
|
inputs: {
|
|
1100
1219
|
question: args.inputs.question,
|
|
1101
1220
|
// convert Blob or ArrayBuffer to base64
|
|
1102
|
-
image: base64FromBytes(
|
|
1103
|
-
new Uint8Array(
|
|
1104
|
-
args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
|
|
1105
|
-
)
|
|
1106
|
-
)
|
|
1221
|
+
image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer()))
|
|
1107
1222
|
}
|
|
1108
1223
|
};
|
|
1109
|
-
const res =
|
|
1224
|
+
const res = await request(reqArgs, {
|
|
1110
1225
|
...options,
|
|
1111
1226
|
taskHint: "visual-question-answering"
|
|
1112
|
-
})
|
|
1113
|
-
const isValidOutput =
|
|
1227
|
+
});
|
|
1228
|
+
const isValidOutput = Array.isArray(res) && res.every(
|
|
1229
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
|
|
1230
|
+
);
|
|
1114
1231
|
if (!isValidOutput) {
|
|
1115
1232
|
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
|
1116
1233
|
}
|
|
1117
|
-
return res;
|
|
1234
|
+
return res[0];
|
|
1118
1235
|
}
|
|
1119
1236
|
|
|
1120
1237
|
// src/tasks/tabular/tabularRegression.ts
|