@huggingface/inference 2.5.2 → 2.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.d.ts +48 -2
- package/dist/index.js +168 -78
- package/dist/index.mjs +168 -78
- package/package.json +1 -1
- package/src/lib/getDefaultTask.ts +1 -1
- package/src/lib/makeRequestOptions.ts +34 -5
- package/src/tasks/audio/audioClassification.ts +4 -1
- package/src/tasks/audio/audioToAudio.ts +4 -1
- package/src/tasks/audio/automaticSpeechRecognition.ts +4 -1
- package/src/tasks/audio/textToSpeech.ts +4 -1
- package/src/tasks/custom/request.ts +3 -1
- package/src/tasks/custom/streamingRequest.ts +3 -1
- package/src/tasks/cv/imageClassification.ts +4 -1
- package/src/tasks/cv/imageSegmentation.ts +4 -1
- package/src/tasks/cv/imageToImage.ts +4 -1
- package/src/tasks/cv/imageToText.ts +6 -1
- package/src/tasks/cv/objectDetection.ts +4 -1
- package/src/tasks/cv/textToImage.ts +4 -1
- package/src/tasks/cv/zeroShotImageClassification.ts +4 -1
- package/src/tasks/multimodal/documentQuestionAnswering.ts +4 -1
- package/src/tasks/multimodal/visualQuestionAnswering.ts +6 -1
- package/src/tasks/nlp/conversational.ts +1 -1
- package/src/tasks/nlp/featureExtraction.ts +7 -10
- package/src/tasks/nlp/fillMask.ts +4 -1
- package/src/tasks/nlp/questionAnswering.ts +4 -1
- package/src/tasks/nlp/sentenceSimilarity.ts +6 -10
- package/src/tasks/nlp/summarization.ts +4 -1
- package/src/tasks/nlp/tableQuestionAnswering.ts +4 -1
- package/src/tasks/nlp/textClassification.ts +6 -1
- package/src/tasks/nlp/textGeneration.ts +4 -1
- package/src/tasks/nlp/textGenerationStream.ts +4 -1
- package/src/tasks/nlp/tokenClassification.ts +6 -1
- package/src/tasks/nlp/translation.ts +4 -1
- package/src/tasks/nlp/zeroShotClassification.ts +4 -1
- package/src/tasks/tabular/tabularClassification.ts +4 -1
- package/src/tasks/tabular/tabularRegression.ts +4 -1
- package/src/types.ts +36 -2
package/dist/index.d.ts
CHANGED
|
@@ -26,7 +26,39 @@ export interface Options {
|
|
|
26
26
|
fetch?: typeof fetch;
|
|
27
27
|
}
|
|
28
28
|
|
|
29
|
-
export type InferenceTask =
|
|
29
|
+
export type InferenceTask =
|
|
30
|
+
| "audio-classification"
|
|
31
|
+
| "audio-to-audio"
|
|
32
|
+
| "automatic-speech-recognition"
|
|
33
|
+
| "conversational"
|
|
34
|
+
| "depth-estimation"
|
|
35
|
+
| "document-question-answering"
|
|
36
|
+
| "feature-extraction"
|
|
37
|
+
| "fill-mask"
|
|
38
|
+
| "image-classification"
|
|
39
|
+
| "image-segmentation"
|
|
40
|
+
| "image-to-image"
|
|
41
|
+
| "image-to-text"
|
|
42
|
+
| "object-detection"
|
|
43
|
+
| "video-classification"
|
|
44
|
+
| "question-answering"
|
|
45
|
+
| "reinforcement-learning"
|
|
46
|
+
| "sentence-similarity"
|
|
47
|
+
| "summarization"
|
|
48
|
+
| "table-question-answering"
|
|
49
|
+
| "tabular-classification"
|
|
50
|
+
| "tabular-regression"
|
|
51
|
+
| "text-classification"
|
|
52
|
+
| "text-generation"
|
|
53
|
+
| "text-to-image"
|
|
54
|
+
| "text-to-speech"
|
|
55
|
+
| "text-to-video"
|
|
56
|
+
| "token-classification"
|
|
57
|
+
| "translation"
|
|
58
|
+
| "unconditional-image-generation"
|
|
59
|
+
| "visual-question-answering"
|
|
60
|
+
| "zero-shot-classification"
|
|
61
|
+
| "zero-shot-image-classification";
|
|
30
62
|
|
|
31
63
|
export interface BaseArgs {
|
|
32
64
|
/**
|
|
@@ -37,8 +69,10 @@ export interface BaseArgs {
|
|
|
37
69
|
accessToken?: string;
|
|
38
70
|
/**
|
|
39
71
|
* The model to use. Can be a full URL for HF inference endpoints.
|
|
72
|
+
*
|
|
73
|
+
* If not specified, will call huggingface.co/api/tasks to get the default model for the task.
|
|
40
74
|
*/
|
|
41
|
-
model
|
|
75
|
+
model?: string;
|
|
42
76
|
}
|
|
43
77
|
|
|
44
78
|
export type RequestArgs = BaseArgs &
|
|
@@ -144,6 +178,8 @@ export function request<T>(
|
|
|
144
178
|
includeCredentials?: boolean;
|
|
145
179
|
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
|
146
180
|
task?: string | InferenceTask;
|
|
181
|
+
/** To load default model if needed */
|
|
182
|
+
taskHint?: InferenceTask;
|
|
147
183
|
}
|
|
148
184
|
): Promise<T>;
|
|
149
185
|
/**
|
|
@@ -156,6 +192,8 @@ export function streamingRequest<T>(
|
|
|
156
192
|
includeCredentials?: boolean;
|
|
157
193
|
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
|
158
194
|
task?: string | InferenceTask;
|
|
195
|
+
/** To load default model if needed */
|
|
196
|
+
taskHint?: InferenceTask;
|
|
159
197
|
}
|
|
160
198
|
): AsyncGenerator<T>;
|
|
161
199
|
export type ImageClassificationArgs = BaseArgs & {
|
|
@@ -1020,6 +1058,8 @@ export class HfInference {
|
|
|
1020
1058
|
includeCredentials?: boolean;
|
|
1021
1059
|
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
|
1022
1060
|
task?: string | InferenceTask;
|
|
1061
|
+
/** To load default model if needed */
|
|
1062
|
+
taskHint?: InferenceTask;
|
|
1023
1063
|
}
|
|
1024
1064
|
): Promise<T>;
|
|
1025
1065
|
/**
|
|
@@ -1032,6 +1072,8 @@ export class HfInference {
|
|
|
1032
1072
|
includeCredentials?: boolean;
|
|
1033
1073
|
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
|
1034
1074
|
task?: string | InferenceTask;
|
|
1075
|
+
/** To load default model if needed */
|
|
1076
|
+
taskHint?: InferenceTask;
|
|
1035
1077
|
}
|
|
1036
1078
|
): AsyncGenerator<T>;
|
|
1037
1079
|
/**
|
|
@@ -1225,6 +1267,8 @@ export class HfInferenceEndpoint {
|
|
|
1225
1267
|
includeCredentials?: boolean;
|
|
1226
1268
|
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
|
1227
1269
|
task?: string | InferenceTask;
|
|
1270
|
+
/** To load default model if needed */
|
|
1271
|
+
taskHint?: InferenceTask;
|
|
1228
1272
|
}
|
|
1229
1273
|
): Promise<T>;
|
|
1230
1274
|
/**
|
|
@@ -1237,6 +1281,8 @@ export class HfInferenceEndpoint {
|
|
|
1237
1281
|
includeCredentials?: boolean;
|
|
1238
1282
|
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
|
1239
1283
|
task?: string | InferenceTask;
|
|
1284
|
+
/** To load default model if needed */
|
|
1285
|
+
taskHint?: InferenceTask;
|
|
1240
1286
|
}
|
|
1241
1287
|
): AsyncGenerator<T>;
|
|
1242
1288
|
/**
|
package/dist/index.js
CHANGED
|
@@ -97,15 +97,63 @@ function isUrl(modelOrUrl) {
|
|
|
97
97
|
return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
|
|
98
98
|
}
|
|
99
99
|
|
|
100
|
+
// src/lib/getDefaultTask.ts
|
|
101
|
+
var taskCache = /* @__PURE__ */ new Map();
|
|
102
|
+
var CACHE_DURATION = 10 * 60 * 1e3;
|
|
103
|
+
var MAX_CACHE_ITEMS = 1e3;
|
|
104
|
+
var HF_HUB_URL = "https://huggingface.co";
|
|
105
|
+
async function getDefaultTask(model, accessToken) {
|
|
106
|
+
if (isUrl(model)) {
|
|
107
|
+
return null;
|
|
108
|
+
}
|
|
109
|
+
const key = `${model}:${accessToken}`;
|
|
110
|
+
let cachedTask = taskCache.get(key);
|
|
111
|
+
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
|
|
112
|
+
taskCache.delete(key);
|
|
113
|
+
cachedTask = void 0;
|
|
114
|
+
}
|
|
115
|
+
if (cachedTask === void 0) {
|
|
116
|
+
const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
|
|
117
|
+
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
|
|
118
|
+
}).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
|
|
119
|
+
if (!modelTask) {
|
|
120
|
+
return null;
|
|
121
|
+
}
|
|
122
|
+
cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
|
|
123
|
+
taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
|
|
124
|
+
if (taskCache.size > MAX_CACHE_ITEMS) {
|
|
125
|
+
taskCache.delete(taskCache.keys().next().value);
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
return cachedTask.task;
|
|
129
|
+
}
|
|
130
|
+
|
|
100
131
|
// src/lib/makeRequestOptions.ts
|
|
101
132
|
var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
const {
|
|
133
|
+
var tasks = null;
|
|
134
|
+
async function makeRequestOptions(args, options) {
|
|
135
|
+
const { accessToken, model: _model, ...otherArgs } = args;
|
|
136
|
+
let { model } = args;
|
|
137
|
+
const { forceTask: task, includeCredentials, taskHint, ...otherOptions } = options ?? {};
|
|
105
138
|
const headers = {};
|
|
106
139
|
if (accessToken) {
|
|
107
140
|
headers["Authorization"] = `Bearer ${accessToken}`;
|
|
108
141
|
}
|
|
142
|
+
if (!model && !tasks && taskHint) {
|
|
143
|
+
const res = await fetch(`${HF_HUB_URL}/api/tasks`);
|
|
144
|
+
if (res.ok) {
|
|
145
|
+
tasks = await res.json();
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
if (!model && tasks && taskHint) {
|
|
149
|
+
const taskInfo = tasks[taskHint];
|
|
150
|
+
if (taskInfo) {
|
|
151
|
+
model = taskInfo.models[0].id;
|
|
152
|
+
}
|
|
153
|
+
}
|
|
154
|
+
if (!model) {
|
|
155
|
+
throw new Error("No model provided, and no default model found for this task");
|
|
156
|
+
}
|
|
109
157
|
const binary = "data" in args && !!args.data;
|
|
110
158
|
if (!binary) {
|
|
111
159
|
headers["Content-Type"] = "application/json";
|
|
@@ -143,7 +191,7 @@ function makeRequestOptions(args, options) {
|
|
|
143
191
|
|
|
144
192
|
// src/tasks/custom/request.ts
|
|
145
193
|
async function request(args, options) {
|
|
146
|
-
const { url, info } = makeRequestOptions(args, options);
|
|
194
|
+
const { url, info } = await makeRequestOptions(args, options);
|
|
147
195
|
const response = await (options?.fetch ?? fetch)(url, info);
|
|
148
196
|
if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
|
|
149
197
|
return request(args, {
|
|
@@ -267,7 +315,7 @@ function newMessage() {
|
|
|
267
315
|
|
|
268
316
|
// src/tasks/custom/streamingRequest.ts
|
|
269
317
|
async function* streamingRequest(args, options) {
|
|
270
|
-
const { url, info } = makeRequestOptions({ ...args, stream: true }, options);
|
|
318
|
+
const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
|
|
271
319
|
const response = await (options?.fetch ?? fetch)(url, info);
|
|
272
320
|
if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
|
|
273
321
|
return streamingRequest(args, {
|
|
@@ -340,7 +388,10 @@ var InferenceOutputError = class extends TypeError {
|
|
|
340
388
|
|
|
341
389
|
// src/tasks/audio/audioClassification.ts
|
|
342
390
|
async function audioClassification(args, options) {
|
|
343
|
-
const res = await request(args,
|
|
391
|
+
const res = await request(args, {
|
|
392
|
+
...options,
|
|
393
|
+
taskHint: "audio-classification"
|
|
394
|
+
});
|
|
344
395
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
345
396
|
if (!isValidOutput) {
|
|
346
397
|
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
@@ -350,7 +401,10 @@ async function audioClassification(args, options) {
|
|
|
350
401
|
|
|
351
402
|
// src/tasks/audio/automaticSpeechRecognition.ts
|
|
352
403
|
async function automaticSpeechRecognition(args, options) {
|
|
353
|
-
const res = await request(args,
|
|
404
|
+
const res = await request(args, {
|
|
405
|
+
...options,
|
|
406
|
+
taskHint: "automatic-speech-recognition"
|
|
407
|
+
});
|
|
354
408
|
const isValidOutput = typeof res?.text === "string";
|
|
355
409
|
if (!isValidOutput) {
|
|
356
410
|
throw new InferenceOutputError("Expected {text: string}");
|
|
@@ -360,7 +414,10 @@ async function automaticSpeechRecognition(args, options) {
|
|
|
360
414
|
|
|
361
415
|
// src/tasks/audio/textToSpeech.ts
|
|
362
416
|
async function textToSpeech(args, options) {
|
|
363
|
-
const res = await request(args,
|
|
417
|
+
const res = await request(args, {
|
|
418
|
+
...options,
|
|
419
|
+
taskHint: "text-to-speech"
|
|
420
|
+
});
|
|
364
421
|
const isValidOutput = res && res instanceof Blob;
|
|
365
422
|
if (!isValidOutput) {
|
|
366
423
|
throw new InferenceOutputError("Expected Blob");
|
|
@@ -370,7 +427,10 @@ async function textToSpeech(args, options) {
|
|
|
370
427
|
|
|
371
428
|
// src/tasks/audio/audioToAudio.ts
|
|
372
429
|
async function audioToAudio(args, options) {
|
|
373
|
-
const res = await request(args,
|
|
430
|
+
const res = await request(args, {
|
|
431
|
+
...options,
|
|
432
|
+
taskHint: "audio-to-audio"
|
|
433
|
+
});
|
|
374
434
|
const isValidOutput = Array.isArray(res) && res.every(
|
|
375
435
|
(x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string"
|
|
376
436
|
);
|
|
@@ -382,7 +442,10 @@ async function audioToAudio(args, options) {
|
|
|
382
442
|
|
|
383
443
|
// src/tasks/cv/imageClassification.ts
|
|
384
444
|
async function imageClassification(args, options) {
|
|
385
|
-
const res = await request(args,
|
|
445
|
+
const res = await request(args, {
|
|
446
|
+
...options,
|
|
447
|
+
taskHint: "image-classification"
|
|
448
|
+
});
|
|
386
449
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
387
450
|
if (!isValidOutput) {
|
|
388
451
|
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
@@ -392,7 +455,10 @@ async function imageClassification(args, options) {
|
|
|
392
455
|
|
|
393
456
|
// src/tasks/cv/imageSegmentation.ts
|
|
394
457
|
async function imageSegmentation(args, options) {
|
|
395
|
-
const res = await request(args,
|
|
458
|
+
const res = await request(args, {
|
|
459
|
+
...options,
|
|
460
|
+
taskHint: "image-segmentation"
|
|
461
|
+
});
|
|
396
462
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
|
|
397
463
|
if (!isValidOutput) {
|
|
398
464
|
throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
|
|
@@ -402,7 +468,10 @@ async function imageSegmentation(args, options) {
|
|
|
402
468
|
|
|
403
469
|
// src/tasks/cv/imageToText.ts
|
|
404
470
|
async function imageToText(args, options) {
|
|
405
|
-
const res = (await request(args,
|
|
471
|
+
const res = (await request(args, {
|
|
472
|
+
...options,
|
|
473
|
+
taskHint: "image-to-text"
|
|
474
|
+
}))?.[0];
|
|
406
475
|
if (typeof res?.generated_text !== "string") {
|
|
407
476
|
throw new InferenceOutputError("Expected {generated_text: string}");
|
|
408
477
|
}
|
|
@@ -411,7 +480,10 @@ async function imageToText(args, options) {
|
|
|
411
480
|
|
|
412
481
|
// src/tasks/cv/objectDetection.ts
|
|
413
482
|
async function objectDetection(args, options) {
|
|
414
|
-
const res = await request(args,
|
|
483
|
+
const res = await request(args, {
|
|
484
|
+
...options,
|
|
485
|
+
taskHint: "object-detection"
|
|
486
|
+
});
|
|
415
487
|
const isValidOutput = Array.isArray(res) && res.every(
|
|
416
488
|
(x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
|
|
417
489
|
);
|
|
@@ -425,7 +497,10 @@ async function objectDetection(args, options) {
|
|
|
425
497
|
|
|
426
498
|
// src/tasks/cv/textToImage.ts
|
|
427
499
|
async function textToImage(args, options) {
|
|
428
|
-
const res = await request(args,
|
|
500
|
+
const res = await request(args, {
|
|
501
|
+
...options,
|
|
502
|
+
taskHint: "text-to-image"
|
|
503
|
+
});
|
|
429
504
|
const isValidOutput = res && res instanceof Blob;
|
|
430
505
|
if (!isValidOutput) {
|
|
431
506
|
throw new InferenceOutputError("Expected Blob");
|
|
@@ -467,7 +542,10 @@ async function imageToImage(args, options) {
|
|
|
467
542
|
)
|
|
468
543
|
};
|
|
469
544
|
}
|
|
470
|
-
const res = await request(reqArgs,
|
|
545
|
+
const res = await request(reqArgs, {
|
|
546
|
+
...options,
|
|
547
|
+
taskHint: "image-to-image"
|
|
548
|
+
});
|
|
471
549
|
const isValidOutput = res && res instanceof Blob;
|
|
472
550
|
if (!isValidOutput) {
|
|
473
551
|
throw new InferenceOutputError("Expected Blob");
|
|
@@ -487,7 +565,10 @@ async function zeroShotImageClassification(args, options) {
|
|
|
487
565
|
)
|
|
488
566
|
}
|
|
489
567
|
};
|
|
490
|
-
const res = await request(reqArgs,
|
|
568
|
+
const res = await request(reqArgs, {
|
|
569
|
+
...options,
|
|
570
|
+
taskHint: "zero-shot-image-classification"
|
|
571
|
+
});
|
|
491
572
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
492
573
|
if (!isValidOutput) {
|
|
493
574
|
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
@@ -497,7 +578,7 @@ async function zeroShotImageClassification(args, options) {
|
|
|
497
578
|
|
|
498
579
|
// src/tasks/nlp/conversational.ts
|
|
499
580
|
async function conversational(args, options) {
|
|
500
|
-
const res = await request(args, options);
|
|
581
|
+
const res = await request(args, { ...options, taskHint: "conversational" });
|
|
501
582
|
const isValidOutput = Array.isArray(res.conversation.generated_responses) && res.conversation.generated_responses.every((x) => typeof x === "string") && Array.isArray(res.conversation.past_user_inputs) && res.conversation.past_user_inputs.every((x) => typeof x === "string") && typeof res.generated_text === "string" && Array.isArray(res.warnings) && res.warnings.every((x) => typeof x === "string");
|
|
502
583
|
if (!isValidOutput) {
|
|
503
584
|
throw new InferenceOutputError(
|
|
@@ -507,47 +588,14 @@ async function conversational(args, options) {
|
|
|
507
588
|
return res;
|
|
508
589
|
}
|
|
509
590
|
|
|
510
|
-
// src/lib/getDefaultTask.ts
|
|
511
|
-
var taskCache = /* @__PURE__ */ new Map();
|
|
512
|
-
var CACHE_DURATION = 10 * 60 * 1e3;
|
|
513
|
-
var MAX_CACHE_ITEMS = 1e3;
|
|
514
|
-
var HF_HUB_URL = "https://huggingface.co";
|
|
515
|
-
async function getDefaultTask(model, accessToken) {
|
|
516
|
-
if (isUrl(model)) {
|
|
517
|
-
return null;
|
|
518
|
-
}
|
|
519
|
-
const key = `${model}:${accessToken}`;
|
|
520
|
-
let cachedTask = taskCache.get(key);
|
|
521
|
-
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
|
|
522
|
-
taskCache.delete(key);
|
|
523
|
-
cachedTask = void 0;
|
|
524
|
-
}
|
|
525
|
-
if (cachedTask === void 0) {
|
|
526
|
-
const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
|
|
527
|
-
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
|
|
528
|
-
}).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
|
|
529
|
-
if (!modelTask) {
|
|
530
|
-
return null;
|
|
531
|
-
}
|
|
532
|
-
cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
|
|
533
|
-
taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
|
|
534
|
-
if (taskCache.size > MAX_CACHE_ITEMS) {
|
|
535
|
-
taskCache.delete(taskCache.keys().next().value);
|
|
536
|
-
}
|
|
537
|
-
}
|
|
538
|
-
return cachedTask.task;
|
|
539
|
-
}
|
|
540
|
-
|
|
541
591
|
// src/tasks/nlp/featureExtraction.ts
|
|
542
592
|
async function featureExtraction(args, options) {
|
|
543
|
-
const defaultTask = await getDefaultTask(args.model, args.accessToken);
|
|
544
|
-
const res = await request(
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
} : options
|
|
550
|
-
);
|
|
593
|
+
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : void 0;
|
|
594
|
+
const res = await request(args, {
|
|
595
|
+
...options,
|
|
596
|
+
taskHint: "feature-extraction",
|
|
597
|
+
...defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" }
|
|
598
|
+
});
|
|
551
599
|
let isValidOutput = true;
|
|
552
600
|
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
|
|
553
601
|
if (curDepth > maxDepth)
|
|
@@ -567,7 +615,10 @@ async function featureExtraction(args, options) {
|
|
|
567
615
|
|
|
568
616
|
// src/tasks/nlp/fillMask.ts
|
|
569
617
|
async function fillMask(args, options) {
|
|
570
|
-
const res = await request(args,
|
|
618
|
+
const res = await request(args, {
|
|
619
|
+
...options,
|
|
620
|
+
taskHint: "fill-mask"
|
|
621
|
+
});
|
|
571
622
|
const isValidOutput = Array.isArray(res) && res.every(
|
|
572
623
|
(x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
|
|
573
624
|
);
|
|
@@ -581,7 +632,10 @@ async function fillMask(args, options) {
|
|
|
581
632
|
|
|
582
633
|
// src/tasks/nlp/questionAnswering.ts
|
|
583
634
|
async function questionAnswering(args, options) {
|
|
584
|
-
const res = await request(args,
|
|
635
|
+
const res = await request(args, {
|
|
636
|
+
...options,
|
|
637
|
+
taskHint: "question-answering"
|
|
638
|
+
});
|
|
585
639
|
const isValidOutput = typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";
|
|
586
640
|
if (!isValidOutput) {
|
|
587
641
|
throw new InferenceOutputError("Expected {answer: string, end: number, score: number, start: number}");
|
|
@@ -591,14 +645,12 @@ async function questionAnswering(args, options) {
|
|
|
591
645
|
|
|
592
646
|
// src/tasks/nlp/sentenceSimilarity.ts
|
|
593
647
|
async function sentenceSimilarity(args, options) {
|
|
594
|
-
const defaultTask = await getDefaultTask(args.model, args.accessToken);
|
|
595
|
-
const res = await request(
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
} : options
|
|
601
|
-
);
|
|
648
|
+
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : void 0;
|
|
649
|
+
const res = await request(args, {
|
|
650
|
+
...options,
|
|
651
|
+
taskHint: "sentence-similarity",
|
|
652
|
+
...defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }
|
|
653
|
+
});
|
|
602
654
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
|
|
603
655
|
if (!isValidOutput) {
|
|
604
656
|
throw new InferenceOutputError("Expected number[]");
|
|
@@ -608,7 +660,10 @@ async function sentenceSimilarity(args, options) {
|
|
|
608
660
|
|
|
609
661
|
// src/tasks/nlp/summarization.ts
|
|
610
662
|
async function summarization(args, options) {
|
|
611
|
-
const res = await request(args,
|
|
663
|
+
const res = await request(args, {
|
|
664
|
+
...options,
|
|
665
|
+
taskHint: "summarization"
|
|
666
|
+
});
|
|
612
667
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string");
|
|
613
668
|
if (!isValidOutput) {
|
|
614
669
|
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
|
@@ -618,7 +673,10 @@ async function summarization(args, options) {
|
|
|
618
673
|
|
|
619
674
|
// src/tasks/nlp/tableQuestionAnswering.ts
|
|
620
675
|
async function tableQuestionAnswering(args, options) {
|
|
621
|
-
const res = await request(args,
|
|
676
|
+
const res = await request(args, {
|
|
677
|
+
...options,
|
|
678
|
+
taskHint: "table-question-answering"
|
|
679
|
+
});
|
|
622
680
|
const isValidOutput = typeof res?.aggregator === "string" && typeof res.answer === "string" && Array.isArray(res.cells) && res.cells.every((x) => typeof x === "string") && Array.isArray(res.coordinates) && res.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number"));
|
|
623
681
|
if (!isValidOutput) {
|
|
624
682
|
throw new InferenceOutputError(
|
|
@@ -630,7 +688,10 @@ async function tableQuestionAnswering(args, options) {
|
|
|
630
688
|
|
|
631
689
|
// src/tasks/nlp/textClassification.ts
|
|
632
690
|
async function textClassification(args, options) {
|
|
633
|
-
const res = (await request(args,
|
|
691
|
+
const res = (await request(args, {
|
|
692
|
+
...options,
|
|
693
|
+
taskHint: "text-classification"
|
|
694
|
+
}))?.[0];
|
|
634
695
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number");
|
|
635
696
|
if (!isValidOutput) {
|
|
636
697
|
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
@@ -640,7 +701,10 @@ async function textClassification(args, options) {
|
|
|
640
701
|
|
|
641
702
|
// src/tasks/nlp/textGeneration.ts
|
|
642
703
|
async function textGeneration(args, options) {
|
|
643
|
-
const res = await request(args,
|
|
704
|
+
const res = await request(args, {
|
|
705
|
+
...options,
|
|
706
|
+
taskHint: "text-generation"
|
|
707
|
+
});
|
|
644
708
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string");
|
|
645
709
|
if (!isValidOutput) {
|
|
646
710
|
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
@@ -650,7 +714,10 @@ async function textGeneration(args, options) {
|
|
|
650
714
|
|
|
651
715
|
// src/tasks/nlp/textGenerationStream.ts
|
|
652
716
|
async function* textGenerationStream(args, options) {
|
|
653
|
-
yield* streamingRequest(args,
|
|
717
|
+
yield* streamingRequest(args, {
|
|
718
|
+
...options,
|
|
719
|
+
taskHint: "text-generation"
|
|
720
|
+
});
|
|
654
721
|
}
|
|
655
722
|
|
|
656
723
|
// src/utils/toArray.ts
|
|
@@ -663,7 +730,12 @@ function toArray(obj) {
|
|
|
663
730
|
|
|
664
731
|
// src/tasks/nlp/tokenClassification.ts
|
|
665
732
|
async function tokenClassification(args, options) {
|
|
666
|
-
const res = toArray(
|
|
733
|
+
const res = toArray(
|
|
734
|
+
await request(args, {
|
|
735
|
+
...options,
|
|
736
|
+
taskHint: "token-classification"
|
|
737
|
+
})
|
|
738
|
+
);
|
|
667
739
|
const isValidOutput = Array.isArray(res) && res.every(
|
|
668
740
|
(x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
|
|
669
741
|
);
|
|
@@ -677,7 +749,10 @@ async function tokenClassification(args, options) {
|
|
|
677
749
|
|
|
678
750
|
// src/tasks/nlp/translation.ts
|
|
679
751
|
async function translation(args, options) {
|
|
680
|
-
const res = await request(args,
|
|
752
|
+
const res = await request(args, {
|
|
753
|
+
...options,
|
|
754
|
+
taskHint: "translation"
|
|
755
|
+
});
|
|
681
756
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");
|
|
682
757
|
if (!isValidOutput) {
|
|
683
758
|
throw new InferenceOutputError("Expected type Array<{translation_text: string}>");
|
|
@@ -688,7 +763,10 @@ async function translation(args, options) {
|
|
|
688
763
|
// src/tasks/nlp/zeroShotClassification.ts
|
|
689
764
|
async function zeroShotClassification(args, options) {
|
|
690
765
|
const res = toArray(
|
|
691
|
-
await request(args,
|
|
766
|
+
await request(args, {
|
|
767
|
+
...options,
|
|
768
|
+
taskHint: "zero-shot-classification"
|
|
769
|
+
})
|
|
692
770
|
);
|
|
693
771
|
const isValidOutput = Array.isArray(res) && res.every(
|
|
694
772
|
(x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
|
|
@@ -714,7 +792,10 @@ async function documentQuestionAnswering(args, options) {
|
|
|
714
792
|
}
|
|
715
793
|
};
|
|
716
794
|
const res = toArray(
|
|
717
|
-
await request(reqArgs,
|
|
795
|
+
await request(reqArgs, {
|
|
796
|
+
...options,
|
|
797
|
+
taskHint: "document-question-answering"
|
|
798
|
+
})
|
|
718
799
|
)?.[0];
|
|
719
800
|
const isValidOutput = typeof res?.answer === "string" && (typeof res.end === "number" || typeof res.end === "undefined") && (typeof res.score === "number" || typeof res.score === "undefined") && (typeof res.start === "number" || typeof res.start === "undefined");
|
|
720
801
|
if (!isValidOutput) {
|
|
@@ -737,7 +818,10 @@ async function visualQuestionAnswering(args, options) {
|
|
|
737
818
|
)
|
|
738
819
|
}
|
|
739
820
|
};
|
|
740
|
-
const res = (await request(reqArgs,
|
|
821
|
+
const res = (await request(reqArgs, {
|
|
822
|
+
...options,
|
|
823
|
+
taskHint: "visual-question-answering"
|
|
824
|
+
}))?.[0];
|
|
741
825
|
const isValidOutput = typeof res?.answer === "string" && typeof res.score === "number";
|
|
742
826
|
if (!isValidOutput) {
|
|
743
827
|
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
|
@@ -747,7 +831,10 @@ async function visualQuestionAnswering(args, options) {
|
|
|
747
831
|
|
|
748
832
|
// src/tasks/tabular/tabularRegression.ts
|
|
749
833
|
async function tabularRegression(args, options) {
|
|
750
|
-
const res = await request(args,
|
|
834
|
+
const res = await request(args, {
|
|
835
|
+
...options,
|
|
836
|
+
taskHint: "tabular-regression"
|
|
837
|
+
});
|
|
751
838
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
|
|
752
839
|
if (!isValidOutput) {
|
|
753
840
|
throw new InferenceOutputError("Expected number[]");
|
|
@@ -757,7 +844,10 @@ async function tabularRegression(args, options) {
|
|
|
757
844
|
|
|
758
845
|
// src/tasks/tabular/tabularClassification.ts
|
|
759
846
|
async function tabularClassification(args, options) {
|
|
760
|
-
const res = await request(args,
|
|
847
|
+
const res = await request(args, {
|
|
848
|
+
...options,
|
|
849
|
+
taskHint: "tabular-classification"
|
|
850
|
+
});
|
|
761
851
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
|
|
762
852
|
if (!isValidOutput) {
|
|
763
853
|
throw new InferenceOutputError("Expected number[]");
|