@huggingface/inference 2.3.3 → 2.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -149,6 +149,16 @@ await hf.audioClassification({
149
149
  data: readFileSync('test/sample1.flac')
150
150
  })
151
151
 
152
+ await hf.textToSpeech({
153
+ model: 'espnet/kan-bayashi_ljspeech_vits',
154
+ inputs: 'Hello world!'
155
+ })
156
+
157
+ await hf.audioToAudio({
158
+ model: 'speechbrain/sepformer-wham',
159
+ data: readFileSync('test/sample1.flac')
160
+ })
161
+
152
162
  // Computer Vision
153
163
 
154
164
  await hf.imageClassification({
@@ -187,6 +197,16 @@ await hf.imageToImage({
187
197
  model: "lllyasviel/sd-controlnet-depth",
188
198
  });
189
199
 
200
+ await hf.zeroShotImageClassification({
201
+ model: 'openai/clip-vit-large-patch14-336',
202
+ inputs: {
203
+ image: await (await fetch('https://placekitten.com/300/300')).blob()
204
+ },
205
+ parameters: {
206
+ candidate_labels: ['cat', 'dog']
207
+ }
208
+ })
209
+
190
210
  // Multimodal
191
211
 
192
212
  await hf.visualQuestionAnswering({
@@ -211,12 +231,31 @@ await hf.tabularRegression({
211
231
  model: "scikit-learn/Fish-Weight",
212
232
  inputs: {
213
233
  data: {
214
- "Height":["11.52", "12.48", "12.3778"],
215
- "Length1":["23.2", "24", "23.9"],
216
- "Length2":["25.4", "26.3", "26.5"],
217
- "Length3":["30", "31.2", "31.1"],
218
- "Species":["Bream", "Bream", "Bream"],
219
- "Width":["4.02", "4.3056", "4.6961"]
234
+ "Height": ["11.52", "12.48", "12.3778"],
235
+ "Length1": ["23.2", "24", "23.9"],
236
+ "Length2": ["25.4", "26.3", "26.5"],
237
+ "Length3": ["30", "31.2", "31.1"],
238
+ "Species": ["Bream", "Bream", "Bream"],
239
+ "Width": ["4.02", "4.3056", "4.6961"]
240
+ },
241
+ },
242
+ })
243
+
244
+ await hf.tabularClassification({
245
+ model: "vvmnnnkv/wine-quality",
246
+ inputs: {
247
+ data: {
248
+ "fixed_acidity": ["7.4", "7.8", "10.3"],
249
+ "volatile_acidity": ["0.7", "0.88", "0.32"],
250
+ "citric_acid": ["0", "0", "0.45"],
251
+ "residual_sugar": ["1.9", "2.6", "6.4"],
252
+ "chlorides": ["0.076", "0.098", "0.073"],
253
+ "free_sulfur_dioxide": ["11", "25", "5"],
254
+ "total_sulfur_dioxide": ["34", "67", "13"],
255
+ "density": ["0.9978", "0.9968", "0.9976"],
256
+ "pH": ["3.51", "3.2", "3.23"],
257
+ "sulphates": ["0.56", "0.68", "0.82"],
258
+ "alcohol": ["9.4", "9.8", "12.6"]
220
259
  },
221
260
  },
222
261
  })
@@ -269,6 +308,8 @@ const { generated_text } = await gpt2.textGeneration({inputs: 'The answer to the
269
308
 
270
309
  - [x] Automatic speech recognition
271
310
  - [x] Audio classification
311
+ - [x] Text to speech
312
+ - [x] Audio to audio
272
313
 
273
314
  ### Computer Vision
274
315
 
@@ -278,6 +319,7 @@ const { generated_text } = await gpt2.textGeneration({inputs: 'The answer to the
278
319
  - [x] Text to image
279
320
  - [x] Image to text - [demo](https://huggingface.co/spaces/huggingfacejs/image-to-text)
280
321
  - [x] Image to Image
322
+ - [x] Zero-shot image classification
281
323
 
282
324
  ### Multimodal
283
325
 
@@ -287,6 +329,7 @@ const { generated_text } = await gpt2.textGeneration({inputs: 'The answer to the
287
329
  ### Tabular
288
330
 
289
331
  - [x] Tabular regression
332
+ - [x] Tabular classification
290
333
 
291
334
  ## Tree-shaking
292
335
 
package/dist/index.d.ts CHANGED
@@ -26,6 +26,8 @@ export interface Options {
26
26
  fetch?: typeof fetch;
27
27
  }
28
28
 
29
+ export type InferenceTask = "text-classification" | "feature-extraction" | "sentence-similarity";
30
+
29
31
  export interface BaseArgs {
30
32
  /**
31
33
  * The access token to use. Without it, you'll get rate-limited quickly.
@@ -72,6 +74,34 @@ export function audioClassification(
72
74
  args: AudioClassificationArgs,
73
75
  options?: Options
74
76
  ): Promise<AudioClassificationReturn>;
77
+ export type AudioToAudioArgs = BaseArgs & {
78
+ /**
79
+ * Binary audio data
80
+ */
81
+ data: Blob | ArrayBuffer;
82
+ };
83
+ export type AudioToAudioReturn = AudioToAudioOutputValue[];
84
+ export interface AudioToAudioOutputValue {
85
+ /**
86
+ * The label for the audio output (model specific)
87
+ */
88
+ label: string;
89
+
90
+ /**
91
+ * Base64 encoded audio output.
92
+ */
93
+ blob: string;
94
+
95
+ /**
96
+ * Content-type for blob, e.g. audio/flac
97
+ */
98
+ "content-type": string;
99
+ }
100
+ /**
101
+ * This task reads some audio input and outputs one or multiple audio files.
102
+ * Example model: speechbrain/sepformer-wham does audio source separation.
103
+ */
104
+ export function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioReturn>;
75
105
  export type AutomaticSpeechRecognitionArgs = BaseArgs & {
76
106
  /**
77
107
  * Binary audio data
@@ -112,6 +142,8 @@ export function request<T>(
112
142
  options?: Options & {
113
143
  /** For internal HF use, which is why it's not exposed in {@link Options} */
114
144
  includeCredentials?: boolean;
145
+ /** When a model can be used for multiple tasks, and we want to run a non-default task */
146
+ task?: string | InferenceTask;
115
147
  }
116
148
  ): Promise<T>;
117
149
  /**
@@ -122,6 +154,8 @@ export function streamingRequest<T>(
122
154
  options?: Options & {
123
155
  /** For internal HF use, which is why it's not exposed in {@link Options} */
124
156
  includeCredentials?: boolean;
157
+ /** When a model can be used for multiple tasks, and we want to run a non-default task */
158
+ task?: string | InferenceTask;
125
159
  }
126
160
  ): AsyncGenerator<T>;
127
161
  export type ImageClassificationArgs = BaseArgs & {
@@ -315,6 +349,33 @@ export type TextToImageOutput = Blob;
315
349
  * Recommended model: stabilityai/stable-diffusion-2
316
350
  */
317
351
  export function textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageOutput>;
352
+ export type ZeroShotImageClassificationArgs = BaseArgs & {
353
+ inputs: {
354
+ /**
355
+ * Binary image data
356
+ */
357
+ image: Blob | ArrayBuffer;
358
+ };
359
+ parameters: {
360
+ /**
361
+ * A list of strings that are potential classes for inputs. (max 10)
362
+ */
363
+ candidate_labels: string[];
364
+ };
365
+ };
366
+ export type ZeroShotImageClassificationOutput = ZeroShotImageClassificationOutputValue[];
367
+ export interface ZeroShotImageClassificationOutputValue {
368
+ label: string;
369
+ score: number;
370
+ }
371
+ /**
372
+ * Classify an image to specified classes.
373
+ * Recommended model: openai/clip-vit-large-patch14-336
374
+ */
375
+ export function zeroShotImageClassification(
376
+ args: ZeroShotImageClassificationArgs,
377
+ options?: Options
378
+ ): Promise<ZeroShotImageClassificationOutput>;
318
379
  export type DocumentQuestionAnsweringArgs = BaseArgs & {
319
380
  inputs: {
320
381
  /**
@@ -448,9 +509,9 @@ export type FeatureExtractionArgs = BaseArgs & {
448
509
  inputs: string | string[];
449
510
  };
450
511
  /**
451
- * Returned values are a list of floats, or a list of list of floats (depending on if you sent a string or a list of string, and if the automatic reduction, usually mean_pooling for instance was applied for you or not. This should be explained on the model's README.
512
+ * Returned values are a list of floats, or a list of list of floats, or a list of list of list of floats (depending on if you sent a string or a list of string, and if the automatic reduction, usually mean_pooling for instance was applied for you or not. This should be explained on the model's README.
452
513
  */
453
- export type FeatureExtractionOutput = (number | number[])[];
514
+ export type FeatureExtractionOutput = (number | number[] | number[][])[];
454
515
  /**
455
516
  * This task reads some text and outputs raw float values, that are usually consumed as part of a semantic database/semantic search.
456
517
  */
@@ -875,6 +936,27 @@ export function zeroShotClassification(
875
936
  args: ZeroShotClassificationArgs,
876
937
  options?: Options
877
938
  ): Promise<ZeroShotClassificationOutput>;
939
+ export type TabularClassificationArgs = BaseArgs & {
940
+ inputs: {
941
+ /**
942
+ * A table of data represented as a dict of list where entries are headers and the lists are all the values, all lists must have the same size.
943
+ */
944
+ data: Record<string, string[]>;
945
+ };
946
+ };
947
+ /**
948
+ * A list of predicted labels for each row
949
+ */
950
+ export type TabularClassificationOutput = number[];
951
+ /**
952
+ * Predicts target label for a given set of features in tabular form.
953
+ * Typically, you will want to train a classification model on your training data and use it with your new data of the same format.
954
+ * Example model: vvmnnnkv/wine-quality
955
+ */
956
+ export function tabularClassification(
957
+ args: TabularClassificationArgs,
958
+ options?: Options
959
+ ): Promise<TabularClassificationOutput>;
878
960
  export type TabularRegressionArgs = BaseArgs & {
879
961
  inputs: {
880
962
  /**
@@ -910,6 +992,11 @@ export class HfInference {
910
992
  args: Omit<AudioClassificationArgs, 'accessToken'>,
911
993
  options?: Options
912
994
  ): Promise<AudioClassificationReturn>;
995
+ /**
996
+ * This task reads some audio input and outputs one or multiple audio files.
997
+ * Example model: speechbrain/sepformer-wham does audio source separation.
998
+ */
999
+ audioToAudio(args: Omit<AudioToAudioArgs, 'accessToken'>, options?: Options): Promise<AudioToAudioReturn>;
913
1000
  /**
914
1001
  * This task reads some audio input and outputs the said words within the audio files.
915
1002
  * Recommended model (english language): facebook/wav2vec2-large-960h-lv60-self
@@ -931,6 +1018,8 @@ export class HfInference {
931
1018
  options?: Options & {
932
1019
  /** For internal HF use, which is why it's not exposed in {@link Options} */
933
1020
  includeCredentials?: boolean;
1021
+ /** When a model can be used for multiple tasks, and we want to run a non-default task */
1022
+ task?: string | InferenceTask;
934
1023
  }
935
1024
  ): Promise<T>;
936
1025
  /**
@@ -941,6 +1030,8 @@ export class HfInference {
941
1030
  options?: Options & {
942
1031
  /** For internal HF use, which is why it's not exposed in {@link Options} */
943
1032
  includeCredentials?: boolean;
1033
+ /** When a model can be used for multiple tasks, and we want to run a non-default task */
1034
+ task?: string | InferenceTask;
944
1035
  }
945
1036
  ): AsyncGenerator<T>;
946
1037
  /**
@@ -978,6 +1069,14 @@ export class HfInference {
978
1069
  * Recommended model: stabilityai/stable-diffusion-2
979
1070
  */
980
1071
  textToImage(args: Omit<TextToImageArgs, 'accessToken'>, options?: Options): Promise<TextToImageOutput>;
1072
+ /**
1073
+ * Classify an image to specified classes.
1074
+ * Recommended model: openai/clip-vit-large-patch14-336
1075
+ */
1076
+ zeroShotImageClassification(
1077
+ args: Omit<ZeroShotImageClassificationArgs, 'accessToken'>,
1078
+ options?: Options
1079
+ ): Promise<ZeroShotImageClassificationOutput>;
981
1080
  /**
982
1081
  * Answers a question on a document image. Recommended model: impira/layoutlm-document-qa.
983
1082
  */
@@ -1069,6 +1168,15 @@ export class HfInference {
1069
1168
  args: Omit<ZeroShotClassificationArgs, 'accessToken'>,
1070
1169
  options?: Options
1071
1170
  ): Promise<ZeroShotClassificationOutput>;
1171
+ /**
1172
+ * Predicts target label for a given set of features in tabular form.
1173
+ * Typically, you will want to train a classification model on your training data and use it with your new data of the same format.
1174
+ * Example model: vvmnnnkv/wine-quality
1175
+ */
1176
+ tabularClassification(
1177
+ args: Omit<TabularClassificationArgs, 'accessToken'>,
1178
+ options?: Options
1179
+ ): Promise<TabularClassificationOutput>;
1072
1180
  /**
1073
1181
  * Predicts target value for a given set of features in tabular form.
1074
1182
  * Typically, you will want to train a regression model on your training data and use it with your new data of the same format.
@@ -1089,6 +1197,11 @@ export class HfInferenceEndpoint {
1089
1197
  args: Omit<AudioClassificationArgs, 'accessToken' | 'model'>,
1090
1198
  options?: Options
1091
1199
  ): Promise<AudioClassificationReturn>;
1200
+ /**
1201
+ * This task reads some audio input and outputs one or multiple audio files.
1202
+ * Example model: speechbrain/sepformer-wham does audio source separation.
1203
+ */
1204
+ audioToAudio(args: Omit<AudioToAudioArgs, 'accessToken' | 'model'>, options?: Options): Promise<AudioToAudioReturn>;
1092
1205
  /**
1093
1206
  * This task reads some audio input and outputs the said words within the audio files.
1094
1207
  * Recommended model (english language): facebook/wav2vec2-large-960h-lv60-self
@@ -1110,6 +1223,8 @@ export class HfInferenceEndpoint {
1110
1223
  options?: Options & {
1111
1224
  /** For internal HF use, which is why it's not exposed in {@link Options} */
1112
1225
  includeCredentials?: boolean;
1226
+ /** When a model can be used for multiple tasks, and we want to run a non-default task */
1227
+ task?: string | InferenceTask;
1113
1228
  }
1114
1229
  ): Promise<T>;
1115
1230
  /**
@@ -1120,6 +1235,8 @@ export class HfInferenceEndpoint {
1120
1235
  options?: Options & {
1121
1236
  /** For internal HF use, which is why it's not exposed in {@link Options} */
1122
1237
  includeCredentials?: boolean;
1238
+ /** When a model can be used for multiple tasks, and we want to run a non-default task */
1239
+ task?: string | InferenceTask;
1123
1240
  }
1124
1241
  ): AsyncGenerator<T>;
1125
1242
  /**
@@ -1157,6 +1274,14 @@ export class HfInferenceEndpoint {
1157
1274
  * Recommended model: stabilityai/stable-diffusion-2
1158
1275
  */
1159
1276
  textToImage(args: Omit<TextToImageArgs, 'accessToken' | 'model'>, options?: Options): Promise<TextToImageOutput>;
1277
+ /**
1278
+ * Classify an image to specified classes.
1279
+ * Recommended model: openai/clip-vit-large-patch14-336
1280
+ */
1281
+ zeroShotImageClassification(
1282
+ args: Omit<ZeroShotImageClassificationArgs, 'accessToken' | 'model'>,
1283
+ options?: Options
1284
+ ): Promise<ZeroShotImageClassificationOutput>;
1160
1285
  /**
1161
1286
  * Answers a question on a document image. Recommended model: impira/layoutlm-document-qa.
1162
1287
  */
@@ -1248,6 +1373,15 @@ export class HfInferenceEndpoint {
1248
1373
  args: Omit<ZeroShotClassificationArgs, 'accessToken' | 'model'>,
1249
1374
  options?: Options
1250
1375
  ): Promise<ZeroShotClassificationOutput>;
1376
+ /**
1377
+ * Predicts target label for a given set of features in tabular form.
1378
+ * Typically, you will want to train a classification model on your training data and use it with your new data of the same format.
1379
+ * Example model: vvmnnnkv/wine-quality
1380
+ */
1381
+ tabularClassification(
1382
+ args: Omit<TabularClassificationArgs, 'accessToken' | 'model'>,
1383
+ options?: Options
1384
+ ): Promise<TabularClassificationOutput>;
1251
1385
  /**
1252
1386
  * Predicts target value for a given set of features in tabular form.
1253
1387
  * Typically, you will want to train a regression model on your training data and use it with your new data of the same format.
package/dist/index.js CHANGED
@@ -25,6 +25,7 @@ __export(src_exports, {
25
25
  HfInferenceEndpoint: () => HfInferenceEndpoint,
26
26
  InferenceOutputError: () => InferenceOutputError,
27
27
  audioClassification: () => audioClassification,
28
+ audioToAudio: () => audioToAudio,
28
29
  automaticSpeechRecognition: () => automaticSpeechRecognition,
29
30
  conversational: () => conversational,
30
31
  documentQuestionAnswering: () => documentQuestionAnswering,
@@ -41,6 +42,7 @@ __export(src_exports, {
41
42
  streamingRequest: () => streamingRequest,
42
43
  summarization: () => summarization,
43
44
  tableQuestionAnswering: () => tableQuestionAnswering,
45
+ tabularClassification: () => tabularClassification,
44
46
  tabularRegression: () => tabularRegression,
45
47
  textClassification: () => textClassification,
46
48
  textGeneration: () => textGeneration,
@@ -50,7 +52,8 @@ __export(src_exports, {
50
52
  tokenClassification: () => tokenClassification,
51
53
  translation: () => translation,
52
54
  visualQuestionAnswering: () => visualQuestionAnswering,
53
- zeroShotClassification: () => zeroShotClassification
55
+ zeroShotClassification: () => zeroShotClassification,
56
+ zeroShotImageClassification: () => zeroShotImageClassification
54
57
  });
55
58
  module.exports = __toCommonJS(src_exports);
56
59
 
@@ -58,6 +61,7 @@ module.exports = __toCommonJS(src_exports);
58
61
  var tasks_exports = {};
59
62
  __export(tasks_exports, {
60
63
  audioClassification: () => audioClassification,
64
+ audioToAudio: () => audioToAudio,
61
65
  automaticSpeechRecognition: () => automaticSpeechRecognition,
62
66
  conversational: () => conversational,
63
67
  documentQuestionAnswering: () => documentQuestionAnswering,
@@ -74,6 +78,7 @@ __export(tasks_exports, {
74
78
  streamingRequest: () => streamingRequest,
75
79
  summarization: () => summarization,
76
80
  tableQuestionAnswering: () => tableQuestionAnswering,
81
+ tabularClassification: () => tabularClassification,
77
82
  tabularRegression: () => tabularRegression,
78
83
  textClassification: () => textClassification,
79
84
  textGeneration: () => textGeneration,
@@ -83,13 +88,20 @@ __export(tasks_exports, {
83
88
  tokenClassification: () => tokenClassification,
84
89
  translation: () => translation,
85
90
  visualQuestionAnswering: () => visualQuestionAnswering,
86
- zeroShotClassification: () => zeroShotClassification
91
+ zeroShotClassification: () => zeroShotClassification,
92
+ zeroShotImageClassification: () => zeroShotImageClassification
87
93
  });
88
94
 
95
+ // src/lib/isUrl.ts
96
+ function isUrl(modelOrUrl) {
97
+ return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
98
+ }
99
+
89
100
  // src/lib/makeRequestOptions.ts
90
- var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co/models/";
101
+ var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
91
102
  function makeRequestOptions(args, options) {
92
103
  const { model, accessToken, ...otherArgs } = args;
104
+ const { task, includeCredentials, ...otherOptions } = options ?? {};
93
105
  const headers = {};
94
106
  if (accessToken) {
95
107
  headers["Authorization"] = `Bearer ${accessToken}`;
@@ -108,15 +120,23 @@ function makeRequestOptions(args, options) {
108
120
  headers["X-Load-Model"] = "0";
109
121
  }
110
122
  }
111
- const url = /^http(s?):/.test(model) || model.startsWith("/") ? model : `${HF_INFERENCE_API_BASE_URL}${model}`;
123
+ const url = (() => {
124
+ if (isUrl(model)) {
125
+ return model;
126
+ }
127
+ if (task) {
128
+ return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
129
+ }
130
+ return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
131
+ })();
112
132
  const info = {
113
133
  headers,
114
134
  method: "POST",
115
135
  body: binary ? args.data : JSON.stringify({
116
136
  ...otherArgs,
117
- options
137
+ options: options && otherOptions
118
138
  }),
119
- credentials: options?.includeCredentials ? "include" : "same-origin"
139
+ credentials: includeCredentials ? "include" : "same-origin"
120
140
  };
121
141
  return { url, info };
122
142
  }
@@ -348,6 +368,18 @@ async function textToSpeech(args, options) {
348
368
  return res;
349
369
  }
350
370
 
371
+ // src/tasks/audio/audioToAudio.ts
372
+ async function audioToAudio(args, options) {
373
+ const res = await request(args, options);
374
+ const isValidOutput = Array.isArray(res) && res.every(
375
+ (x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string"
376
+ );
377
+ if (!isValidOutput) {
378
+ throw new InferenceOutputError("Expected Array<{label: string, blob: string, content-type: string}>");
379
+ }
380
+ return res;
381
+ }
382
+
351
383
  // src/tasks/cv/imageClassification.ts
352
384
  async function imageClassification(args, options) {
353
385
  const res = await request(args, options);
@@ -443,6 +475,26 @@ async function imageToImage(args, options) {
443
475
  return res;
444
476
  }
445
477
 
478
+ // src/tasks/cv/zeroShotImageClassification.ts
479
+ async function zeroShotImageClassification(args, options) {
480
+ const reqArgs = {
481
+ ...args,
482
+ inputs: {
483
+ image: base64FromBytes(
484
+ new Uint8Array(
485
+ args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
486
+ )
487
+ )
488
+ }
489
+ };
490
+ const res = await request(reqArgs, options);
491
+ const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
492
+ if (!isValidOutput) {
493
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
494
+ }
495
+ return res;
496
+ }
497
+
446
498
  // src/tasks/nlp/conversational.ts
447
499
  async function conversational(args, options) {
448
500
  const res = await request(args, options);
@@ -455,27 +507,60 @@ async function conversational(args, options) {
455
507
  return res;
456
508
  }
457
509
 
510
+ // src/lib/getDefaultTask.ts
511
+ var taskCache = /* @__PURE__ */ new Map();
512
+ var CACHE_DURATION = 10 * 60 * 1e3;
513
+ var MAX_CACHE_ITEMS = 1e3;
514
+ var HF_HUB_URL = "https://huggingface.co";
515
+ async function getDefaultTask(model, accessToken) {
516
+ if (isUrl(model)) {
517
+ return null;
518
+ }
519
+ const key = `${model}:${accessToken}`;
520
+ let cachedTask = taskCache.get(key);
521
+ if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
522
+ taskCache.delete(key);
523
+ cachedTask = void 0;
524
+ }
525
+ if (cachedTask === void 0) {
526
+ const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
527
+ headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
528
+ }).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
529
+ if (!modelTask) {
530
+ return null;
531
+ }
532
+ cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
533
+ taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
534
+ if (taskCache.size > MAX_CACHE_ITEMS) {
535
+ taskCache.delete(taskCache.keys().next().value);
536
+ }
537
+ }
538
+ return cachedTask.task;
539
+ }
540
+
458
541
  // src/tasks/nlp/featureExtraction.ts
459
542
  async function featureExtraction(args, options) {
460
- const res = await request(args, options);
543
+ const defaultTask = await getDefaultTask(args.model, args.accessToken);
544
+ const res = await request(
545
+ args,
546
+ defaultTask === "sentence-similarity" ? {
547
+ ...options,
548
+ task: "feature-extraction"
549
+ } : options
550
+ );
461
551
  let isValidOutput = true;
462
- if (Array.isArray(res)) {
463
- for (const e of res) {
464
- if (Array.isArray(e)) {
465
- isValidOutput = e.every((x) => typeof x === "number");
466
- if (!isValidOutput) {
467
- break;
468
- }
469
- } else if (typeof e !== "number") {
470
- isValidOutput = false;
471
- break;
472
- }
552
+ const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
553
+ if (curDepth > maxDepth)
554
+ return false;
555
+ if (arr.every((x) => Array.isArray(x))) {
556
+ return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
557
+ } else {
558
+ return arr.every((x) => typeof x === "number");
473
559
  }
474
- } else {
475
- isValidOutput = false;
476
- }
560
+ };
561
+ isValidOutput = Array.isArray(res) && isNumArrayRec(res, 2, 0);
477
562
  if (!isValidOutput) {
478
- throw new InferenceOutputError("Expected Array<number[] | number>");
563
+ throw new InferenceOutputError("Expected Array<number[][] | number[] | number>");
479
564
  }
480
565
  return res;
481
566
  }
@@ -506,7 +591,14 @@ async function questionAnswering(args, options) {
506
591
 
507
592
  // src/tasks/nlp/sentenceSimilarity.ts
508
593
  async function sentenceSimilarity(args, options) {
509
- const res = await request(args, options);
594
+ const defaultTask = await getDefaultTask(args.model, args.accessToken);
595
+ const res = await request(
596
+ args,
597
+ defaultTask === "feature-extraction" ? {
598
+ ...options,
599
+ task: "sentence-similarity"
600
+ } : options
601
+ );
510
602
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
511
603
  if (!isValidOutput) {
512
604
  throw new InferenceOutputError("Expected number[]");
@@ -663,6 +755,16 @@ async function tabularRegression(args, options) {
663
755
  return res;
664
756
  }
665
757
 
758
+ // src/tasks/tabular/tabularClassification.ts
759
+ async function tabularClassification(args, options) {
760
+ const res = await request(args, options);
761
+ const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
762
+ if (!isValidOutput) {
763
+ throw new InferenceOutputError("Expected number[]");
764
+ }
765
+ return res;
766
+ }
767
+
666
768
  // src/HfInference.ts
667
769
  var HfInference = class {
668
770
  accessToken;
@@ -708,6 +810,7 @@ var HfInferenceEndpoint = class {
708
810
  HfInferenceEndpoint,
709
811
  InferenceOutputError,
710
812
  audioClassification,
813
+ audioToAudio,
711
814
  automaticSpeechRecognition,
712
815
  conversational,
713
816
  documentQuestionAnswering,
@@ -724,6 +827,7 @@ var HfInferenceEndpoint = class {
724
827
  streamingRequest,
725
828
  summarization,
726
829
  tableQuestionAnswering,
830
+ tabularClassification,
727
831
  tabularRegression,
728
832
  textClassification,
729
833
  textGeneration,
@@ -733,5 +837,6 @@ var HfInferenceEndpoint = class {
733
837
  tokenClassification,
734
838
  translation,
735
839
  visualQuestionAnswering,
736
- zeroShotClassification
840
+ zeroShotClassification,
841
+ zeroShotImageClassification
737
842
  });
package/dist/index.mjs CHANGED
@@ -9,6 +9,7 @@ var __export = (target, all) => {
9
9
  var tasks_exports = {};
10
10
  __export(tasks_exports, {
11
11
  audioClassification: () => audioClassification,
12
+ audioToAudio: () => audioToAudio,
12
13
  automaticSpeechRecognition: () => automaticSpeechRecognition,
13
14
  conversational: () => conversational,
14
15
  documentQuestionAnswering: () => documentQuestionAnswering,
@@ -25,6 +26,7 @@ __export(tasks_exports, {
25
26
  streamingRequest: () => streamingRequest,
26
27
  summarization: () => summarization,
27
28
  tableQuestionAnswering: () => tableQuestionAnswering,
29
+ tabularClassification: () => tabularClassification,
28
30
  tabularRegression: () => tabularRegression,
29
31
  textClassification: () => textClassification,
30
32
  textGeneration: () => textGeneration,
@@ -34,13 +36,20 @@ __export(tasks_exports, {
34
36
  tokenClassification: () => tokenClassification,
35
37
  translation: () => translation,
36
38
  visualQuestionAnswering: () => visualQuestionAnswering,
37
- zeroShotClassification: () => zeroShotClassification
39
+ zeroShotClassification: () => zeroShotClassification,
40
+ zeroShotImageClassification: () => zeroShotImageClassification
38
41
  });
39
42
 
43
+ // src/lib/isUrl.ts
44
+ function isUrl(modelOrUrl) {
45
+ return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
46
+ }
47
+
40
48
  // src/lib/makeRequestOptions.ts
41
- var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co/models/";
49
+ var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
42
50
  function makeRequestOptions(args, options) {
43
51
  const { model, accessToken, ...otherArgs } = args;
52
+ const { task, includeCredentials, ...otherOptions } = options ?? {};
44
53
  const headers = {};
45
54
  if (accessToken) {
46
55
  headers["Authorization"] = `Bearer ${accessToken}`;
@@ -59,15 +68,23 @@ function makeRequestOptions(args, options) {
59
68
  headers["X-Load-Model"] = "0";
60
69
  }
61
70
  }
62
- const url = /^http(s?):/.test(model) || model.startsWith("/") ? model : `${HF_INFERENCE_API_BASE_URL}${model}`;
71
+ const url = (() => {
72
+ if (isUrl(model)) {
73
+ return model;
74
+ }
75
+ if (task) {
76
+ return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
77
+ }
78
+ return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
79
+ })();
63
80
  const info = {
64
81
  headers,
65
82
  method: "POST",
66
83
  body: binary ? args.data : JSON.stringify({
67
84
  ...otherArgs,
68
- options
85
+ options: options && otherOptions
69
86
  }),
70
- credentials: options?.includeCredentials ? "include" : "same-origin"
87
+ credentials: includeCredentials ? "include" : "same-origin"
71
88
  };
72
89
  return { url, info };
73
90
  }
@@ -299,6 +316,18 @@ async function textToSpeech(args, options) {
299
316
  return res;
300
317
  }
301
318
 
319
+ // src/tasks/audio/audioToAudio.ts
320
+ async function audioToAudio(args, options) {
321
+ const res = await request(args, options);
322
+ const isValidOutput = Array.isArray(res) && res.every(
323
+ (x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string"
324
+ );
325
+ if (!isValidOutput) {
326
+ throw new InferenceOutputError("Expected Array<{label: string, blob: string, content-type: string}>");
327
+ }
328
+ return res;
329
+ }
330
+
302
331
  // src/tasks/cv/imageClassification.ts
303
332
  async function imageClassification(args, options) {
304
333
  const res = await request(args, options);
@@ -394,6 +423,26 @@ async function imageToImage(args, options) {
394
423
  return res;
395
424
  }
396
425
 
426
+ // src/tasks/cv/zeroShotImageClassification.ts
427
+ async function zeroShotImageClassification(args, options) {
428
+ const reqArgs = {
429
+ ...args,
430
+ inputs: {
431
+ image: base64FromBytes(
432
+ new Uint8Array(
433
+ args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
434
+ )
435
+ )
436
+ }
437
+ };
438
+ const res = await request(reqArgs, options);
439
+ const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
440
+ if (!isValidOutput) {
441
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
442
+ }
443
+ return res;
444
+ }
445
+
397
446
  // src/tasks/nlp/conversational.ts
398
447
  async function conversational(args, options) {
399
448
  const res = await request(args, options);
@@ -406,27 +455,60 @@ async function conversational(args, options) {
406
455
  return res;
407
456
  }
408
457
 
458
+ // src/lib/getDefaultTask.ts
459
+ var taskCache = /* @__PURE__ */ new Map();
460
+ var CACHE_DURATION = 10 * 60 * 1e3;
461
+ var MAX_CACHE_ITEMS = 1e3;
462
+ var HF_HUB_URL = "https://huggingface.co";
463
+ async function getDefaultTask(model, accessToken) {
464
+ if (isUrl(model)) {
465
+ return null;
466
+ }
467
+ const key = `${model}:${accessToken}`;
468
+ let cachedTask = taskCache.get(key);
469
+ if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
470
+ taskCache.delete(key);
471
+ cachedTask = void 0;
472
+ }
473
+ if (cachedTask === void 0) {
474
+ const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
475
+ headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
476
+ }).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
477
+ if (!modelTask) {
478
+ return null;
479
+ }
480
+ cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
481
+ taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
482
+ if (taskCache.size > MAX_CACHE_ITEMS) {
483
+ taskCache.delete(taskCache.keys().next().value);
484
+ }
485
+ }
486
+ return cachedTask.task;
487
+ }
488
+
409
489
  // src/tasks/nlp/featureExtraction.ts
410
490
  async function featureExtraction(args, options) {
411
- const res = await request(args, options);
491
+ const defaultTask = await getDefaultTask(args.model, args.accessToken);
492
+ const res = await request(
493
+ args,
494
+ defaultTask === "sentence-similarity" ? {
495
+ ...options,
496
+ task: "feature-extraction"
497
+ } : options
498
+ );
412
499
  let isValidOutput = true;
413
- if (Array.isArray(res)) {
414
- for (const e of res) {
415
- if (Array.isArray(e)) {
416
- isValidOutput = e.every((x) => typeof x === "number");
417
- if (!isValidOutput) {
418
- break;
419
- }
420
- } else if (typeof e !== "number") {
421
- isValidOutput = false;
422
- break;
423
- }
500
+ const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
501
+ if (curDepth > maxDepth)
502
+ return false;
503
+ if (arr.every((x) => Array.isArray(x))) {
504
+ return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
505
+ } else {
506
+ return arr.every((x) => typeof x === "number");
424
507
  }
425
- } else {
426
- isValidOutput = false;
427
- }
508
+ };
509
+ isValidOutput = Array.isArray(res) && isNumArrayRec(res, 2, 0);
428
510
  if (!isValidOutput) {
429
- throw new InferenceOutputError("Expected Array<number[] | number>");
511
+ throw new InferenceOutputError("Expected Array<number[][] | number[] | number>");
430
512
  }
431
513
  return res;
432
514
  }
@@ -457,7 +539,14 @@ async function questionAnswering(args, options) {
457
539
 
458
540
  // src/tasks/nlp/sentenceSimilarity.ts
459
541
  async function sentenceSimilarity(args, options) {
460
- const res = await request(args, options);
542
+ const defaultTask = await getDefaultTask(args.model, args.accessToken);
543
+ const res = await request(
544
+ args,
545
+ defaultTask === "feature-extraction" ? {
546
+ ...options,
547
+ task: "sentence-similarity"
548
+ } : options
549
+ );
461
550
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
462
551
  if (!isValidOutput) {
463
552
  throw new InferenceOutputError("Expected number[]");
@@ -614,6 +703,16 @@ async function tabularRegression(args, options) {
614
703
  return res;
615
704
  }
616
705
 
706
+ // src/tasks/tabular/tabularClassification.ts
707
+ async function tabularClassification(args, options) {
708
+ const res = await request(args, options);
709
+ const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
710
+ if (!isValidOutput) {
711
+ throw new InferenceOutputError("Expected number[]");
712
+ }
713
+ return res;
714
+ }
715
+
617
716
  // src/HfInference.ts
618
717
  var HfInference = class {
619
718
  accessToken;
@@ -658,6 +757,7 @@ export {
658
757
  HfInferenceEndpoint,
659
758
  InferenceOutputError,
660
759
  audioClassification,
760
+ audioToAudio,
661
761
  automaticSpeechRecognition,
662
762
  conversational,
663
763
  documentQuestionAnswering,
@@ -674,6 +774,7 @@ export {
674
774
  streamingRequest,
675
775
  summarization,
676
776
  tableQuestionAnswering,
777
+ tabularClassification,
677
778
  tabularRegression,
678
779
  textClassification,
679
780
  textGeneration,
@@ -683,5 +784,6 @@ export {
683
784
  tokenClassification,
684
785
  translation,
685
786
  visualQuestionAnswering,
686
- zeroShotClassification
787
+ zeroShotClassification,
788
+ zeroShotImageClassification
687
789
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@huggingface/inference",
3
- "version": "2.3.3",
3
+ "version": "2.5.0",
4
4
  "packageManager": "pnpm@8.3.1",
5
5
  "license": "MIT",
6
6
  "author": "Tim Mikeladze <tim.mikeladze@gmail.com>",
@@ -0,0 +1,53 @@
1
+ import { isUrl } from "./isUrl";
2
+
3
+ /**
4
+ * We want to make calls to the huggingface hub the least possible, eg if
5
+ * someone is calling the inference API 1000 times per second, we don't want
6
+ * to make 1000 calls to the hub to get the task name.
7
+ */
8
+ const taskCache = new Map<string, { task: string; date: Date }>();
9
+ const CACHE_DURATION = 10 * 60 * 1000;
10
+ const MAX_CACHE_ITEMS = 1000;
11
+ const HF_HUB_URL = "https://huggingface.co";
12
+
13
+ /**
14
+ * Get the default task. Use a LRU cache of 1000 items with 10 minutes expiration
15
+ * to avoid making too many calls to the HF hub.
16
+ *
17
+ * @returns The default task for the model, or `null` if it was impossible to get it
18
+ */
19
+ export async function getDefaultTask(model: string, accessToken: string | undefined): Promise<string | null> {
20
+ if (isUrl(model)) {
21
+ return null;
22
+ }
23
+
24
+ const key = `${model}:${accessToken}`;
25
+ let cachedTask = taskCache.get(key);
26
+
27
+ if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
28
+ taskCache.delete(key);
29
+ cachedTask = undefined;
30
+ }
31
+
32
+ if (cachedTask === undefined) {
33
+ const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
34
+ headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {},
35
+ })
36
+ .then((resp) => resp.json())
37
+ .then((json) => json.pipeline_tag)
38
+ .catch(() => null);
39
+
40
+ if (!modelTask) {
41
+ return null;
42
+ }
43
+
44
+ cachedTask = { task: modelTask, date: new Date() };
45
+ taskCache.set(key, { task: modelTask, date: new Date() });
46
+
47
+ if (taskCache.size > MAX_CACHE_ITEMS) {
48
+ taskCache.delete(taskCache.keys().next().value);
49
+ }
50
+ }
51
+
52
+ return cachedTask.task;
53
+ }
@@ -0,0 +1,3 @@
1
+ export function isUrl(modelOrUrl: string): boolean {
2
+ return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
3
+ }
@@ -1,6 +1,7 @@
1
- import type { Options, RequestArgs } from "../types";
1
+ import type { InferenceTask, Options, RequestArgs } from "../types";
2
+ import { isUrl } from "./isUrl";
2
3
 
3
- const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co/models/";
4
+ const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
4
5
 
5
6
  /**
6
7
  * Helper that prepares request arguments
@@ -13,9 +14,12 @@ export function makeRequestOptions(
13
14
  options?: Options & {
14
15
  /** For internal HF use, which is why it's not exposed in {@link Options} */
15
16
  includeCredentials?: boolean;
17
+ /** When a model can be used for multiple tasks, and we want to run a non-default task */
18
+ task?: string | InferenceTask;
16
19
  }
17
20
  ): { url: string; info: RequestInit } {
18
21
  const { model, accessToken, ...otherArgs } = args;
22
+ const { task, includeCredentials, ...otherOptions } = options ?? {};
19
23
 
20
24
  const headers: Record<string, string> = {};
21
25
  if (accessToken) {
@@ -38,7 +42,18 @@ export function makeRequestOptions(
38
42
  }
39
43
  }
40
44
 
41
- const url = /^http(s?):/.test(model) || model.startsWith("/") ? model : `${HF_INFERENCE_API_BASE_URL}${model}`;
45
+ const url = (() => {
46
+ if (isUrl(model)) {
47
+ return model;
48
+ }
49
+
50
+ if (task) {
51
+ return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
52
+ }
53
+
54
+ return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
55
+ })();
56
+
42
57
  const info: RequestInit = {
43
58
  headers,
44
59
  method: "POST",
@@ -46,9 +61,9 @@ export function makeRequestOptions(
46
61
  ? args.data
47
62
  : JSON.stringify({
48
63
  ...otherArgs,
49
- options,
64
+ options: options && otherOptions,
50
65
  }),
51
- credentials: options?.includeCredentials ? "include" : "same-origin",
66
+ credentials: includeCredentials ? "include" : "same-origin",
52
67
  };
53
68
 
54
69
  return { url, info };
@@ -0,0 +1,46 @@
1
+ import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import type { BaseArgs, Options } from "../../types";
3
+ import { request } from "../custom/request";
4
+
5
+ export type AudioToAudioArgs = BaseArgs & {
6
+ /**
7
+ * Binary audio data
8
+ */
9
+ data: Blob | ArrayBuffer;
10
+ };
11
+
12
+ export interface AudioToAudioOutputValue {
13
+ /**
14
+ * The label for the audio output (model specific)
15
+ */
16
+ label: string;
17
+
18
+ /**
19
+ * Base64 encoded audio output.
20
+ */
21
+ blob: string;
22
+
23
+ /**
24
+ * Content-type for blob, e.g. audio/flac
25
+ */
26
+ "content-type": string;
27
+ }
28
+
29
+ export type AudioToAudioReturn = AudioToAudioOutputValue[];
30
+
31
+ /**
32
+ * This task reads some audio input and outputs one or multiple audio files.
33
+ * Example model: speechbrain/sepformer-wham does audio source separation.
34
+ */
35
+ export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioReturn> {
36
+ const res = await request<AudioToAudioReturn>(args, options);
37
+ const isValidOutput =
38
+ Array.isArray(res) &&
39
+ res.every(
40
+ (x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string"
41
+ );
42
+ if (!isValidOutput) {
43
+ throw new InferenceOutputError("Expected Array<{label: string, blob: string, content-type: string}>");
44
+ }
45
+ return res;
46
+ }
@@ -1,4 +1,4 @@
1
- import type { Options, RequestArgs } from "../../types";
1
+ import type { InferenceTask, Options, RequestArgs } from "../../types";
2
2
  import { makeRequestOptions } from "../../lib/makeRequestOptions";
3
3
 
4
4
  /**
@@ -9,6 +9,8 @@ export async function request<T>(
9
9
  options?: Options & {
10
10
  /** For internal HF use, which is why it's not exposed in {@link Options} */
11
11
  includeCredentials?: boolean;
12
+ /** When a model can be used for multiple tasks, and we want to run a non-default task */
13
+ task?: string | InferenceTask;
12
14
  }
13
15
  ): Promise<T> {
14
16
  const { url, info } = makeRequestOptions(args, options);
@@ -1,4 +1,4 @@
1
- import type { Options, RequestArgs } from "../../types";
1
+ import type { InferenceTask, Options, RequestArgs } from "../../types";
2
2
  import { makeRequestOptions } from "../../lib/makeRequestOptions";
3
3
  import type { EventSourceMessage } from "../../vendor/fetch-event-source/parse";
4
4
  import { getLines, getMessages } from "../../vendor/fetch-event-source/parse";
@@ -11,6 +11,8 @@ export async function* streamingRequest<T>(
11
11
  options?: Options & {
12
12
  /** For internal HF use, which is why it's not exposed in {@link Options} */
13
13
  includeCredentials?: boolean;
14
+ /** When a model can be used for multiple tasks, and we want to run a non-default task */
15
+ task?: string | InferenceTask;
14
16
  }
15
17
  ): AsyncGenerator<T> {
16
18
  const { url, info } = makeRequestOptions({ ...args, stream: true }, options);
@@ -0,0 +1,55 @@
1
+ import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import type { BaseArgs, Options } from "../../types";
3
+ import { request } from "../custom/request";
4
+ import type { RequestArgs } from "../../types";
5
+ import { base64FromBytes } from "../../../../shared";
6
+
7
+ export type ZeroShotImageClassificationArgs = BaseArgs & {
8
+ inputs: {
9
+ /**
10
+ * Binary image data
11
+ */
12
+ image: Blob | ArrayBuffer;
13
+ };
14
+ parameters: {
15
+ /**
16
+ * A list of strings that are potential classes for inputs. (max 10)
17
+ */
18
+ candidate_labels: string[];
19
+ };
20
+ };
21
+
22
+ export interface ZeroShotImageClassificationOutputValue {
23
+ label: string;
24
+ score: number;
25
+ }
26
+
27
+ export type ZeroShotImageClassificationOutput = ZeroShotImageClassificationOutputValue[];
28
+
29
+ /**
30
+ * Classify an image to specified classes.
31
+ * Recommended model: openai/clip-vit-large-patch14-336
32
+ */
33
+ export async function zeroShotImageClassification(
34
+ args: ZeroShotImageClassificationArgs,
35
+ options?: Options
36
+ ): Promise<ZeroShotImageClassificationOutput> {
37
+ const reqArgs: RequestArgs = {
38
+ ...args,
39
+ inputs: {
40
+ image: base64FromBytes(
41
+ new Uint8Array(
42
+ args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
43
+ )
44
+ ),
45
+ },
46
+ } as RequestArgs;
47
+
48
+ const res = await request<ZeroShotImageClassificationOutput>(reqArgs, options);
49
+ const isValidOutput =
50
+ Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
51
+ if (!isValidOutput) {
52
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
53
+ }
54
+ return res;
55
+ }
@@ -6,6 +6,7 @@ export * from "./custom/streamingRequest";
6
6
  export * from "./audio/audioClassification";
7
7
  export * from "./audio/automaticSpeechRecognition";
8
8
  export * from "./audio/textToSpeech";
9
+ export * from "./audio/audioToAudio";
9
10
 
10
11
  // Computer Vision tasks
11
12
  export * from "./cv/imageClassification";
@@ -14,6 +15,7 @@ export * from "./cv/imageToText";
14
15
  export * from "./cv/objectDetection";
15
16
  export * from "./cv/textToImage";
16
17
  export * from "./cv/imageToImage";
18
+ export * from "./cv/zeroShotImageClassification";
17
19
 
18
20
  // Natural Language Processing tasks
19
21
  export * from "./nlp/conversational";
@@ -36,3 +38,4 @@ export * from "./multimodal/visualQuestionAnswering";
36
38
 
37
39
  // Tabular tasks
38
40
  export * from "./tabular/tabularRegression";
41
+ export * from "./tabular/tabularClassification";
@@ -1,4 +1,5 @@
1
1
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getDefaultTask } from "../../lib/getDefaultTask";
2
3
  import type { BaseArgs, Options } from "../../types";
3
4
  import { request } from "../custom/request";
4
5
 
@@ -13,9 +14,9 @@ export type FeatureExtractionArgs = BaseArgs & {
13
14
  };
14
15
 
15
16
  /**
16
- * Returned values are a list of floats, or a list of list of floats (depending on if you sent a string or a list of string, and if the automatic reduction, usually mean_pooling for instance was applied for you or not. This should be explained on the model's README.
17
+ * Returned values are a list of floats, or a list of list of floats, or a list of list of list of floats (depending on if you sent a string or a list of string, and if the automatic reduction, usually mean_pooling for instance was applied for you or not. This should be explained on the model's README.
17
18
  */
18
- export type FeatureExtractionOutput = (number | number[])[];
19
+ export type FeatureExtractionOutput = (number | number[] | number[][])[];
19
20
 
20
21
  /**
21
22
  * This task reads some text and outputs raw float values, that are usually consumed as part of a semantic database/semantic search.
@@ -24,28 +25,31 @@ export async function featureExtraction(
24
25
  args: FeatureExtractionArgs,
25
26
  options?: Options
26
27
  ): Promise<FeatureExtractionOutput> {
27
- const res = await request<FeatureExtractionOutput>(args, options);
28
+ const defaultTask = await getDefaultTask(args.model, args.accessToken);
29
+ const res = await request<FeatureExtractionOutput>(
30
+ args,
31
+ defaultTask === "sentence-similarity"
32
+ ? {
33
+ ...options,
34
+ task: "feature-extraction",
35
+ }
36
+ : options
37
+ );
28
38
  let isValidOutput = true;
29
- // Check if output is an array
30
- if (Array.isArray(res)) {
31
- for (const e of res) {
32
- // Check if output is an array of arrays or numbers
33
- if (Array.isArray(e)) {
34
- // if all elements are numbers, continue
35
- isValidOutput = e.every((x) => typeof x === "number");
36
- if (!isValidOutput) {
37
- break;
38
- }
39
- } else if (typeof e !== "number") {
40
- isValidOutput = false;
41
- break;
42
- }
39
+
40
+ const isNumArrayRec = (arr: unknown[], maxDepth: number, curDepth = 0): boolean => {
41
+ if (curDepth > maxDepth) return false;
42
+ if (arr.every((x) => Array.isArray(x))) {
43
+ return arr.every((x) => isNumArrayRec(x as unknown[], maxDepth, curDepth + 1));
44
+ } else {
45
+ return arr.every((x) => typeof x === "number");
43
46
  }
44
- } else {
45
- isValidOutput = false;
46
- }
47
+ };
48
+
49
+ isValidOutput = Array.isArray(res) && isNumArrayRec(res, 2, 0);
50
+
47
51
  if (!isValidOutput) {
48
- throw new InferenceOutputError("Expected Array<number[] | number>");
52
+ throw new InferenceOutputError("Expected Array<number[][] | number[] | number>");
49
53
  }
50
54
  return res;
51
55
  }
@@ -1,4 +1,5 @@
1
1
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getDefaultTask } from "../../lib/getDefaultTask";
2
3
  import type { BaseArgs, Options } from "../../types";
3
4
  import { request } from "../custom/request";
4
5
 
@@ -24,7 +25,16 @@ export async function sentenceSimilarity(
24
25
  args: SentenceSimilarityArgs,
25
26
  options?: Options
26
27
  ): Promise<SentenceSimilarityOutput> {
27
- const res = await request<SentenceSimilarityOutput>(args, options);
28
+ const defaultTask = await getDefaultTask(args.model, args.accessToken);
29
+ const res = await request<SentenceSimilarityOutput>(
30
+ args,
31
+ defaultTask === "feature-extraction"
32
+ ? {
33
+ ...options,
34
+ task: "sentence-similarity",
35
+ }
36
+ : options
37
+ );
28
38
 
29
39
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
30
40
  if (!isValidOutput) {
@@ -0,0 +1,34 @@
1
+ import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import type { BaseArgs, Options } from "../../types";
3
+ import { request } from "../custom/request";
4
+
5
+ export type TabularClassificationArgs = BaseArgs & {
6
+ inputs: {
7
+ /**
8
+ * A table of data represented as a dict of list where entries are headers and the lists are all the values, all lists must have the same size.
9
+ */
10
+ data: Record<string, string[]>;
11
+ };
12
+ };
13
+
14
+ /**
15
+ * A list of predicted labels for each row
16
+ */
17
+ export type TabularClassificationOutput = number[];
18
+
19
+ /**
20
+ * Predicts target label for a given set of features in tabular form.
21
+ * Typically, you will want to train a classification model on your training data and use it with your new data of the same format.
22
+ * Example model: vvmnnnkv/wine-quality
23
+ */
24
+ export async function tabularClassification(
25
+ args: TabularClassificationArgs,
26
+ options?: Options
27
+ ): Promise<TabularClassificationOutput> {
28
+ const res = await request<TabularClassificationOutput>(args, options);
29
+ const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
30
+ if (!isValidOutput) {
31
+ throw new InferenceOutputError("Expected number[]");
32
+ }
33
+ return res;
34
+ }
package/src/types.ts CHANGED
@@ -26,6 +26,8 @@ export interface Options {
26
26
  fetch?: typeof fetch;
27
27
  }
28
28
 
29
+ export type InferenceTask = "text-classification" | "feature-extraction" | "sentence-similarity";
30
+
29
31
  export interface BaseArgs {
30
32
  /**
31
33
  * The access token to use. Without it, you'll get rate-limited quickly.