@huggingface/inference 2.5.2 → 2.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.d.ts +48 -2
- package/dist/index.js +168 -78
- package/dist/index.mjs +168 -78
- package/package.json +1 -1
- package/src/lib/getDefaultTask.ts +1 -1
- package/src/lib/makeRequestOptions.ts +34 -5
- package/src/tasks/audio/audioClassification.ts +4 -1
- package/src/tasks/audio/audioToAudio.ts +4 -1
- package/src/tasks/audio/automaticSpeechRecognition.ts +4 -1
- package/src/tasks/audio/textToSpeech.ts +4 -1
- package/src/tasks/custom/request.ts +3 -1
- package/src/tasks/custom/streamingRequest.ts +3 -1
- package/src/tasks/cv/imageClassification.ts +4 -1
- package/src/tasks/cv/imageSegmentation.ts +4 -1
- package/src/tasks/cv/imageToImage.ts +4 -1
- package/src/tasks/cv/imageToText.ts +6 -1
- package/src/tasks/cv/objectDetection.ts +4 -1
- package/src/tasks/cv/textToImage.ts +4 -1
- package/src/tasks/cv/zeroShotImageClassification.ts +4 -1
- package/src/tasks/multimodal/documentQuestionAnswering.ts +4 -1
- package/src/tasks/multimodal/visualQuestionAnswering.ts +6 -1
- package/src/tasks/nlp/conversational.ts +1 -1
- package/src/tasks/nlp/featureExtraction.ts +7 -10
- package/src/tasks/nlp/fillMask.ts +4 -1
- package/src/tasks/nlp/questionAnswering.ts +4 -1
- package/src/tasks/nlp/sentenceSimilarity.ts +6 -10
- package/src/tasks/nlp/summarization.ts +4 -1
- package/src/tasks/nlp/tableQuestionAnswering.ts +4 -1
- package/src/tasks/nlp/textClassification.ts +6 -1
- package/src/tasks/nlp/textGeneration.ts +4 -1
- package/src/tasks/nlp/textGenerationStream.ts +4 -1
- package/src/tasks/nlp/tokenClassification.ts +6 -1
- package/src/tasks/nlp/translation.ts +4 -1
- package/src/tasks/nlp/zeroShotClassification.ts +4 -1
- package/src/tasks/tabular/tabularClassification.ts +4 -1
- package/src/tasks/tabular/tabularRegression.ts +4 -1
- package/src/types.ts +36 -2
package/dist/index.mjs
CHANGED
|
@@ -45,15 +45,63 @@ function isUrl(modelOrUrl) {
|
|
|
45
45
|
return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
|
|
46
46
|
}
|
|
47
47
|
|
|
48
|
+
// src/lib/getDefaultTask.ts
|
|
49
|
+
var taskCache = /* @__PURE__ */ new Map();
|
|
50
|
+
var CACHE_DURATION = 10 * 60 * 1e3;
|
|
51
|
+
var MAX_CACHE_ITEMS = 1e3;
|
|
52
|
+
var HF_HUB_URL = "https://huggingface.co";
|
|
53
|
+
async function getDefaultTask(model, accessToken) {
|
|
54
|
+
if (isUrl(model)) {
|
|
55
|
+
return null;
|
|
56
|
+
}
|
|
57
|
+
const key = `${model}:${accessToken}`;
|
|
58
|
+
let cachedTask = taskCache.get(key);
|
|
59
|
+
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
|
|
60
|
+
taskCache.delete(key);
|
|
61
|
+
cachedTask = void 0;
|
|
62
|
+
}
|
|
63
|
+
if (cachedTask === void 0) {
|
|
64
|
+
const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
|
|
65
|
+
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
|
|
66
|
+
}).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
|
|
67
|
+
if (!modelTask) {
|
|
68
|
+
return null;
|
|
69
|
+
}
|
|
70
|
+
cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
|
|
71
|
+
taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
|
|
72
|
+
if (taskCache.size > MAX_CACHE_ITEMS) {
|
|
73
|
+
taskCache.delete(taskCache.keys().next().value);
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
return cachedTask.task;
|
|
77
|
+
}
|
|
78
|
+
|
|
48
79
|
// src/lib/makeRequestOptions.ts
|
|
49
80
|
var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
const {
|
|
81
|
+
var tasks = null;
|
|
82
|
+
async function makeRequestOptions(args, options) {
|
|
83
|
+
const { accessToken, model: _model, ...otherArgs } = args;
|
|
84
|
+
let { model } = args;
|
|
85
|
+
const { forceTask: task, includeCredentials, taskHint, ...otherOptions } = options ?? {};
|
|
53
86
|
const headers = {};
|
|
54
87
|
if (accessToken) {
|
|
55
88
|
headers["Authorization"] = `Bearer ${accessToken}`;
|
|
56
89
|
}
|
|
90
|
+
if (!model && !tasks && taskHint) {
|
|
91
|
+
const res = await fetch(`${HF_HUB_URL}/api/tasks`);
|
|
92
|
+
if (res.ok) {
|
|
93
|
+
tasks = await res.json();
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
if (!model && tasks && taskHint) {
|
|
97
|
+
const taskInfo = tasks[taskHint];
|
|
98
|
+
if (taskInfo) {
|
|
99
|
+
model = taskInfo.models[0].id;
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
if (!model) {
|
|
103
|
+
throw new Error("No model provided, and no default model found for this task");
|
|
104
|
+
}
|
|
57
105
|
const binary = "data" in args && !!args.data;
|
|
58
106
|
if (!binary) {
|
|
59
107
|
headers["Content-Type"] = "application/json";
|
|
@@ -91,7 +139,7 @@ function makeRequestOptions(args, options) {
|
|
|
91
139
|
|
|
92
140
|
// src/tasks/custom/request.ts
|
|
93
141
|
async function request(args, options) {
|
|
94
|
-
const { url, info } = makeRequestOptions(args, options);
|
|
142
|
+
const { url, info } = await makeRequestOptions(args, options);
|
|
95
143
|
const response = await (options?.fetch ?? fetch)(url, info);
|
|
96
144
|
if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
|
|
97
145
|
return request(args, {
|
|
@@ -215,7 +263,7 @@ function newMessage() {
|
|
|
215
263
|
|
|
216
264
|
// src/tasks/custom/streamingRequest.ts
|
|
217
265
|
async function* streamingRequest(args, options) {
|
|
218
|
-
const { url, info } = makeRequestOptions({ ...args, stream: true }, options);
|
|
266
|
+
const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
|
|
219
267
|
const response = await (options?.fetch ?? fetch)(url, info);
|
|
220
268
|
if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
|
|
221
269
|
return streamingRequest(args, {
|
|
@@ -288,7 +336,10 @@ var InferenceOutputError = class extends TypeError {
|
|
|
288
336
|
|
|
289
337
|
// src/tasks/audio/audioClassification.ts
|
|
290
338
|
async function audioClassification(args, options) {
|
|
291
|
-
const res = await request(args,
|
|
339
|
+
const res = await request(args, {
|
|
340
|
+
...options,
|
|
341
|
+
taskHint: "audio-classification"
|
|
342
|
+
});
|
|
292
343
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
293
344
|
if (!isValidOutput) {
|
|
294
345
|
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
@@ -298,7 +349,10 @@ async function audioClassification(args, options) {
|
|
|
298
349
|
|
|
299
350
|
// src/tasks/audio/automaticSpeechRecognition.ts
|
|
300
351
|
async function automaticSpeechRecognition(args, options) {
|
|
301
|
-
const res = await request(args,
|
|
352
|
+
const res = await request(args, {
|
|
353
|
+
...options,
|
|
354
|
+
taskHint: "automatic-speech-recognition"
|
|
355
|
+
});
|
|
302
356
|
const isValidOutput = typeof res?.text === "string";
|
|
303
357
|
if (!isValidOutput) {
|
|
304
358
|
throw new InferenceOutputError("Expected {text: string}");
|
|
@@ -308,7 +362,10 @@ async function automaticSpeechRecognition(args, options) {
|
|
|
308
362
|
|
|
309
363
|
// src/tasks/audio/textToSpeech.ts
|
|
310
364
|
async function textToSpeech(args, options) {
|
|
311
|
-
const res = await request(args,
|
|
365
|
+
const res = await request(args, {
|
|
366
|
+
...options,
|
|
367
|
+
taskHint: "text-to-speech"
|
|
368
|
+
});
|
|
312
369
|
const isValidOutput = res && res instanceof Blob;
|
|
313
370
|
if (!isValidOutput) {
|
|
314
371
|
throw new InferenceOutputError("Expected Blob");
|
|
@@ -318,7 +375,10 @@ async function textToSpeech(args, options) {
|
|
|
318
375
|
|
|
319
376
|
// src/tasks/audio/audioToAudio.ts
|
|
320
377
|
async function audioToAudio(args, options) {
|
|
321
|
-
const res = await request(args,
|
|
378
|
+
const res = await request(args, {
|
|
379
|
+
...options,
|
|
380
|
+
taskHint: "audio-to-audio"
|
|
381
|
+
});
|
|
322
382
|
const isValidOutput = Array.isArray(res) && res.every(
|
|
323
383
|
(x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string"
|
|
324
384
|
);
|
|
@@ -330,7 +390,10 @@ async function audioToAudio(args, options) {
|
|
|
330
390
|
|
|
331
391
|
// src/tasks/cv/imageClassification.ts
|
|
332
392
|
async function imageClassification(args, options) {
|
|
333
|
-
const res = await request(args,
|
|
393
|
+
const res = await request(args, {
|
|
394
|
+
...options,
|
|
395
|
+
taskHint: "image-classification"
|
|
396
|
+
});
|
|
334
397
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
335
398
|
if (!isValidOutput) {
|
|
336
399
|
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
@@ -340,7 +403,10 @@ async function imageClassification(args, options) {
|
|
|
340
403
|
|
|
341
404
|
// src/tasks/cv/imageSegmentation.ts
|
|
342
405
|
async function imageSegmentation(args, options) {
|
|
343
|
-
const res = await request(args,
|
|
406
|
+
const res = await request(args, {
|
|
407
|
+
...options,
|
|
408
|
+
taskHint: "image-segmentation"
|
|
409
|
+
});
|
|
344
410
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
|
|
345
411
|
if (!isValidOutput) {
|
|
346
412
|
throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
|
|
@@ -350,7 +416,10 @@ async function imageSegmentation(args, options) {
|
|
|
350
416
|
|
|
351
417
|
// src/tasks/cv/imageToText.ts
|
|
352
418
|
async function imageToText(args, options) {
|
|
353
|
-
const res = (await request(args,
|
|
419
|
+
const res = (await request(args, {
|
|
420
|
+
...options,
|
|
421
|
+
taskHint: "image-to-text"
|
|
422
|
+
}))?.[0];
|
|
354
423
|
if (typeof res?.generated_text !== "string") {
|
|
355
424
|
throw new InferenceOutputError("Expected {generated_text: string}");
|
|
356
425
|
}
|
|
@@ -359,7 +428,10 @@ async function imageToText(args, options) {
|
|
|
359
428
|
|
|
360
429
|
// src/tasks/cv/objectDetection.ts
|
|
361
430
|
async function objectDetection(args, options) {
|
|
362
|
-
const res = await request(args,
|
|
431
|
+
const res = await request(args, {
|
|
432
|
+
...options,
|
|
433
|
+
taskHint: "object-detection"
|
|
434
|
+
});
|
|
363
435
|
const isValidOutput = Array.isArray(res) && res.every(
|
|
364
436
|
(x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
|
|
365
437
|
);
|
|
@@ -373,7 +445,10 @@ async function objectDetection(args, options) {
|
|
|
373
445
|
|
|
374
446
|
// src/tasks/cv/textToImage.ts
|
|
375
447
|
async function textToImage(args, options) {
|
|
376
|
-
const res = await request(args,
|
|
448
|
+
const res = await request(args, {
|
|
449
|
+
...options,
|
|
450
|
+
taskHint: "text-to-image"
|
|
451
|
+
});
|
|
377
452
|
const isValidOutput = res && res instanceof Blob;
|
|
378
453
|
if (!isValidOutput) {
|
|
379
454
|
throw new InferenceOutputError("Expected Blob");
|
|
@@ -415,7 +490,10 @@ async function imageToImage(args, options) {
|
|
|
415
490
|
)
|
|
416
491
|
};
|
|
417
492
|
}
|
|
418
|
-
const res = await request(reqArgs,
|
|
493
|
+
const res = await request(reqArgs, {
|
|
494
|
+
...options,
|
|
495
|
+
taskHint: "image-to-image"
|
|
496
|
+
});
|
|
419
497
|
const isValidOutput = res && res instanceof Blob;
|
|
420
498
|
if (!isValidOutput) {
|
|
421
499
|
throw new InferenceOutputError("Expected Blob");
|
|
@@ -435,7 +513,10 @@ async function zeroShotImageClassification(args, options) {
|
|
|
435
513
|
)
|
|
436
514
|
}
|
|
437
515
|
};
|
|
438
|
-
const res = await request(reqArgs,
|
|
516
|
+
const res = await request(reqArgs, {
|
|
517
|
+
...options,
|
|
518
|
+
taskHint: "zero-shot-image-classification"
|
|
519
|
+
});
|
|
439
520
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
440
521
|
if (!isValidOutput) {
|
|
441
522
|
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
@@ -445,7 +526,7 @@ async function zeroShotImageClassification(args, options) {
|
|
|
445
526
|
|
|
446
527
|
// src/tasks/nlp/conversational.ts
|
|
447
528
|
async function conversational(args, options) {
|
|
448
|
-
const res = await request(args, options);
|
|
529
|
+
const res = await request(args, { ...options, taskHint: "conversational" });
|
|
449
530
|
const isValidOutput = Array.isArray(res.conversation.generated_responses) && res.conversation.generated_responses.every((x) => typeof x === "string") && Array.isArray(res.conversation.past_user_inputs) && res.conversation.past_user_inputs.every((x) => typeof x === "string") && typeof res.generated_text === "string" && Array.isArray(res.warnings) && res.warnings.every((x) => typeof x === "string");
|
|
450
531
|
if (!isValidOutput) {
|
|
451
532
|
throw new InferenceOutputError(
|
|
@@ -455,47 +536,14 @@ async function conversational(args, options) {
|
|
|
455
536
|
return res;
|
|
456
537
|
}
|
|
457
538
|
|
|
458
|
-
// src/lib/getDefaultTask.ts
|
|
459
|
-
var taskCache = /* @__PURE__ */ new Map();
|
|
460
|
-
var CACHE_DURATION = 10 * 60 * 1e3;
|
|
461
|
-
var MAX_CACHE_ITEMS = 1e3;
|
|
462
|
-
var HF_HUB_URL = "https://huggingface.co";
|
|
463
|
-
async function getDefaultTask(model, accessToken) {
|
|
464
|
-
if (isUrl(model)) {
|
|
465
|
-
return null;
|
|
466
|
-
}
|
|
467
|
-
const key = `${model}:${accessToken}`;
|
|
468
|
-
let cachedTask = taskCache.get(key);
|
|
469
|
-
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
|
|
470
|
-
taskCache.delete(key);
|
|
471
|
-
cachedTask = void 0;
|
|
472
|
-
}
|
|
473
|
-
if (cachedTask === void 0) {
|
|
474
|
-
const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
|
|
475
|
-
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
|
|
476
|
-
}).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
|
|
477
|
-
if (!modelTask) {
|
|
478
|
-
return null;
|
|
479
|
-
}
|
|
480
|
-
cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
|
|
481
|
-
taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
|
|
482
|
-
if (taskCache.size > MAX_CACHE_ITEMS) {
|
|
483
|
-
taskCache.delete(taskCache.keys().next().value);
|
|
484
|
-
}
|
|
485
|
-
}
|
|
486
|
-
return cachedTask.task;
|
|
487
|
-
}
|
|
488
|
-
|
|
489
539
|
// src/tasks/nlp/featureExtraction.ts
|
|
490
540
|
async function featureExtraction(args, options) {
|
|
491
|
-
const defaultTask = await getDefaultTask(args.model, args.accessToken);
|
|
492
|
-
const res = await request(
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
} : options
|
|
498
|
-
);
|
|
541
|
+
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : void 0;
|
|
542
|
+
const res = await request(args, {
|
|
543
|
+
...options,
|
|
544
|
+
taskHint: "feature-extraction",
|
|
545
|
+
...defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" }
|
|
546
|
+
});
|
|
499
547
|
let isValidOutput = true;
|
|
500
548
|
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
|
|
501
549
|
if (curDepth > maxDepth)
|
|
@@ -515,7 +563,10 @@ async function featureExtraction(args, options) {
|
|
|
515
563
|
|
|
516
564
|
// src/tasks/nlp/fillMask.ts
|
|
517
565
|
async function fillMask(args, options) {
|
|
518
|
-
const res = await request(args,
|
|
566
|
+
const res = await request(args, {
|
|
567
|
+
...options,
|
|
568
|
+
taskHint: "fill-mask"
|
|
569
|
+
});
|
|
519
570
|
const isValidOutput = Array.isArray(res) && res.every(
|
|
520
571
|
(x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
|
|
521
572
|
);
|
|
@@ -529,7 +580,10 @@ async function fillMask(args, options) {
|
|
|
529
580
|
|
|
530
581
|
// src/tasks/nlp/questionAnswering.ts
|
|
531
582
|
async function questionAnswering(args, options) {
|
|
532
|
-
const res = await request(args,
|
|
583
|
+
const res = await request(args, {
|
|
584
|
+
...options,
|
|
585
|
+
taskHint: "question-answering"
|
|
586
|
+
});
|
|
533
587
|
const isValidOutput = typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";
|
|
534
588
|
if (!isValidOutput) {
|
|
535
589
|
throw new InferenceOutputError("Expected {answer: string, end: number, score: number, start: number}");
|
|
@@ -539,14 +593,12 @@ async function questionAnswering(args, options) {
|
|
|
539
593
|
|
|
540
594
|
// src/tasks/nlp/sentenceSimilarity.ts
|
|
541
595
|
async function sentenceSimilarity(args, options) {
|
|
542
|
-
const defaultTask = await getDefaultTask(args.model, args.accessToken);
|
|
543
|
-
const res = await request(
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
} : options
|
|
549
|
-
);
|
|
596
|
+
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : void 0;
|
|
597
|
+
const res = await request(args, {
|
|
598
|
+
...options,
|
|
599
|
+
taskHint: "sentence-similarity",
|
|
600
|
+
...defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }
|
|
601
|
+
});
|
|
550
602
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
|
|
551
603
|
if (!isValidOutput) {
|
|
552
604
|
throw new InferenceOutputError("Expected number[]");
|
|
@@ -556,7 +608,10 @@ async function sentenceSimilarity(args, options) {
|
|
|
556
608
|
|
|
557
609
|
// src/tasks/nlp/summarization.ts
|
|
558
610
|
async function summarization(args, options) {
|
|
559
|
-
const res = await request(args,
|
|
611
|
+
const res = await request(args, {
|
|
612
|
+
...options,
|
|
613
|
+
taskHint: "summarization"
|
|
614
|
+
});
|
|
560
615
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string");
|
|
561
616
|
if (!isValidOutput) {
|
|
562
617
|
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
|
@@ -566,7 +621,10 @@ async function summarization(args, options) {
|
|
|
566
621
|
|
|
567
622
|
// src/tasks/nlp/tableQuestionAnswering.ts
|
|
568
623
|
async function tableQuestionAnswering(args, options) {
|
|
569
|
-
const res = await request(args,
|
|
624
|
+
const res = await request(args, {
|
|
625
|
+
...options,
|
|
626
|
+
taskHint: "table-question-answering"
|
|
627
|
+
});
|
|
570
628
|
const isValidOutput = typeof res?.aggregator === "string" && typeof res.answer === "string" && Array.isArray(res.cells) && res.cells.every((x) => typeof x === "string") && Array.isArray(res.coordinates) && res.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number"));
|
|
571
629
|
if (!isValidOutput) {
|
|
572
630
|
throw new InferenceOutputError(
|
|
@@ -578,7 +636,10 @@ async function tableQuestionAnswering(args, options) {
|
|
|
578
636
|
|
|
579
637
|
// src/tasks/nlp/textClassification.ts
|
|
580
638
|
async function textClassification(args, options) {
|
|
581
|
-
const res = (await request(args,
|
|
639
|
+
const res = (await request(args, {
|
|
640
|
+
...options,
|
|
641
|
+
taskHint: "text-classification"
|
|
642
|
+
}))?.[0];
|
|
582
643
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number");
|
|
583
644
|
if (!isValidOutput) {
|
|
584
645
|
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
@@ -588,7 +649,10 @@ async function textClassification(args, options) {
|
|
|
588
649
|
|
|
589
650
|
// src/tasks/nlp/textGeneration.ts
|
|
590
651
|
async function textGeneration(args, options) {
|
|
591
|
-
const res = await request(args,
|
|
652
|
+
const res = await request(args, {
|
|
653
|
+
...options,
|
|
654
|
+
taskHint: "text-generation"
|
|
655
|
+
});
|
|
592
656
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string");
|
|
593
657
|
if (!isValidOutput) {
|
|
594
658
|
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
@@ -598,7 +662,10 @@ async function textGeneration(args, options) {
|
|
|
598
662
|
|
|
599
663
|
// src/tasks/nlp/textGenerationStream.ts
|
|
600
664
|
async function* textGenerationStream(args, options) {
|
|
601
|
-
yield* streamingRequest(args,
|
|
665
|
+
yield* streamingRequest(args, {
|
|
666
|
+
...options,
|
|
667
|
+
taskHint: "text-generation"
|
|
668
|
+
});
|
|
602
669
|
}
|
|
603
670
|
|
|
604
671
|
// src/utils/toArray.ts
|
|
@@ -611,7 +678,12 @@ function toArray(obj) {
|
|
|
611
678
|
|
|
612
679
|
// src/tasks/nlp/tokenClassification.ts
|
|
613
680
|
async function tokenClassification(args, options) {
|
|
614
|
-
const res = toArray(
|
|
681
|
+
const res = toArray(
|
|
682
|
+
await request(args, {
|
|
683
|
+
...options,
|
|
684
|
+
taskHint: "token-classification"
|
|
685
|
+
})
|
|
686
|
+
);
|
|
615
687
|
const isValidOutput = Array.isArray(res) && res.every(
|
|
616
688
|
(x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
|
|
617
689
|
);
|
|
@@ -625,7 +697,10 @@ async function tokenClassification(args, options) {
|
|
|
625
697
|
|
|
626
698
|
// src/tasks/nlp/translation.ts
|
|
627
699
|
async function translation(args, options) {
|
|
628
|
-
const res = await request(args,
|
|
700
|
+
const res = await request(args, {
|
|
701
|
+
...options,
|
|
702
|
+
taskHint: "translation"
|
|
703
|
+
});
|
|
629
704
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");
|
|
630
705
|
if (!isValidOutput) {
|
|
631
706
|
throw new InferenceOutputError("Expected type Array<{translation_text: string}>");
|
|
@@ -636,7 +711,10 @@ async function translation(args, options) {
|
|
|
636
711
|
// src/tasks/nlp/zeroShotClassification.ts
|
|
637
712
|
async function zeroShotClassification(args, options) {
|
|
638
713
|
const res = toArray(
|
|
639
|
-
await request(args,
|
|
714
|
+
await request(args, {
|
|
715
|
+
...options,
|
|
716
|
+
taskHint: "zero-shot-classification"
|
|
717
|
+
})
|
|
640
718
|
);
|
|
641
719
|
const isValidOutput = Array.isArray(res) && res.every(
|
|
642
720
|
(x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
|
|
@@ -662,7 +740,10 @@ async function documentQuestionAnswering(args, options) {
|
|
|
662
740
|
}
|
|
663
741
|
};
|
|
664
742
|
const res = toArray(
|
|
665
|
-
await request(reqArgs,
|
|
743
|
+
await request(reqArgs, {
|
|
744
|
+
...options,
|
|
745
|
+
taskHint: "document-question-answering"
|
|
746
|
+
})
|
|
666
747
|
)?.[0];
|
|
667
748
|
const isValidOutput = typeof res?.answer === "string" && (typeof res.end === "number" || typeof res.end === "undefined") && (typeof res.score === "number" || typeof res.score === "undefined") && (typeof res.start === "number" || typeof res.start === "undefined");
|
|
668
749
|
if (!isValidOutput) {
|
|
@@ -685,7 +766,10 @@ async function visualQuestionAnswering(args, options) {
|
|
|
685
766
|
)
|
|
686
767
|
}
|
|
687
768
|
};
|
|
688
|
-
const res = (await request(reqArgs,
|
|
769
|
+
const res = (await request(reqArgs, {
|
|
770
|
+
...options,
|
|
771
|
+
taskHint: "visual-question-answering"
|
|
772
|
+
}))?.[0];
|
|
689
773
|
const isValidOutput = typeof res?.answer === "string" && typeof res.score === "number";
|
|
690
774
|
if (!isValidOutput) {
|
|
691
775
|
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
|
@@ -695,7 +779,10 @@ async function visualQuestionAnswering(args, options) {
|
|
|
695
779
|
|
|
696
780
|
// src/tasks/tabular/tabularRegression.ts
|
|
697
781
|
async function tabularRegression(args, options) {
|
|
698
|
-
const res = await request(args,
|
|
782
|
+
const res = await request(args, {
|
|
783
|
+
...options,
|
|
784
|
+
taskHint: "tabular-regression"
|
|
785
|
+
});
|
|
699
786
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
|
|
700
787
|
if (!isValidOutput) {
|
|
701
788
|
throw new InferenceOutputError("Expected number[]");
|
|
@@ -705,7 +792,10 @@ async function tabularRegression(args, options) {
|
|
|
705
792
|
|
|
706
793
|
// src/tasks/tabular/tabularClassification.ts
|
|
707
794
|
async function tabularClassification(args, options) {
|
|
708
|
-
const res = await request(args,
|
|
795
|
+
const res = await request(args, {
|
|
796
|
+
...options,
|
|
797
|
+
taskHint: "tabular-classification"
|
|
798
|
+
});
|
|
709
799
|
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
|
|
710
800
|
if (!isValidOutput) {
|
|
711
801
|
throw new InferenceOutputError("Expected number[]");
|
package/package.json
CHANGED
|
@@ -8,7 +8,7 @@ import { isUrl } from "./isUrl";
|
|
|
8
8
|
const taskCache = new Map<string, { task: string; date: Date }>();
|
|
9
9
|
const CACHE_DURATION = 10 * 60 * 1000;
|
|
10
10
|
const MAX_CACHE_ITEMS = 1000;
|
|
11
|
-
const HF_HUB_URL = "https://huggingface.co";
|
|
11
|
+
export const HF_HUB_URL = "https://huggingface.co";
|
|
12
12
|
|
|
13
13
|
/**
|
|
14
14
|
* Get the default task. Use a LRU cache of 1000 items with 10 minutes expiration
|
|
@@ -1,12 +1,18 @@
|
|
|
1
1
|
import type { InferenceTask, Options, RequestArgs } from "../types";
|
|
2
|
+
import { HF_HUB_URL } from "./getDefaultTask";
|
|
2
3
|
import { isUrl } from "./isUrl";
|
|
3
4
|
|
|
4
5
|
const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
|
|
5
6
|
|
|
7
|
+
/**
|
|
8
|
+
* Loaded from huggingface.co/api/tasks if needed
|
|
9
|
+
*/
|
|
10
|
+
let tasks: Record<string, { models: { id: string }[] }> | null = null;
|
|
11
|
+
|
|
6
12
|
/**
|
|
7
13
|
* Helper that prepares request arguments
|
|
8
14
|
*/
|
|
9
|
-
export function makeRequestOptions(
|
|
15
|
+
export async function makeRequestOptions(
|
|
10
16
|
args: RequestArgs & {
|
|
11
17
|
data?: Blob | ArrayBuffer;
|
|
12
18
|
stream?: boolean;
|
|
@@ -15,17 +21,40 @@ export function makeRequestOptions(
|
|
|
15
21
|
/** For internal HF use, which is why it's not exposed in {@link Options} */
|
|
16
22
|
includeCredentials?: boolean;
|
|
17
23
|
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
|
18
|
-
|
|
24
|
+
forceTask?: string | InferenceTask;
|
|
25
|
+
/** To load default model if needed */
|
|
26
|
+
taskHint?: InferenceTask;
|
|
19
27
|
}
|
|
20
|
-
): { url: string; info: RequestInit } {
|
|
21
|
-
|
|
22
|
-
const {
|
|
28
|
+
): Promise<{ url: string; info: RequestInit }> {
|
|
29
|
+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
|
30
|
+
const { accessToken, model: _model, ...otherArgs } = args;
|
|
31
|
+
let { model } = args;
|
|
32
|
+
const { forceTask: task, includeCredentials, taskHint, ...otherOptions } = options ?? {};
|
|
23
33
|
|
|
24
34
|
const headers: Record<string, string> = {};
|
|
25
35
|
if (accessToken) {
|
|
26
36
|
headers["Authorization"] = `Bearer ${accessToken}`;
|
|
27
37
|
}
|
|
28
38
|
|
|
39
|
+
if (!model && !tasks && taskHint) {
|
|
40
|
+
const res = await fetch(`${HF_HUB_URL}/api/tasks`);
|
|
41
|
+
|
|
42
|
+
if (res.ok) {
|
|
43
|
+
tasks = await res.json();
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
if (!model && tasks && taskHint) {
|
|
48
|
+
const taskInfo = tasks[taskHint];
|
|
49
|
+
if (taskInfo) {
|
|
50
|
+
model = taskInfo.models[0].id;
|
|
51
|
+
}
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
if (!model) {
|
|
55
|
+
throw new Error("No model provided, and no default model found for this task");
|
|
56
|
+
}
|
|
57
|
+
|
|
29
58
|
const binary = "data" in args && !!args.data;
|
|
30
59
|
|
|
31
60
|
if (!binary) {
|
|
@@ -31,7 +31,10 @@ export async function audioClassification(
|
|
|
31
31
|
args: AudioClassificationArgs,
|
|
32
32
|
options?: Options
|
|
33
33
|
): Promise<AudioClassificationReturn> {
|
|
34
|
-
const res = await request<AudioClassificationReturn>(args,
|
|
34
|
+
const res = await request<AudioClassificationReturn>(args, {
|
|
35
|
+
...options,
|
|
36
|
+
taskHint: "audio-classification",
|
|
37
|
+
});
|
|
35
38
|
const isValidOutput =
|
|
36
39
|
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
37
40
|
if (!isValidOutput) {
|
|
@@ -33,7 +33,10 @@ export type AudioToAudioReturn = AudioToAudioOutputValue[];
|
|
|
33
33
|
* Example model: speechbrain/sepformer-wham does audio source separation.
|
|
34
34
|
*/
|
|
35
35
|
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioReturn> {
|
|
36
|
-
const res = await request<AudioToAudioReturn>(args,
|
|
36
|
+
const res = await request<AudioToAudioReturn>(args, {
|
|
37
|
+
...options,
|
|
38
|
+
taskHint: "audio-to-audio",
|
|
39
|
+
});
|
|
37
40
|
const isValidOutput =
|
|
38
41
|
Array.isArray(res) &&
|
|
39
42
|
res.every(
|
|
@@ -24,7 +24,10 @@ export async function automaticSpeechRecognition(
|
|
|
24
24
|
args: AutomaticSpeechRecognitionArgs,
|
|
25
25
|
options?: Options
|
|
26
26
|
): Promise<AutomaticSpeechRecognitionOutput> {
|
|
27
|
-
const res = await request<AutomaticSpeechRecognitionOutput>(args,
|
|
27
|
+
const res = await request<AutomaticSpeechRecognitionOutput>(args, {
|
|
28
|
+
...options,
|
|
29
|
+
taskHint: "automatic-speech-recognition",
|
|
30
|
+
});
|
|
28
31
|
const isValidOutput = typeof res?.text === "string";
|
|
29
32
|
if (!isValidOutput) {
|
|
30
33
|
throw new InferenceOutputError("Expected {text: string}");
|
|
@@ -16,7 +16,10 @@ export type TextToSpeechOutput = Blob;
|
|
|
16
16
|
* Recommended model: espnet/kan-bayashi_ljspeech_vits
|
|
17
17
|
*/
|
|
18
18
|
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<TextToSpeechOutput> {
|
|
19
|
-
const res = await request<TextToSpeechOutput>(args,
|
|
19
|
+
const res = await request<TextToSpeechOutput>(args, {
|
|
20
|
+
...options,
|
|
21
|
+
taskHint: "text-to-speech",
|
|
22
|
+
});
|
|
20
23
|
const isValidOutput = res && res instanceof Blob;
|
|
21
24
|
if (!isValidOutput) {
|
|
22
25
|
throw new InferenceOutputError("Expected Blob");
|
|
@@ -11,9 +11,11 @@ export async function request<T>(
|
|
|
11
11
|
includeCredentials?: boolean;
|
|
12
12
|
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
|
13
13
|
task?: string | InferenceTask;
|
|
14
|
+
/** To load default model if needed */
|
|
15
|
+
taskHint?: InferenceTask;
|
|
14
16
|
}
|
|
15
17
|
): Promise<T> {
|
|
16
|
-
const { url, info } = makeRequestOptions(args, options);
|
|
18
|
+
const { url, info } = await makeRequestOptions(args, options);
|
|
17
19
|
const response = await (options?.fetch ?? fetch)(url, info);
|
|
18
20
|
|
|
19
21
|
if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
|
|
@@ -13,9 +13,11 @@ export async function* streamingRequest<T>(
|
|
|
13
13
|
includeCredentials?: boolean;
|
|
14
14
|
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
|
15
15
|
task?: string | InferenceTask;
|
|
16
|
+
/** To load default model if needed */
|
|
17
|
+
taskHint?: InferenceTask;
|
|
16
18
|
}
|
|
17
19
|
): AsyncGenerator<T> {
|
|
18
|
-
const { url, info } = makeRequestOptions({ ...args, stream: true }, options);
|
|
20
|
+
const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
|
|
19
21
|
const response = await (options?.fetch ?? fetch)(url, info);
|
|
20
22
|
|
|
21
23
|
if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
|
|
@@ -30,7 +30,10 @@ export async function imageClassification(
|
|
|
30
30
|
args: ImageClassificationArgs,
|
|
31
31
|
options?: Options
|
|
32
32
|
): Promise<ImageClassificationOutput> {
|
|
33
|
-
const res = await request<ImageClassificationOutput>(args,
|
|
33
|
+
const res = await request<ImageClassificationOutput>(args, {
|
|
34
|
+
...options,
|
|
35
|
+
taskHint: "image-classification",
|
|
36
|
+
});
|
|
34
37
|
const isValidOutput =
|
|
35
38
|
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
36
39
|
if (!isValidOutput) {
|
|
@@ -34,7 +34,10 @@ export async function imageSegmentation(
|
|
|
34
34
|
args: ImageSegmentationArgs,
|
|
35
35
|
options?: Options
|
|
36
36
|
): Promise<ImageSegmentationOutput> {
|
|
37
|
-
const res = await request<ImageSegmentationOutput>(args,
|
|
37
|
+
const res = await request<ImageSegmentationOutput>(args, {
|
|
38
|
+
...options,
|
|
39
|
+
taskHint: "image-segmentation",
|
|
40
|
+
});
|
|
38
41
|
const isValidOutput =
|
|
39
42
|
Array.isArray(res) &&
|
|
40
43
|
res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
|
|
@@ -74,7 +74,10 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P
|
|
|
74
74
|
),
|
|
75
75
|
};
|
|
76
76
|
}
|
|
77
|
-
const res = await request<ImageToImageOutput>(reqArgs,
|
|
77
|
+
const res = await request<ImageToImageOutput>(reqArgs, {
|
|
78
|
+
...options,
|
|
79
|
+
taskHint: "image-to-image",
|
|
80
|
+
});
|
|
78
81
|
const isValidOutput = res && res instanceof Blob;
|
|
79
82
|
if (!isValidOutput) {
|
|
80
83
|
throw new InferenceOutputError("Expected Blob");
|