@huggingface/inference 2.4.0 → 2.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +23 -0
- package/dist/index.d.ts +97 -2
- package/dist/index.js +106 -9
- package/dist/index.mjs +103 -8
- package/package.json +1 -1
- package/src/lib/getDefaultTask.ts +53 -0
- package/src/lib/isUrl.ts +3 -0
- package/src/lib/makeRequestOptions.ts +20 -5
- package/src/tasks/audio/audioToAudio.ts +46 -0
- package/src/tasks/custom/request.ts +3 -1
- package/src/tasks/custom/streamingRequest.ts +3 -1
- package/src/tasks/cv/imageClassification.ts +2 -2
- package/src/tasks/cv/zeroShotImageClassification.ts +55 -0
- package/src/tasks/index.ts +2 -0
- package/src/tasks/nlp/featureExtraction.ts +11 -1
- package/src/tasks/nlp/sentenceSimilarity.ts +11 -1
- package/src/types.ts +2 -0
package/README.md
CHANGED
|
@@ -149,6 +149,16 @@ await hf.audioClassification({
|
|
|
149
149
|
data: readFileSync('test/sample1.flac')
|
|
150
150
|
})
|
|
151
151
|
|
|
152
|
+
await hf.textToSpeech({
|
|
153
|
+
model: 'espnet/kan-bayashi_ljspeech_vits',
|
|
154
|
+
inputs: 'Hello world!'
|
|
155
|
+
})
|
|
156
|
+
|
|
157
|
+
await hf.audioToAudio({
|
|
158
|
+
model: 'speechbrain/sepformer-wham',
|
|
159
|
+
data: readFileSync('test/sample1.flac')
|
|
160
|
+
})
|
|
161
|
+
|
|
152
162
|
// Computer Vision
|
|
153
163
|
|
|
154
164
|
await hf.imageClassification({
|
|
@@ -187,6 +197,16 @@ await hf.imageToImage({
|
|
|
187
197
|
model: "lllyasviel/sd-controlnet-depth",
|
|
188
198
|
});
|
|
189
199
|
|
|
200
|
+
await hf.zeroShotImageClassification({
|
|
201
|
+
model: 'openai/clip-vit-large-patch14-336',
|
|
202
|
+
inputs: {
|
|
203
|
+
image: await (await fetch('https://placekitten.com/300/300')).blob()
|
|
204
|
+
},
|
|
205
|
+
parameters: {
|
|
206
|
+
candidate_labels: ['cat', 'dog']
|
|
207
|
+
}
|
|
208
|
+
})
|
|
209
|
+
|
|
190
210
|
// Multimodal
|
|
191
211
|
|
|
192
212
|
await hf.visualQuestionAnswering({
|
|
@@ -288,6 +308,8 @@ const { generated_text } = await gpt2.textGeneration({inputs: 'The answer to the
|
|
|
288
308
|
|
|
289
309
|
- [x] Automatic speech recognition
|
|
290
310
|
- [x] Audio classification
|
|
311
|
+
- [x] Text to speech
|
|
312
|
+
- [x] Audio to audio
|
|
291
313
|
|
|
292
314
|
### Computer Vision
|
|
293
315
|
|
|
@@ -297,6 +319,7 @@ const { generated_text } = await gpt2.textGeneration({inputs: 'The answer to the
|
|
|
297
319
|
- [x] Text to image
|
|
298
320
|
- [x] Image to text - [demo](https://huggingface.co/spaces/huggingfacejs/image-to-text)
|
|
299
321
|
- [x] Image to Image
|
|
322
|
+
- [x] Zero-shot image classification
|
|
300
323
|
|
|
301
324
|
### Multimodal
|
|
302
325
|
|
package/dist/index.d.ts
CHANGED
|
@@ -26,6 +26,8 @@ export interface Options {
|
|
|
26
26
|
fetch?: typeof fetch;
|
|
27
27
|
}
|
|
28
28
|
|
|
29
|
+
export type InferenceTask = "text-classification" | "feature-extraction" | "sentence-similarity";
|
|
30
|
+
|
|
29
31
|
export interface BaseArgs {
|
|
30
32
|
/**
|
|
31
33
|
* The access token to use. Without it, you'll get rate-limited quickly.
|
|
@@ -72,6 +74,34 @@ export function audioClassification(
|
|
|
72
74
|
args: AudioClassificationArgs,
|
|
73
75
|
options?: Options
|
|
74
76
|
): Promise<AudioClassificationReturn>;
|
|
77
|
+
export type AudioToAudioArgs = BaseArgs & {
|
|
78
|
+
/**
|
|
79
|
+
* Binary audio data
|
|
80
|
+
*/
|
|
81
|
+
data: Blob | ArrayBuffer;
|
|
82
|
+
};
|
|
83
|
+
export type AudioToAudioReturn = AudioToAudioOutputValue[];
|
|
84
|
+
export interface AudioToAudioOutputValue {
|
|
85
|
+
/**
|
|
86
|
+
* The label for the audio output (model specific)
|
|
87
|
+
*/
|
|
88
|
+
label: string;
|
|
89
|
+
|
|
90
|
+
/**
|
|
91
|
+
* Base64 encoded audio output.
|
|
92
|
+
*/
|
|
93
|
+
blob: string;
|
|
94
|
+
|
|
95
|
+
/**
|
|
96
|
+
* Content-type for blob, e.g. audio/flac
|
|
97
|
+
*/
|
|
98
|
+
"content-type": string;
|
|
99
|
+
}
|
|
100
|
+
/**
|
|
101
|
+
* This task reads some audio input and outputs one or multiple audio files.
|
|
102
|
+
* Example model: speechbrain/sepformer-wham does audio source separation.
|
|
103
|
+
*/
|
|
104
|
+
export function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioReturn>;
|
|
75
105
|
export type AutomaticSpeechRecognitionArgs = BaseArgs & {
|
|
76
106
|
/**
|
|
77
107
|
* Binary audio data
|
|
@@ -112,6 +142,8 @@ export function request<T>(
|
|
|
112
142
|
options?: Options & {
|
|
113
143
|
/** For internal HF use, which is why it's not exposed in {@link Options} */
|
|
114
144
|
includeCredentials?: boolean;
|
|
145
|
+
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
|
146
|
+
task?: string | InferenceTask;
|
|
115
147
|
}
|
|
116
148
|
): Promise<T>;
|
|
117
149
|
/**
|
|
@@ -122,6 +154,8 @@ export function streamingRequest<T>(
|
|
|
122
154
|
options?: Options & {
|
|
123
155
|
/** For internal HF use, which is why it's not exposed in {@link Options} */
|
|
124
156
|
includeCredentials?: boolean;
|
|
157
|
+
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
|
158
|
+
task?: string | InferenceTask;
|
|
125
159
|
}
|
|
126
160
|
): AsyncGenerator<T>;
|
|
127
161
|
export type ImageClassificationArgs = BaseArgs & {
|
|
@@ -133,11 +167,11 @@ export type ImageClassificationArgs = BaseArgs & {
|
|
|
133
167
|
export type ImageClassificationOutput = ImageClassificationOutputValue[];
|
|
134
168
|
export interface ImageClassificationOutputValue {
|
|
135
169
|
/**
|
|
136
|
-
*
|
|
170
|
+
* The label for the class (model specific)
|
|
137
171
|
*/
|
|
138
172
|
label: string;
|
|
139
173
|
/**
|
|
140
|
-
*
|
|
174
|
+
* A float that represents how likely it is that the image file belongs to this class.
|
|
141
175
|
*/
|
|
142
176
|
score: number;
|
|
143
177
|
}
|
|
@@ -315,6 +349,33 @@ export type TextToImageOutput = Blob;
|
|
|
315
349
|
* Recommended model: stabilityai/stable-diffusion-2
|
|
316
350
|
*/
|
|
317
351
|
export function textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageOutput>;
|
|
352
|
+
export type ZeroShotImageClassificationArgs = BaseArgs & {
|
|
353
|
+
inputs: {
|
|
354
|
+
/**
|
|
355
|
+
* Binary image data
|
|
356
|
+
*/
|
|
357
|
+
image: Blob | ArrayBuffer;
|
|
358
|
+
};
|
|
359
|
+
parameters: {
|
|
360
|
+
/**
|
|
361
|
+
* A list of strings that are potential classes for inputs. (max 10)
|
|
362
|
+
*/
|
|
363
|
+
candidate_labels: string[];
|
|
364
|
+
};
|
|
365
|
+
};
|
|
366
|
+
export type ZeroShotImageClassificationOutput = ZeroShotImageClassificationOutputValue[];
|
|
367
|
+
export interface ZeroShotImageClassificationOutputValue {
|
|
368
|
+
label: string;
|
|
369
|
+
score: number;
|
|
370
|
+
}
|
|
371
|
+
/**
|
|
372
|
+
* Classify an image to specified classes.
|
|
373
|
+
* Recommended model: openai/clip-vit-large-patch14-336
|
|
374
|
+
*/
|
|
375
|
+
export function zeroShotImageClassification(
|
|
376
|
+
args: ZeroShotImageClassificationArgs,
|
|
377
|
+
options?: Options
|
|
378
|
+
): Promise<ZeroShotImageClassificationOutput>;
|
|
318
379
|
export type DocumentQuestionAnsweringArgs = BaseArgs & {
|
|
319
380
|
inputs: {
|
|
320
381
|
/**
|
|
@@ -931,6 +992,11 @@ export class HfInference {
|
|
|
931
992
|
args: Omit<AudioClassificationArgs, 'accessToken'>,
|
|
932
993
|
options?: Options
|
|
933
994
|
): Promise<AudioClassificationReturn>;
|
|
995
|
+
/**
|
|
996
|
+
* This task reads some audio input and outputs one or multiple audio files.
|
|
997
|
+
* Example model: speechbrain/sepformer-wham does audio source separation.
|
|
998
|
+
*/
|
|
999
|
+
audioToAudio(args: Omit<AudioToAudioArgs, 'accessToken'>, options?: Options): Promise<AudioToAudioReturn>;
|
|
934
1000
|
/**
|
|
935
1001
|
* This task reads some audio input and outputs the said words within the audio files.
|
|
936
1002
|
* Recommended model (english language): facebook/wav2vec2-large-960h-lv60-self
|
|
@@ -952,6 +1018,8 @@ export class HfInference {
|
|
|
952
1018
|
options?: Options & {
|
|
953
1019
|
/** For internal HF use, which is why it's not exposed in {@link Options} */
|
|
954
1020
|
includeCredentials?: boolean;
|
|
1021
|
+
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
|
1022
|
+
task?: string | InferenceTask;
|
|
955
1023
|
}
|
|
956
1024
|
): Promise<T>;
|
|
957
1025
|
/**
|
|
@@ -962,6 +1030,8 @@ export class HfInference {
|
|
|
962
1030
|
options?: Options & {
|
|
963
1031
|
/** For internal HF use, which is why it's not exposed in {@link Options} */
|
|
964
1032
|
includeCredentials?: boolean;
|
|
1033
|
+
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
|
1034
|
+
task?: string | InferenceTask;
|
|
965
1035
|
}
|
|
966
1036
|
): AsyncGenerator<T>;
|
|
967
1037
|
/**
|
|
@@ -999,6 +1069,14 @@ export class HfInference {
|
|
|
999
1069
|
* Recommended model: stabilityai/stable-diffusion-2
|
|
1000
1070
|
*/
|
|
1001
1071
|
textToImage(args: Omit<TextToImageArgs, 'accessToken'>, options?: Options): Promise<TextToImageOutput>;
|
|
1072
|
+
/**
|
|
1073
|
+
* Classify an image to specified classes.
|
|
1074
|
+
* Recommended model: openai/clip-vit-large-patch14-336
|
|
1075
|
+
*/
|
|
1076
|
+
zeroShotImageClassification(
|
|
1077
|
+
args: Omit<ZeroShotImageClassificationArgs, 'accessToken'>,
|
|
1078
|
+
options?: Options
|
|
1079
|
+
): Promise<ZeroShotImageClassificationOutput>;
|
|
1002
1080
|
/**
|
|
1003
1081
|
* Answers a question on a document image. Recommended model: impira/layoutlm-document-qa.
|
|
1004
1082
|
*/
|
|
@@ -1119,6 +1197,11 @@ export class HfInferenceEndpoint {
|
|
|
1119
1197
|
args: Omit<AudioClassificationArgs, 'accessToken' | 'model'>,
|
|
1120
1198
|
options?: Options
|
|
1121
1199
|
): Promise<AudioClassificationReturn>;
|
|
1200
|
+
/**
|
|
1201
|
+
* This task reads some audio input and outputs one or multiple audio files.
|
|
1202
|
+
* Example model: speechbrain/sepformer-wham does audio source separation.
|
|
1203
|
+
*/
|
|
1204
|
+
audioToAudio(args: Omit<AudioToAudioArgs, 'accessToken' | 'model'>, options?: Options): Promise<AudioToAudioReturn>;
|
|
1122
1205
|
/**
|
|
1123
1206
|
* This task reads some audio input and outputs the said words within the audio files.
|
|
1124
1207
|
* Recommended model (english language): facebook/wav2vec2-large-960h-lv60-self
|
|
@@ -1140,6 +1223,8 @@ export class HfInferenceEndpoint {
|
|
|
1140
1223
|
options?: Options & {
|
|
1141
1224
|
/** For internal HF use, which is why it's not exposed in {@link Options} */
|
|
1142
1225
|
includeCredentials?: boolean;
|
|
1226
|
+
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
|
1227
|
+
task?: string | InferenceTask;
|
|
1143
1228
|
}
|
|
1144
1229
|
): Promise<T>;
|
|
1145
1230
|
/**
|
|
@@ -1150,6 +1235,8 @@ export class HfInferenceEndpoint {
|
|
|
1150
1235
|
options?: Options & {
|
|
1151
1236
|
/** For internal HF use, which is why it's not exposed in {@link Options} */
|
|
1152
1237
|
includeCredentials?: boolean;
|
|
1238
|
+
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
|
1239
|
+
task?: string | InferenceTask;
|
|
1153
1240
|
}
|
|
1154
1241
|
): AsyncGenerator<T>;
|
|
1155
1242
|
/**
|
|
@@ -1187,6 +1274,14 @@ export class HfInferenceEndpoint {
|
|
|
1187
1274
|
* Recommended model: stabilityai/stable-diffusion-2
|
|
1188
1275
|
*/
|
|
1189
1276
|
textToImage(args: Omit<TextToImageArgs, 'accessToken' | 'model'>, options?: Options): Promise<TextToImageOutput>;
|
|
1277
|
+
/**
|
|
1278
|
+
* Classify an image to specified classes.
|
|
1279
|
+
* Recommended model: openai/clip-vit-large-patch14-336
|
|
1280
|
+
*/
|
|
1281
|
+
zeroShotImageClassification(
|
|
1282
|
+
args: Omit<ZeroShotImageClassificationArgs, 'accessToken' | 'model'>,
|
|
1283
|
+
options?: Options
|
|
1284
|
+
): Promise<ZeroShotImageClassificationOutput>;
|
|
1190
1285
|
/**
|
|
1191
1286
|
* Answers a question on a document image. Recommended model: impira/layoutlm-document-qa.
|
|
1192
1287
|
*/
|
package/dist/index.js
CHANGED
|
@@ -25,6 +25,7 @@ __export(src_exports, {
|
|
|
25
25
|
HfInferenceEndpoint: () => HfInferenceEndpoint,
|
|
26
26
|
InferenceOutputError: () => InferenceOutputError,
|
|
27
27
|
audioClassification: () => audioClassification,
|
|
28
|
+
audioToAudio: () => audioToAudio,
|
|
28
29
|
automaticSpeechRecognition: () => automaticSpeechRecognition,
|
|
29
30
|
conversational: () => conversational,
|
|
30
31
|
documentQuestionAnswering: () => documentQuestionAnswering,
|
|
@@ -51,7 +52,8 @@ __export(src_exports, {
|
|
|
51
52
|
tokenClassification: () => tokenClassification,
|
|
52
53
|
translation: () => translation,
|
|
53
54
|
visualQuestionAnswering: () => visualQuestionAnswering,
|
|
54
|
-
zeroShotClassification: () => zeroShotClassification
|
|
55
|
+
zeroShotClassification: () => zeroShotClassification,
|
|
56
|
+
zeroShotImageClassification: () => zeroShotImageClassification
|
|
55
57
|
});
|
|
56
58
|
module.exports = __toCommonJS(src_exports);
|
|
57
59
|
|
|
@@ -59,6 +61,7 @@ module.exports = __toCommonJS(src_exports);
|
|
|
59
61
|
var tasks_exports = {};
|
|
60
62
|
__export(tasks_exports, {
|
|
61
63
|
audioClassification: () => audioClassification,
|
|
64
|
+
audioToAudio: () => audioToAudio,
|
|
62
65
|
automaticSpeechRecognition: () => automaticSpeechRecognition,
|
|
63
66
|
conversational: () => conversational,
|
|
64
67
|
documentQuestionAnswering: () => documentQuestionAnswering,
|
|
@@ -85,13 +88,20 @@ __export(tasks_exports, {
|
|
|
85
88
|
tokenClassification: () => tokenClassification,
|
|
86
89
|
translation: () => translation,
|
|
87
90
|
visualQuestionAnswering: () => visualQuestionAnswering,
|
|
88
|
-
zeroShotClassification: () => zeroShotClassification
|
|
91
|
+
zeroShotClassification: () => zeroShotClassification,
|
|
92
|
+
zeroShotImageClassification: () => zeroShotImageClassification
|
|
89
93
|
});
|
|
90
94
|
|
|
95
|
+
// src/lib/isUrl.ts
|
|
96
|
+
function isUrl(modelOrUrl) {
|
|
97
|
+
return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
|
|
98
|
+
}
|
|
99
|
+
|
|
91
100
|
// src/lib/makeRequestOptions.ts
|
|
92
|
-
var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co
|
|
101
|
+
var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
|
|
93
102
|
function makeRequestOptions(args, options) {
|
|
94
103
|
const { model, accessToken, ...otherArgs } = args;
|
|
104
|
+
const { task, includeCredentials, ...otherOptions } = options ?? {};
|
|
95
105
|
const headers = {};
|
|
96
106
|
if (accessToken) {
|
|
97
107
|
headers["Authorization"] = `Bearer ${accessToken}`;
|
|
@@ -110,15 +120,23 @@ function makeRequestOptions(args, options) {
|
|
|
110
120
|
headers["X-Load-Model"] = "0";
|
|
111
121
|
}
|
|
112
122
|
}
|
|
113
|
-
const url =
|
|
123
|
+
const url = (() => {
|
|
124
|
+
if (isUrl(model)) {
|
|
125
|
+
return model;
|
|
126
|
+
}
|
|
127
|
+
if (task) {
|
|
128
|
+
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
|
|
129
|
+
}
|
|
130
|
+
return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
|
|
131
|
+
})();
|
|
114
132
|
const info = {
|
|
115
133
|
headers,
|
|
116
134
|
method: "POST",
|
|
117
135
|
body: binary ? args.data : JSON.stringify({
|
|
118
136
|
...otherArgs,
|
|
119
|
-
options
|
|
137
|
+
options: options && otherOptions
|
|
120
138
|
}),
|
|
121
|
-
credentials:
|
|
139
|
+
credentials: includeCredentials ? "include" : "same-origin"
|
|
122
140
|
};
|
|
123
141
|
return { url, info };
|
|
124
142
|
}
|
|
@@ -350,6 +368,18 @@ async function textToSpeech(args, options) {
|
|
|
350
368
|
return res;
|
|
351
369
|
}
|
|
352
370
|
|
|
371
|
+
// src/tasks/audio/audioToAudio.ts
|
|
372
|
+
async function audioToAudio(args, options) {
|
|
373
|
+
const res = await request(args, options);
|
|
374
|
+
const isValidOutput = Array.isArray(res) && res.every(
|
|
375
|
+
(x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string"
|
|
376
|
+
);
|
|
377
|
+
if (!isValidOutput) {
|
|
378
|
+
throw new InferenceOutputError("Expected Array<{label: string, blob: string, content-type: string}>");
|
|
379
|
+
}
|
|
380
|
+
return res;
|
|
381
|
+
}
|
|
382
|
+
|
|
353
383
|
// src/tasks/cv/imageClassification.ts
|
|
354
384
|
async function imageClassification(args, options) {
|
|
355
385
|
const res = await request(args, options);
|
|
@@ -445,6 +475,26 @@ async function imageToImage(args, options) {
|
|
|
445
475
|
return res;
|
|
446
476
|
}
|
|
447
477
|
|
|
478
|
+
// src/tasks/cv/zeroShotImageClassification.ts
|
|
479
|
+
async function zeroShotImageClassification(args, options) {
|
|
480
|
+
const reqArgs = {
|
|
481
|
+
...args,
|
|
482
|
+
inputs: {
|
|
483
|
+
image: base64FromBytes(
|
|
484
|
+
new Uint8Array(
|
|
485
|
+
args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
|
|
486
|
+
)
|
|
487
|
+
)
|
|
488
|
+
}
|
|
489
|
+
};
|
|
490
|
+
const res = await request(reqArgs, options);
|
|
491
|
+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
492
|
+
if (!isValidOutput) {
|
|
493
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
494
|
+
}
|
|
495
|
+
return res;
|
|
496
|
+
}
|
|
497
|
+
|
|
448
498
|
// src/tasks/nlp/conversational.ts
|
|
449
499
|
async function conversational(args, options) {
|
|
450
500
|
const res = await request(args, options);
|
|
@@ -457,9 +507,47 @@ async function conversational(args, options) {
|
|
|
457
507
|
return res;
|
|
458
508
|
}
|
|
459
509
|
|
|
510
|
+
// src/lib/getDefaultTask.ts
|
|
511
|
+
var taskCache = /* @__PURE__ */ new Map();
|
|
512
|
+
var CACHE_DURATION = 10 * 60 * 1e3;
|
|
513
|
+
var MAX_CACHE_ITEMS = 1e3;
|
|
514
|
+
var HF_HUB_URL = "https://huggingface.co";
|
|
515
|
+
async function getDefaultTask(model, accessToken) {
|
|
516
|
+
if (isUrl(model)) {
|
|
517
|
+
return null;
|
|
518
|
+
}
|
|
519
|
+
const key = `${model}:${accessToken}`;
|
|
520
|
+
let cachedTask = taskCache.get(key);
|
|
521
|
+
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
|
|
522
|
+
taskCache.delete(key);
|
|
523
|
+
cachedTask = void 0;
|
|
524
|
+
}
|
|
525
|
+
if (cachedTask === void 0) {
|
|
526
|
+
const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
|
|
527
|
+
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
|
|
528
|
+
}).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
|
|
529
|
+
if (!modelTask) {
|
|
530
|
+
return null;
|
|
531
|
+
}
|
|
532
|
+
cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
|
|
533
|
+
taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
|
|
534
|
+
if (taskCache.size > MAX_CACHE_ITEMS) {
|
|
535
|
+
taskCache.delete(taskCache.keys().next().value);
|
|
536
|
+
}
|
|
537
|
+
}
|
|
538
|
+
return cachedTask.task;
|
|
539
|
+
}
|
|
540
|
+
|
|
460
541
|
// src/tasks/nlp/featureExtraction.ts
|
|
461
542
|
async function featureExtraction(args, options) {
|
|
462
|
-
const
|
|
543
|
+
const defaultTask = await getDefaultTask(args.model, args.accessToken);
|
|
544
|
+
const res = await request(
|
|
545
|
+
args,
|
|
546
|
+
defaultTask === "sentence-similarity" ? {
|
|
547
|
+
...options,
|
|
548
|
+
task: "feature-extraction"
|
|
549
|
+
} : options
|
|
550
|
+
);
|
|
463
551
|
let isValidOutput = true;
|
|
464
552
|
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
|
|
465
553
|
if (curDepth > maxDepth)
|
|
@@ -503,7 +591,14 @@ async function questionAnswering(args, options) {
|
|
|
503
591
|
|
|
504
592
|
// src/tasks/nlp/sentenceSimilarity.ts
|
|
505
593
|
async function sentenceSimilarity(args, options) {
|
|
506
|
-
const
|
|
594
|
+
const defaultTask = await getDefaultTask(args.model, args.accessToken);
|
|
595
|
+
const res = await request(
|
|
596
|
+
args,
|
|
597
|
+
defaultTask === "feature-extraction" ? {
|
|
598
|
+
...options,
|
|
599
|
+
task: "sentence-similarity"
|
|
600
|
+
} : options
|
|
601
|
+
);
|
|
507
602
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
|
|
508
603
|
if (!isValidOutput) {
|
|
509
604
|
throw new InferenceOutputError("Expected number[]");
|
|
@@ -715,6 +810,7 @@ var HfInferenceEndpoint = class {
|
|
|
715
810
|
HfInferenceEndpoint,
|
|
716
811
|
InferenceOutputError,
|
|
717
812
|
audioClassification,
|
|
813
|
+
audioToAudio,
|
|
718
814
|
automaticSpeechRecognition,
|
|
719
815
|
conversational,
|
|
720
816
|
documentQuestionAnswering,
|
|
@@ -741,5 +837,6 @@ var HfInferenceEndpoint = class {
|
|
|
741
837
|
tokenClassification,
|
|
742
838
|
translation,
|
|
743
839
|
visualQuestionAnswering,
|
|
744
|
-
zeroShotClassification
|
|
840
|
+
zeroShotClassification,
|
|
841
|
+
zeroShotImageClassification
|
|
745
842
|
});
|
package/dist/index.mjs
CHANGED
|
@@ -9,6 +9,7 @@ var __export = (target, all) => {
|
|
|
9
9
|
var tasks_exports = {};
|
|
10
10
|
__export(tasks_exports, {
|
|
11
11
|
audioClassification: () => audioClassification,
|
|
12
|
+
audioToAudio: () => audioToAudio,
|
|
12
13
|
automaticSpeechRecognition: () => automaticSpeechRecognition,
|
|
13
14
|
conversational: () => conversational,
|
|
14
15
|
documentQuestionAnswering: () => documentQuestionAnswering,
|
|
@@ -35,13 +36,20 @@ __export(tasks_exports, {
|
|
|
35
36
|
tokenClassification: () => tokenClassification,
|
|
36
37
|
translation: () => translation,
|
|
37
38
|
visualQuestionAnswering: () => visualQuestionAnswering,
|
|
38
|
-
zeroShotClassification: () => zeroShotClassification
|
|
39
|
+
zeroShotClassification: () => zeroShotClassification,
|
|
40
|
+
zeroShotImageClassification: () => zeroShotImageClassification
|
|
39
41
|
});
|
|
40
42
|
|
|
43
|
+
// src/lib/isUrl.ts
|
|
44
|
+
function isUrl(modelOrUrl) {
|
|
45
|
+
return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
|
|
46
|
+
}
|
|
47
|
+
|
|
41
48
|
// src/lib/makeRequestOptions.ts
|
|
42
|
-
var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co
|
|
49
|
+
var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
|
|
43
50
|
function makeRequestOptions(args, options) {
|
|
44
51
|
const { model, accessToken, ...otherArgs } = args;
|
|
52
|
+
const { task, includeCredentials, ...otherOptions } = options ?? {};
|
|
45
53
|
const headers = {};
|
|
46
54
|
if (accessToken) {
|
|
47
55
|
headers["Authorization"] = `Bearer ${accessToken}`;
|
|
@@ -60,15 +68,23 @@ function makeRequestOptions(args, options) {
|
|
|
60
68
|
headers["X-Load-Model"] = "0";
|
|
61
69
|
}
|
|
62
70
|
}
|
|
63
|
-
const url =
|
|
71
|
+
const url = (() => {
|
|
72
|
+
if (isUrl(model)) {
|
|
73
|
+
return model;
|
|
74
|
+
}
|
|
75
|
+
if (task) {
|
|
76
|
+
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
|
|
77
|
+
}
|
|
78
|
+
return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
|
|
79
|
+
})();
|
|
64
80
|
const info = {
|
|
65
81
|
headers,
|
|
66
82
|
method: "POST",
|
|
67
83
|
body: binary ? args.data : JSON.stringify({
|
|
68
84
|
...otherArgs,
|
|
69
|
-
options
|
|
85
|
+
options: options && otherOptions
|
|
70
86
|
}),
|
|
71
|
-
credentials:
|
|
87
|
+
credentials: includeCredentials ? "include" : "same-origin"
|
|
72
88
|
};
|
|
73
89
|
return { url, info };
|
|
74
90
|
}
|
|
@@ -300,6 +316,18 @@ async function textToSpeech(args, options) {
|
|
|
300
316
|
return res;
|
|
301
317
|
}
|
|
302
318
|
|
|
319
|
+
// src/tasks/audio/audioToAudio.ts
|
|
320
|
+
async function audioToAudio(args, options) {
|
|
321
|
+
const res = await request(args, options);
|
|
322
|
+
const isValidOutput = Array.isArray(res) && res.every(
|
|
323
|
+
(x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string"
|
|
324
|
+
);
|
|
325
|
+
if (!isValidOutput) {
|
|
326
|
+
throw new InferenceOutputError("Expected Array<{label: string, blob: string, content-type: string}>");
|
|
327
|
+
}
|
|
328
|
+
return res;
|
|
329
|
+
}
|
|
330
|
+
|
|
303
331
|
// src/tasks/cv/imageClassification.ts
|
|
304
332
|
async function imageClassification(args, options) {
|
|
305
333
|
const res = await request(args, options);
|
|
@@ -395,6 +423,26 @@ async function imageToImage(args, options) {
|
|
|
395
423
|
return res;
|
|
396
424
|
}
|
|
397
425
|
|
|
426
|
+
// src/tasks/cv/zeroShotImageClassification.ts
|
|
427
|
+
async function zeroShotImageClassification(args, options) {
|
|
428
|
+
const reqArgs = {
|
|
429
|
+
...args,
|
|
430
|
+
inputs: {
|
|
431
|
+
image: base64FromBytes(
|
|
432
|
+
new Uint8Array(
|
|
433
|
+
args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
|
|
434
|
+
)
|
|
435
|
+
)
|
|
436
|
+
}
|
|
437
|
+
};
|
|
438
|
+
const res = await request(reqArgs, options);
|
|
439
|
+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
440
|
+
if (!isValidOutput) {
|
|
441
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
442
|
+
}
|
|
443
|
+
return res;
|
|
444
|
+
}
|
|
445
|
+
|
|
398
446
|
// src/tasks/nlp/conversational.ts
|
|
399
447
|
async function conversational(args, options) {
|
|
400
448
|
const res = await request(args, options);
|
|
@@ -407,9 +455,47 @@ async function conversational(args, options) {
|
|
|
407
455
|
return res;
|
|
408
456
|
}
|
|
409
457
|
|
|
458
|
+
// src/lib/getDefaultTask.ts
|
|
459
|
+
var taskCache = /* @__PURE__ */ new Map();
|
|
460
|
+
var CACHE_DURATION = 10 * 60 * 1e3;
|
|
461
|
+
var MAX_CACHE_ITEMS = 1e3;
|
|
462
|
+
var HF_HUB_URL = "https://huggingface.co";
|
|
463
|
+
async function getDefaultTask(model, accessToken) {
|
|
464
|
+
if (isUrl(model)) {
|
|
465
|
+
return null;
|
|
466
|
+
}
|
|
467
|
+
const key = `${model}:${accessToken}`;
|
|
468
|
+
let cachedTask = taskCache.get(key);
|
|
469
|
+
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
|
|
470
|
+
taskCache.delete(key);
|
|
471
|
+
cachedTask = void 0;
|
|
472
|
+
}
|
|
473
|
+
if (cachedTask === void 0) {
|
|
474
|
+
const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
|
|
475
|
+
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
|
|
476
|
+
}).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
|
|
477
|
+
if (!modelTask) {
|
|
478
|
+
return null;
|
|
479
|
+
}
|
|
480
|
+
cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
|
|
481
|
+
taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
|
|
482
|
+
if (taskCache.size > MAX_CACHE_ITEMS) {
|
|
483
|
+
taskCache.delete(taskCache.keys().next().value);
|
|
484
|
+
}
|
|
485
|
+
}
|
|
486
|
+
return cachedTask.task;
|
|
487
|
+
}
|
|
488
|
+
|
|
410
489
|
// src/tasks/nlp/featureExtraction.ts
|
|
411
490
|
async function featureExtraction(args, options) {
|
|
412
|
-
const
|
|
491
|
+
const defaultTask = await getDefaultTask(args.model, args.accessToken);
|
|
492
|
+
const res = await request(
|
|
493
|
+
args,
|
|
494
|
+
defaultTask === "sentence-similarity" ? {
|
|
495
|
+
...options,
|
|
496
|
+
task: "feature-extraction"
|
|
497
|
+
} : options
|
|
498
|
+
);
|
|
413
499
|
let isValidOutput = true;
|
|
414
500
|
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
|
|
415
501
|
if (curDepth > maxDepth)
|
|
@@ -453,7 +539,14 @@ async function questionAnswering(args, options) {
|
|
|
453
539
|
|
|
454
540
|
// src/tasks/nlp/sentenceSimilarity.ts
|
|
455
541
|
async function sentenceSimilarity(args, options) {
|
|
456
|
-
const
|
|
542
|
+
const defaultTask = await getDefaultTask(args.model, args.accessToken);
|
|
543
|
+
const res = await request(
|
|
544
|
+
args,
|
|
545
|
+
defaultTask === "feature-extraction" ? {
|
|
546
|
+
...options,
|
|
547
|
+
task: "sentence-similarity"
|
|
548
|
+
} : options
|
|
549
|
+
);
|
|
457
550
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
|
|
458
551
|
if (!isValidOutput) {
|
|
459
552
|
throw new InferenceOutputError("Expected number[]");
|
|
@@ -664,6 +757,7 @@ export {
|
|
|
664
757
|
HfInferenceEndpoint,
|
|
665
758
|
InferenceOutputError,
|
|
666
759
|
audioClassification,
|
|
760
|
+
audioToAudio,
|
|
667
761
|
automaticSpeechRecognition,
|
|
668
762
|
conversational,
|
|
669
763
|
documentQuestionAnswering,
|
|
@@ -690,5 +784,6 @@ export {
|
|
|
690
784
|
tokenClassification,
|
|
691
785
|
translation,
|
|
692
786
|
visualQuestionAnswering,
|
|
693
|
-
zeroShotClassification
|
|
787
|
+
zeroShotClassification,
|
|
788
|
+
zeroShotImageClassification
|
|
694
789
|
};
|
package/package.json
CHANGED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import { isUrl } from "./isUrl";
|
|
2
|
+
|
|
3
|
+
/**
|
|
4
|
+
* We want to make calls to the huggingface hub the least possible, eg if
|
|
5
|
+
* someone is calling the inference API 1000 times per second, we don't want
|
|
6
|
+
* to make 1000 calls to the hub to get the task name.
|
|
7
|
+
*/
|
|
8
|
+
const taskCache = new Map<string, { task: string; date: Date }>();
|
|
9
|
+
const CACHE_DURATION = 10 * 60 * 1000;
|
|
10
|
+
const MAX_CACHE_ITEMS = 1000;
|
|
11
|
+
const HF_HUB_URL = "https://huggingface.co";
|
|
12
|
+
|
|
13
|
+
/**
|
|
14
|
+
* Get the default task. Use a LRU cache of 1000 items with 10 minutes expiration
|
|
15
|
+
* to avoid making too many calls to the HF hub.
|
|
16
|
+
*
|
|
17
|
+
* @returns The default task for the model, or `null` if it was impossible to get it
|
|
18
|
+
*/
|
|
19
|
+
export async function getDefaultTask(model: string, accessToken: string | undefined): Promise<string | null> {
|
|
20
|
+
if (isUrl(model)) {
|
|
21
|
+
return null;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
const key = `${model}:${accessToken}`;
|
|
25
|
+
let cachedTask = taskCache.get(key);
|
|
26
|
+
|
|
27
|
+
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
|
|
28
|
+
taskCache.delete(key);
|
|
29
|
+
cachedTask = undefined;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
if (cachedTask === undefined) {
|
|
33
|
+
const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
|
|
34
|
+
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {},
|
|
35
|
+
})
|
|
36
|
+
.then((resp) => resp.json())
|
|
37
|
+
.then((json) => json.pipeline_tag)
|
|
38
|
+
.catch(() => null);
|
|
39
|
+
|
|
40
|
+
if (!modelTask) {
|
|
41
|
+
return null;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
cachedTask = { task: modelTask, date: new Date() };
|
|
45
|
+
taskCache.set(key, { task: modelTask, date: new Date() });
|
|
46
|
+
|
|
47
|
+
if (taskCache.size > MAX_CACHE_ITEMS) {
|
|
48
|
+
taskCache.delete(taskCache.keys().next().value);
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
return cachedTask.task;
|
|
53
|
+
}
|
package/src/lib/isUrl.ts
ADDED
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
import type { Options, RequestArgs } from "../types";
|
|
1
|
+
import type { InferenceTask, Options, RequestArgs } from "../types";
|
|
2
|
+
import { isUrl } from "./isUrl";
|
|
2
3
|
|
|
3
|
-
const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co
|
|
4
|
+
const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
|
|
4
5
|
|
|
5
6
|
/**
|
|
6
7
|
* Helper that prepares request arguments
|
|
@@ -13,9 +14,12 @@ export function makeRequestOptions(
|
|
|
13
14
|
options?: Options & {
|
|
14
15
|
/** For internal HF use, which is why it's not exposed in {@link Options} */
|
|
15
16
|
includeCredentials?: boolean;
|
|
17
|
+
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
|
18
|
+
task?: string | InferenceTask;
|
|
16
19
|
}
|
|
17
20
|
): { url: string; info: RequestInit } {
|
|
18
21
|
const { model, accessToken, ...otherArgs } = args;
|
|
22
|
+
const { task, includeCredentials, ...otherOptions } = options ?? {};
|
|
19
23
|
|
|
20
24
|
const headers: Record<string, string> = {};
|
|
21
25
|
if (accessToken) {
|
|
@@ -38,7 +42,18 @@ export function makeRequestOptions(
|
|
|
38
42
|
}
|
|
39
43
|
}
|
|
40
44
|
|
|
41
|
-
const url =
|
|
45
|
+
const url = (() => {
|
|
46
|
+
if (isUrl(model)) {
|
|
47
|
+
return model;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
if (task) {
|
|
51
|
+
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
|
|
55
|
+
})();
|
|
56
|
+
|
|
42
57
|
const info: RequestInit = {
|
|
43
58
|
headers,
|
|
44
59
|
method: "POST",
|
|
@@ -46,9 +61,9 @@ export function makeRequestOptions(
|
|
|
46
61
|
? args.data
|
|
47
62
|
: JSON.stringify({
|
|
48
63
|
...otherArgs,
|
|
49
|
-
options,
|
|
64
|
+
options: options && otherOptions,
|
|
50
65
|
}),
|
|
51
|
-
credentials:
|
|
66
|
+
credentials: includeCredentials ? "include" : "same-origin",
|
|
52
67
|
};
|
|
53
68
|
|
|
54
69
|
return { url, info };
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
|
2
|
+
import type { BaseArgs, Options } from "../../types";
|
|
3
|
+
import { request } from "../custom/request";
|
|
4
|
+
|
|
5
|
+
export type AudioToAudioArgs = BaseArgs & {
|
|
6
|
+
/**
|
|
7
|
+
* Binary audio data
|
|
8
|
+
*/
|
|
9
|
+
data: Blob | ArrayBuffer;
|
|
10
|
+
};
|
|
11
|
+
|
|
12
|
+
export interface AudioToAudioOutputValue {
|
|
13
|
+
/**
|
|
14
|
+
* The label for the audio output (model specific)
|
|
15
|
+
*/
|
|
16
|
+
label: string;
|
|
17
|
+
|
|
18
|
+
/**
|
|
19
|
+
* Base64 encoded audio output.
|
|
20
|
+
*/
|
|
21
|
+
blob: string;
|
|
22
|
+
|
|
23
|
+
/**
|
|
24
|
+
* Content-type for blob, e.g. audio/flac
|
|
25
|
+
*/
|
|
26
|
+
"content-type": string;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
export type AudioToAudioReturn = AudioToAudioOutputValue[];
|
|
30
|
+
|
|
31
|
+
/**
|
|
32
|
+
* This task reads some audio input and outputs one or multiple audio files.
|
|
33
|
+
* Example model: speechbrain/sepformer-wham does audio source separation.
|
|
34
|
+
*/
|
|
35
|
+
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioReturn> {
|
|
36
|
+
const res = await request<AudioToAudioReturn>(args, options);
|
|
37
|
+
const isValidOutput =
|
|
38
|
+
Array.isArray(res) &&
|
|
39
|
+
res.every(
|
|
40
|
+
(x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string"
|
|
41
|
+
);
|
|
42
|
+
if (!isValidOutput) {
|
|
43
|
+
throw new InferenceOutputError("Expected Array<{label: string, blob: string, content-type: string}>");
|
|
44
|
+
}
|
|
45
|
+
return res;
|
|
46
|
+
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import type { Options, RequestArgs } from "../../types";
|
|
1
|
+
import type { InferenceTask, Options, RequestArgs } from "../../types";
|
|
2
2
|
import { makeRequestOptions } from "../../lib/makeRequestOptions";
|
|
3
3
|
|
|
4
4
|
/**
|
|
@@ -9,6 +9,8 @@ export async function request<T>(
|
|
|
9
9
|
options?: Options & {
|
|
10
10
|
/** For internal HF use, which is why it's not exposed in {@link Options} */
|
|
11
11
|
includeCredentials?: boolean;
|
|
12
|
+
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
|
13
|
+
task?: string | InferenceTask;
|
|
12
14
|
}
|
|
13
15
|
): Promise<T> {
|
|
14
16
|
const { url, info } = makeRequestOptions(args, options);
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import type { Options, RequestArgs } from "../../types";
|
|
1
|
+
import type { InferenceTask, Options, RequestArgs } from "../../types";
|
|
2
2
|
import { makeRequestOptions } from "../../lib/makeRequestOptions";
|
|
3
3
|
import type { EventSourceMessage } from "../../vendor/fetch-event-source/parse";
|
|
4
4
|
import { getLines, getMessages } from "../../vendor/fetch-event-source/parse";
|
|
@@ -11,6 +11,8 @@ export async function* streamingRequest<T>(
|
|
|
11
11
|
options?: Options & {
|
|
12
12
|
/** For internal HF use, which is why it's not exposed in {@link Options} */
|
|
13
13
|
includeCredentials?: boolean;
|
|
14
|
+
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
|
15
|
+
task?: string | InferenceTask;
|
|
14
16
|
}
|
|
15
17
|
): AsyncGenerator<T> {
|
|
16
18
|
const { url, info } = makeRequestOptions({ ...args, stream: true }, options);
|
|
@@ -11,11 +11,11 @@ export type ImageClassificationArgs = BaseArgs & {
|
|
|
11
11
|
|
|
12
12
|
export interface ImageClassificationOutputValue {
|
|
13
13
|
/**
|
|
14
|
-
*
|
|
14
|
+
* The label for the class (model specific)
|
|
15
15
|
*/
|
|
16
16
|
label: string;
|
|
17
17
|
/**
|
|
18
|
-
*
|
|
18
|
+
* A float that represents how likely it is that the image file belongs to this class.
|
|
19
19
|
*/
|
|
20
20
|
score: number;
|
|
21
21
|
}
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
|
2
|
+
import type { BaseArgs, Options } from "../../types";
|
|
3
|
+
import { request } from "../custom/request";
|
|
4
|
+
import type { RequestArgs } from "../../types";
|
|
5
|
+
import { base64FromBytes } from "../../../../shared";
|
|
6
|
+
|
|
7
|
+
export type ZeroShotImageClassificationArgs = BaseArgs & {
|
|
8
|
+
inputs: {
|
|
9
|
+
/**
|
|
10
|
+
* Binary image data
|
|
11
|
+
*/
|
|
12
|
+
image: Blob | ArrayBuffer;
|
|
13
|
+
};
|
|
14
|
+
parameters: {
|
|
15
|
+
/**
|
|
16
|
+
* A list of strings that are potential classes for inputs. (max 10)
|
|
17
|
+
*/
|
|
18
|
+
candidate_labels: string[];
|
|
19
|
+
};
|
|
20
|
+
};
|
|
21
|
+
|
|
22
|
+
export interface ZeroShotImageClassificationOutputValue {
|
|
23
|
+
label: string;
|
|
24
|
+
score: number;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
export type ZeroShotImageClassificationOutput = ZeroShotImageClassificationOutputValue[];
|
|
28
|
+
|
|
29
|
+
/**
|
|
30
|
+
* Classify an image to specified classes.
|
|
31
|
+
* Recommended model: openai/clip-vit-large-patch14-336
|
|
32
|
+
*/
|
|
33
|
+
export async function zeroShotImageClassification(
|
|
34
|
+
args: ZeroShotImageClassificationArgs,
|
|
35
|
+
options?: Options
|
|
36
|
+
): Promise<ZeroShotImageClassificationOutput> {
|
|
37
|
+
const reqArgs: RequestArgs = {
|
|
38
|
+
...args,
|
|
39
|
+
inputs: {
|
|
40
|
+
image: base64FromBytes(
|
|
41
|
+
new Uint8Array(
|
|
42
|
+
args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
|
|
43
|
+
)
|
|
44
|
+
),
|
|
45
|
+
},
|
|
46
|
+
} as RequestArgs;
|
|
47
|
+
|
|
48
|
+
const res = await request<ZeroShotImageClassificationOutput>(reqArgs, options);
|
|
49
|
+
const isValidOutput =
|
|
50
|
+
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
51
|
+
if (!isValidOutput) {
|
|
52
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
53
|
+
}
|
|
54
|
+
return res;
|
|
55
|
+
}
|
package/src/tasks/index.ts
CHANGED
|
@@ -6,6 +6,7 @@ export * from "./custom/streamingRequest";
|
|
|
6
6
|
export * from "./audio/audioClassification";
|
|
7
7
|
export * from "./audio/automaticSpeechRecognition";
|
|
8
8
|
export * from "./audio/textToSpeech";
|
|
9
|
+
export * from "./audio/audioToAudio";
|
|
9
10
|
|
|
10
11
|
// Computer Vision tasks
|
|
11
12
|
export * from "./cv/imageClassification";
|
|
@@ -14,6 +15,7 @@ export * from "./cv/imageToText";
|
|
|
14
15
|
export * from "./cv/objectDetection";
|
|
15
16
|
export * from "./cv/textToImage";
|
|
16
17
|
export * from "./cv/imageToImage";
|
|
18
|
+
export * from "./cv/zeroShotImageClassification";
|
|
17
19
|
|
|
18
20
|
// Natural Language Processing tasks
|
|
19
21
|
export * from "./nlp/conversational";
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
|
2
|
+
import { getDefaultTask } from "../../lib/getDefaultTask";
|
|
2
3
|
import type { BaseArgs, Options } from "../../types";
|
|
3
4
|
import { request } from "../custom/request";
|
|
4
5
|
|
|
@@ -24,7 +25,16 @@ export async function featureExtraction(
|
|
|
24
25
|
args: FeatureExtractionArgs,
|
|
25
26
|
options?: Options
|
|
26
27
|
): Promise<FeatureExtractionOutput> {
|
|
27
|
-
const
|
|
28
|
+
const defaultTask = await getDefaultTask(args.model, args.accessToken);
|
|
29
|
+
const res = await request<FeatureExtractionOutput>(
|
|
30
|
+
args,
|
|
31
|
+
defaultTask === "sentence-similarity"
|
|
32
|
+
? {
|
|
33
|
+
...options,
|
|
34
|
+
task: "feature-extraction",
|
|
35
|
+
}
|
|
36
|
+
: options
|
|
37
|
+
);
|
|
28
38
|
let isValidOutput = true;
|
|
29
39
|
|
|
30
40
|
const isNumArrayRec = (arr: unknown[], maxDepth: number, curDepth = 0): boolean => {
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
|
2
|
+
import { getDefaultTask } from "../../lib/getDefaultTask";
|
|
2
3
|
import type { BaseArgs, Options } from "../../types";
|
|
3
4
|
import { request } from "../custom/request";
|
|
4
5
|
|
|
@@ -24,7 +25,16 @@ export async function sentenceSimilarity(
|
|
|
24
25
|
args: SentenceSimilarityArgs,
|
|
25
26
|
options?: Options
|
|
26
27
|
): Promise<SentenceSimilarityOutput> {
|
|
27
|
-
const
|
|
28
|
+
const defaultTask = await getDefaultTask(args.model, args.accessToken);
|
|
29
|
+
const res = await request<SentenceSimilarityOutput>(
|
|
30
|
+
args,
|
|
31
|
+
defaultTask === "feature-extraction"
|
|
32
|
+
? {
|
|
33
|
+
...options,
|
|
34
|
+
task: "sentence-similarity",
|
|
35
|
+
}
|
|
36
|
+
: options
|
|
37
|
+
);
|
|
28
38
|
|
|
29
39
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
|
|
30
40
|
if (!isValidOutput) {
|
package/src/types.ts
CHANGED
|
@@ -26,6 +26,8 @@ export interface Options {
|
|
|
26
26
|
fetch?: typeof fetch;
|
|
27
27
|
}
|
|
28
28
|
|
|
29
|
+
export type InferenceTask = "text-classification" | "feature-extraction" | "sentence-similarity";
|
|
30
|
+
|
|
29
31
|
export interface BaseArgs {
|
|
30
32
|
/**
|
|
31
33
|
* The access token to use. Without it, you'll get rate-limited quickly.
|