@huggingface/inference 2.5.2 → 2.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. package/dist/index.d.ts +48 -2
  2. package/dist/index.js +168 -78
  3. package/dist/index.mjs +168 -78
  4. package/package.json +1 -1
  5. package/src/lib/getDefaultTask.ts +1 -1
  6. package/src/lib/makeRequestOptions.ts +34 -5
  7. package/src/tasks/audio/audioClassification.ts +4 -1
  8. package/src/tasks/audio/audioToAudio.ts +4 -1
  9. package/src/tasks/audio/automaticSpeechRecognition.ts +4 -1
  10. package/src/tasks/audio/textToSpeech.ts +4 -1
  11. package/src/tasks/custom/request.ts +3 -1
  12. package/src/tasks/custom/streamingRequest.ts +3 -1
  13. package/src/tasks/cv/imageClassification.ts +4 -1
  14. package/src/tasks/cv/imageSegmentation.ts +4 -1
  15. package/src/tasks/cv/imageToImage.ts +4 -1
  16. package/src/tasks/cv/imageToText.ts +6 -1
  17. package/src/tasks/cv/objectDetection.ts +4 -1
  18. package/src/tasks/cv/textToImage.ts +4 -1
  19. package/src/tasks/cv/zeroShotImageClassification.ts +4 -1
  20. package/src/tasks/multimodal/documentQuestionAnswering.ts +4 -1
  21. package/src/tasks/multimodal/visualQuestionAnswering.ts +6 -1
  22. package/src/tasks/nlp/conversational.ts +1 -1
  23. package/src/tasks/nlp/featureExtraction.ts +7 -10
  24. package/src/tasks/nlp/fillMask.ts +4 -1
  25. package/src/tasks/nlp/questionAnswering.ts +4 -1
  26. package/src/tasks/nlp/sentenceSimilarity.ts +6 -10
  27. package/src/tasks/nlp/summarization.ts +4 -1
  28. package/src/tasks/nlp/tableQuestionAnswering.ts +4 -1
  29. package/src/tasks/nlp/textClassification.ts +6 -1
  30. package/src/tasks/nlp/textGeneration.ts +4 -1
  31. package/src/tasks/nlp/textGenerationStream.ts +4 -1
  32. package/src/tasks/nlp/tokenClassification.ts +6 -1
  33. package/src/tasks/nlp/translation.ts +4 -1
  34. package/src/tasks/nlp/zeroShotClassification.ts +4 -1
  35. package/src/tasks/tabular/tabularClassification.ts +4 -1
  36. package/src/tasks/tabular/tabularRegression.ts +4 -1
  37. package/src/types.ts +36 -2
package/dist/index.mjs CHANGED
@@ -45,15 +45,63 @@ function isUrl(modelOrUrl) {
45
45
  return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
46
46
  }
47
47
 
48
+ // src/lib/getDefaultTask.ts
49
+ var taskCache = /* @__PURE__ */ new Map();
50
+ var CACHE_DURATION = 10 * 60 * 1e3;
51
+ var MAX_CACHE_ITEMS = 1e3;
52
+ var HF_HUB_URL = "https://huggingface.co";
53
+ async function getDefaultTask(model, accessToken) {
54
+ if (isUrl(model)) {
55
+ return null;
56
+ }
57
+ const key = `${model}:${accessToken}`;
58
+ let cachedTask = taskCache.get(key);
59
+ if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
60
+ taskCache.delete(key);
61
+ cachedTask = void 0;
62
+ }
63
+ if (cachedTask === void 0) {
64
+ const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
65
+ headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
66
+ }).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
67
+ if (!modelTask) {
68
+ return null;
69
+ }
70
+ cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
71
+ taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
72
+ if (taskCache.size > MAX_CACHE_ITEMS) {
73
+ taskCache.delete(taskCache.keys().next().value);
74
+ }
75
+ }
76
+ return cachedTask.task;
77
+ }
78
+
48
79
  // src/lib/makeRequestOptions.ts
49
80
  var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
50
- function makeRequestOptions(args, options) {
51
- const { model, accessToken, ...otherArgs } = args;
52
- const { task, includeCredentials, ...otherOptions } = options ?? {};
81
+ var tasks = null;
82
+ async function makeRequestOptions(args, options) {
83
+ const { accessToken, model: _model, ...otherArgs } = args;
84
+ let { model } = args;
85
+ const { forceTask: task, includeCredentials, taskHint, ...otherOptions } = options ?? {};
53
86
  const headers = {};
54
87
  if (accessToken) {
55
88
  headers["Authorization"] = `Bearer ${accessToken}`;
56
89
  }
90
+ if (!model && !tasks && taskHint) {
91
+ const res = await fetch(`${HF_HUB_URL}/api/tasks`);
92
+ if (res.ok) {
93
+ tasks = await res.json();
94
+ }
95
+ }
96
+ if (!model && tasks && taskHint) {
97
+ const taskInfo = tasks[taskHint];
98
+ if (taskInfo) {
99
+ model = taskInfo.models[0].id;
100
+ }
101
+ }
102
+ if (!model) {
103
+ throw new Error("No model provided, and no default model found for this task");
104
+ }
57
105
  const binary = "data" in args && !!args.data;
58
106
  if (!binary) {
59
107
  headers["Content-Type"] = "application/json";
@@ -91,7 +139,7 @@ function makeRequestOptions(args, options) {
91
139
 
92
140
  // src/tasks/custom/request.ts
93
141
  async function request(args, options) {
94
- const { url, info } = makeRequestOptions(args, options);
142
+ const { url, info } = await makeRequestOptions(args, options);
95
143
  const response = await (options?.fetch ?? fetch)(url, info);
96
144
  if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
97
145
  return request(args, {
@@ -215,7 +263,7 @@ function newMessage() {
215
263
 
216
264
  // src/tasks/custom/streamingRequest.ts
217
265
  async function* streamingRequest(args, options) {
218
- const { url, info } = makeRequestOptions({ ...args, stream: true }, options);
266
+ const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
219
267
  const response = await (options?.fetch ?? fetch)(url, info);
220
268
  if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
221
269
  return streamingRequest(args, {
@@ -288,7 +336,10 @@ var InferenceOutputError = class extends TypeError {
288
336
 
289
337
  // src/tasks/audio/audioClassification.ts
290
338
  async function audioClassification(args, options) {
291
- const res = await request(args, options);
339
+ const res = await request(args, {
340
+ ...options,
341
+ taskHint: "audio-classification"
342
+ });
292
343
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
293
344
  if (!isValidOutput) {
294
345
  throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
@@ -298,7 +349,10 @@ async function audioClassification(args, options) {
298
349
 
299
350
  // src/tasks/audio/automaticSpeechRecognition.ts
300
351
  async function automaticSpeechRecognition(args, options) {
301
- const res = await request(args, options);
352
+ const res = await request(args, {
353
+ ...options,
354
+ taskHint: "automatic-speech-recognition"
355
+ });
302
356
  const isValidOutput = typeof res?.text === "string";
303
357
  if (!isValidOutput) {
304
358
  throw new InferenceOutputError("Expected {text: string}");
@@ -308,7 +362,10 @@ async function automaticSpeechRecognition(args, options) {
308
362
 
309
363
  // src/tasks/audio/textToSpeech.ts
310
364
  async function textToSpeech(args, options) {
311
- const res = await request(args, options);
365
+ const res = await request(args, {
366
+ ...options,
367
+ taskHint: "text-to-speech"
368
+ });
312
369
  const isValidOutput = res && res instanceof Blob;
313
370
  if (!isValidOutput) {
314
371
  throw new InferenceOutputError("Expected Blob");
@@ -318,7 +375,10 @@ async function textToSpeech(args, options) {
318
375
 
319
376
  // src/tasks/audio/audioToAudio.ts
320
377
  async function audioToAudio(args, options) {
321
- const res = await request(args, options);
378
+ const res = await request(args, {
379
+ ...options,
380
+ taskHint: "audio-to-audio"
381
+ });
322
382
  const isValidOutput = Array.isArray(res) && res.every(
323
383
  (x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string"
324
384
  );
@@ -330,7 +390,10 @@ async function audioToAudio(args, options) {
330
390
 
331
391
  // src/tasks/cv/imageClassification.ts
332
392
  async function imageClassification(args, options) {
333
- const res = await request(args, options);
393
+ const res = await request(args, {
394
+ ...options,
395
+ taskHint: "image-classification"
396
+ });
334
397
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
335
398
  if (!isValidOutput) {
336
399
  throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
@@ -340,7 +403,10 @@ async function imageClassification(args, options) {
340
403
 
341
404
  // src/tasks/cv/imageSegmentation.ts
342
405
  async function imageSegmentation(args, options) {
343
- const res = await request(args, options);
406
+ const res = await request(args, {
407
+ ...options,
408
+ taskHint: "image-segmentation"
409
+ });
344
410
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
345
411
  if (!isValidOutput) {
346
412
  throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
@@ -350,7 +416,10 @@ async function imageSegmentation(args, options) {
350
416
 
351
417
  // src/tasks/cv/imageToText.ts
352
418
  async function imageToText(args, options) {
353
- const res = (await request(args, options))?.[0];
419
+ const res = (await request(args, {
420
+ ...options,
421
+ taskHint: "image-to-text"
422
+ }))?.[0];
354
423
  if (typeof res?.generated_text !== "string") {
355
424
  throw new InferenceOutputError("Expected {generated_text: string}");
356
425
  }
@@ -359,7 +428,10 @@ async function imageToText(args, options) {
359
428
 
360
429
  // src/tasks/cv/objectDetection.ts
361
430
  async function objectDetection(args, options) {
362
- const res = await request(args, options);
431
+ const res = await request(args, {
432
+ ...options,
433
+ taskHint: "object-detection"
434
+ });
363
435
  const isValidOutput = Array.isArray(res) && res.every(
364
436
  (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
365
437
  );
@@ -373,7 +445,10 @@ async function objectDetection(args, options) {
373
445
 
374
446
  // src/tasks/cv/textToImage.ts
375
447
  async function textToImage(args, options) {
376
- const res = await request(args, options);
448
+ const res = await request(args, {
449
+ ...options,
450
+ taskHint: "text-to-image"
451
+ });
377
452
  const isValidOutput = res && res instanceof Blob;
378
453
  if (!isValidOutput) {
379
454
  throw new InferenceOutputError("Expected Blob");
@@ -415,7 +490,10 @@ async function imageToImage(args, options) {
415
490
  )
416
491
  };
417
492
  }
418
- const res = await request(reqArgs, options);
493
+ const res = await request(reqArgs, {
494
+ ...options,
495
+ taskHint: "image-to-image"
496
+ });
419
497
  const isValidOutput = res && res instanceof Blob;
420
498
  if (!isValidOutput) {
421
499
  throw new InferenceOutputError("Expected Blob");
@@ -435,7 +513,10 @@ async function zeroShotImageClassification(args, options) {
435
513
  )
436
514
  }
437
515
  };
438
- const res = await request(reqArgs, options);
516
+ const res = await request(reqArgs, {
517
+ ...options,
518
+ taskHint: "zero-shot-image-classification"
519
+ });
439
520
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
440
521
  if (!isValidOutput) {
441
522
  throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
@@ -445,7 +526,7 @@ async function zeroShotImageClassification(args, options) {
445
526
 
446
527
  // src/tasks/nlp/conversational.ts
447
528
  async function conversational(args, options) {
448
- const res = await request(args, options);
529
+ const res = await request(args, { ...options, taskHint: "conversational" });
449
530
  const isValidOutput = Array.isArray(res.conversation.generated_responses) && res.conversation.generated_responses.every((x) => typeof x === "string") && Array.isArray(res.conversation.past_user_inputs) && res.conversation.past_user_inputs.every((x) => typeof x === "string") && typeof res.generated_text === "string" && Array.isArray(res.warnings) && res.warnings.every((x) => typeof x === "string");
450
531
  if (!isValidOutput) {
451
532
  throw new InferenceOutputError(
@@ -455,47 +536,14 @@ async function conversational(args, options) {
455
536
  return res;
456
537
  }
457
538
 
458
- // src/lib/getDefaultTask.ts
459
- var taskCache = /* @__PURE__ */ new Map();
460
- var CACHE_DURATION = 10 * 60 * 1e3;
461
- var MAX_CACHE_ITEMS = 1e3;
462
- var HF_HUB_URL = "https://huggingface.co";
463
- async function getDefaultTask(model, accessToken) {
464
- if (isUrl(model)) {
465
- return null;
466
- }
467
- const key = `${model}:${accessToken}`;
468
- let cachedTask = taskCache.get(key);
469
- if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
470
- taskCache.delete(key);
471
- cachedTask = void 0;
472
- }
473
- if (cachedTask === void 0) {
474
- const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
475
- headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
476
- }).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
477
- if (!modelTask) {
478
- return null;
479
- }
480
- cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
481
- taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
482
- if (taskCache.size > MAX_CACHE_ITEMS) {
483
- taskCache.delete(taskCache.keys().next().value);
484
- }
485
- }
486
- return cachedTask.task;
487
- }
488
-
489
539
  // src/tasks/nlp/featureExtraction.ts
490
540
  async function featureExtraction(args, options) {
491
- const defaultTask = await getDefaultTask(args.model, args.accessToken);
492
- const res = await request(
493
- args,
494
- defaultTask === "sentence-similarity" ? {
495
- ...options,
496
- task: "feature-extraction"
497
- } : options
498
- );
541
+ const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : void 0;
542
+ const res = await request(args, {
543
+ ...options,
544
+ taskHint: "feature-extraction",
545
+ ...defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" }
546
+ });
499
547
  let isValidOutput = true;
500
548
  const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
501
549
  if (curDepth > maxDepth)
@@ -515,7 +563,10 @@ async function featureExtraction(args, options) {
515
563
 
516
564
  // src/tasks/nlp/fillMask.ts
517
565
  async function fillMask(args, options) {
518
- const res = await request(args, options);
566
+ const res = await request(args, {
567
+ ...options,
568
+ taskHint: "fill-mask"
569
+ });
519
570
  const isValidOutput = Array.isArray(res) && res.every(
520
571
  (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
521
572
  );
@@ -529,7 +580,10 @@ async function fillMask(args, options) {
529
580
 
530
581
  // src/tasks/nlp/questionAnswering.ts
531
582
  async function questionAnswering(args, options) {
532
- const res = await request(args, options);
583
+ const res = await request(args, {
584
+ ...options,
585
+ taskHint: "question-answering"
586
+ });
533
587
  const isValidOutput = typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";
534
588
  if (!isValidOutput) {
535
589
  throw new InferenceOutputError("Expected {answer: string, end: number, score: number, start: number}");
@@ -539,14 +593,12 @@ async function questionAnswering(args, options) {
539
593
 
540
594
  // src/tasks/nlp/sentenceSimilarity.ts
541
595
  async function sentenceSimilarity(args, options) {
542
- const defaultTask = await getDefaultTask(args.model, args.accessToken);
543
- const res = await request(
544
- args,
545
- defaultTask === "feature-extraction" ? {
546
- ...options,
547
- task: "sentence-similarity"
548
- } : options
549
- );
596
+ const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : void 0;
597
+ const res = await request(args, {
598
+ ...options,
599
+ taskHint: "sentence-similarity",
600
+ ...defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }
601
+ });
550
602
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
551
603
  if (!isValidOutput) {
552
604
  throw new InferenceOutputError("Expected number[]");
@@ -556,7 +608,10 @@ async function sentenceSimilarity(args, options) {
556
608
 
557
609
  // src/tasks/nlp/summarization.ts
558
610
  async function summarization(args, options) {
559
- const res = await request(args, options);
611
+ const res = await request(args, {
612
+ ...options,
613
+ taskHint: "summarization"
614
+ });
560
615
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string");
561
616
  if (!isValidOutput) {
562
617
  throw new InferenceOutputError("Expected Array<{summary_text: string}>");
@@ -566,7 +621,10 @@ async function summarization(args, options) {
566
621
 
567
622
  // src/tasks/nlp/tableQuestionAnswering.ts
568
623
  async function tableQuestionAnswering(args, options) {
569
- const res = await request(args, options);
624
+ const res = await request(args, {
625
+ ...options,
626
+ taskHint: "table-question-answering"
627
+ });
570
628
  const isValidOutput = typeof res?.aggregator === "string" && typeof res.answer === "string" && Array.isArray(res.cells) && res.cells.every((x) => typeof x === "string") && Array.isArray(res.coordinates) && res.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number"));
571
629
  if (!isValidOutput) {
572
630
  throw new InferenceOutputError(
@@ -578,7 +636,10 @@ async function tableQuestionAnswering(args, options) {
578
636
 
579
637
  // src/tasks/nlp/textClassification.ts
580
638
  async function textClassification(args, options) {
581
- const res = (await request(args, options))?.[0];
639
+ const res = (await request(args, {
640
+ ...options,
641
+ taskHint: "text-classification"
642
+ }))?.[0];
582
643
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number");
583
644
  if (!isValidOutput) {
584
645
  throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
@@ -588,7 +649,10 @@ async function textClassification(args, options) {
588
649
 
589
650
  // src/tasks/nlp/textGeneration.ts
590
651
  async function textGeneration(args, options) {
591
- const res = await request(args, options);
652
+ const res = await request(args, {
653
+ ...options,
654
+ taskHint: "text-generation"
655
+ });
592
656
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string");
593
657
  if (!isValidOutput) {
594
658
  throw new InferenceOutputError("Expected Array<{generated_text: string}>");
@@ -598,7 +662,10 @@ async function textGeneration(args, options) {
598
662
 
599
663
  // src/tasks/nlp/textGenerationStream.ts
600
664
  async function* textGenerationStream(args, options) {
601
- yield* streamingRequest(args, options);
665
+ yield* streamingRequest(args, {
666
+ ...options,
667
+ taskHint: "text-generation"
668
+ });
602
669
  }
603
670
 
604
671
  // src/utils/toArray.ts
@@ -611,7 +678,12 @@ function toArray(obj) {
611
678
 
612
679
  // src/tasks/nlp/tokenClassification.ts
613
680
  async function tokenClassification(args, options) {
614
- const res = toArray(await request(args, options));
681
+ const res = toArray(
682
+ await request(args, {
683
+ ...options,
684
+ taskHint: "token-classification"
685
+ })
686
+ );
615
687
  const isValidOutput = Array.isArray(res) && res.every(
616
688
  (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
617
689
  );
@@ -625,7 +697,10 @@ async function tokenClassification(args, options) {
625
697
 
626
698
  // src/tasks/nlp/translation.ts
627
699
  async function translation(args, options) {
628
- const res = await request(args, options);
700
+ const res = await request(args, {
701
+ ...options,
702
+ taskHint: "translation"
703
+ });
629
704
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");
630
705
  if (!isValidOutput) {
631
706
  throw new InferenceOutputError("Expected type Array<{translation_text: string}>");
@@ -636,7 +711,10 @@ async function translation(args, options) {
636
711
  // src/tasks/nlp/zeroShotClassification.ts
637
712
  async function zeroShotClassification(args, options) {
638
713
  const res = toArray(
639
- await request(args, options)
714
+ await request(args, {
715
+ ...options,
716
+ taskHint: "zero-shot-classification"
717
+ })
640
718
  );
641
719
  const isValidOutput = Array.isArray(res) && res.every(
642
720
  (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
@@ -662,7 +740,10 @@ async function documentQuestionAnswering(args, options) {
662
740
  }
663
741
  };
664
742
  const res = toArray(
665
- await request(reqArgs, options)
743
+ await request(reqArgs, {
744
+ ...options,
745
+ taskHint: "document-question-answering"
746
+ })
666
747
  )?.[0];
667
748
  const isValidOutput = typeof res?.answer === "string" && (typeof res.end === "number" || typeof res.end === "undefined") && (typeof res.score === "number" || typeof res.score === "undefined") && (typeof res.start === "number" || typeof res.start === "undefined");
668
749
  if (!isValidOutput) {
@@ -685,7 +766,10 @@ async function visualQuestionAnswering(args, options) {
685
766
  )
686
767
  }
687
768
  };
688
- const res = (await request(reqArgs, options))?.[0];
769
+ const res = (await request(reqArgs, {
770
+ ...options,
771
+ taskHint: "visual-question-answering"
772
+ }))?.[0];
689
773
  const isValidOutput = typeof res?.answer === "string" && typeof res.score === "number";
690
774
  if (!isValidOutput) {
691
775
  throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
@@ -695,7 +779,10 @@ async function visualQuestionAnswering(args, options) {
695
779
 
696
780
  // src/tasks/tabular/tabularRegression.ts
697
781
  async function tabularRegression(args, options) {
698
- const res = await request(args, options);
782
+ const res = await request(args, {
783
+ ...options,
784
+ taskHint: "tabular-regression"
785
+ });
699
786
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
700
787
  if (!isValidOutput) {
701
788
  throw new InferenceOutputError("Expected number[]");
@@ -705,7 +792,10 @@ async function tabularRegression(args, options) {
705
792
 
706
793
  // src/tasks/tabular/tabularClassification.ts
707
794
  async function tabularClassification(args, options) {
708
- const res = await request(args, options);
795
+ const res = await request(args, {
796
+ ...options,
797
+ taskHint: "tabular-classification"
798
+ });
709
799
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
710
800
  if (!isValidOutput) {
711
801
  throw new InferenceOutputError("Expected number[]");
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@huggingface/inference",
3
- "version": "2.5.2",
3
+ "version": "2.6.0",
4
4
  "packageManager": "pnpm@8.3.1",
5
5
  "license": "MIT",
6
6
  "author": "Tim Mikeladze <tim.mikeladze@gmail.com>",
@@ -8,7 +8,7 @@ import { isUrl } from "./isUrl";
8
8
  const taskCache = new Map<string, { task: string; date: Date }>();
9
9
  const CACHE_DURATION = 10 * 60 * 1000;
10
10
  const MAX_CACHE_ITEMS = 1000;
11
- const HF_HUB_URL = "https://huggingface.co";
11
+ export const HF_HUB_URL = "https://huggingface.co";
12
12
 
13
13
  /**
14
14
  * Get the default task. Use a LRU cache of 1000 items with 10 minutes expiration
@@ -1,12 +1,18 @@
1
1
  import type { InferenceTask, Options, RequestArgs } from "../types";
2
+ import { HF_HUB_URL } from "./getDefaultTask";
2
3
  import { isUrl } from "./isUrl";
3
4
 
4
5
  const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
5
6
 
7
+ /**
8
+ * Loaded from huggingface.co/api/tasks if needed
9
+ */
10
+ let tasks: Record<string, { models: { id: string }[] }> | null = null;
11
+
6
12
  /**
7
13
  * Helper that prepares request arguments
8
14
  */
9
- export function makeRequestOptions(
15
+ export async function makeRequestOptions(
10
16
  args: RequestArgs & {
11
17
  data?: Blob | ArrayBuffer;
12
18
  stream?: boolean;
@@ -15,17 +21,40 @@ export function makeRequestOptions(
15
21
  /** For internal HF use, which is why it's not exposed in {@link Options} */
16
22
  includeCredentials?: boolean;
17
23
  /** When a model can be used for multiple tasks, and we want to run a non-default task */
18
- task?: string | InferenceTask;
24
+ forceTask?: string | InferenceTask;
25
+ /** To load default model if needed */
26
+ taskHint?: InferenceTask;
19
27
  }
20
- ): { url: string; info: RequestInit } {
21
- const { model, accessToken, ...otherArgs } = args;
22
- const { task, includeCredentials, ...otherOptions } = options ?? {};
28
+ ): Promise<{ url: string; info: RequestInit }> {
29
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
30
+ const { accessToken, model: _model, ...otherArgs } = args;
31
+ let { model } = args;
32
+ const { forceTask: task, includeCredentials, taskHint, ...otherOptions } = options ?? {};
23
33
 
24
34
  const headers: Record<string, string> = {};
25
35
  if (accessToken) {
26
36
  headers["Authorization"] = `Bearer ${accessToken}`;
27
37
  }
28
38
 
39
+ if (!model && !tasks && taskHint) {
40
+ const res = await fetch(`${HF_HUB_URL}/api/tasks`);
41
+
42
+ if (res.ok) {
43
+ tasks = await res.json();
44
+ }
45
+ }
46
+
47
+ if (!model && tasks && taskHint) {
48
+ const taskInfo = tasks[taskHint];
49
+ if (taskInfo) {
50
+ model = taskInfo.models[0].id;
51
+ }
52
+ }
53
+
54
+ if (!model) {
55
+ throw new Error("No model provided, and no default model found for this task");
56
+ }
57
+
29
58
  const binary = "data" in args && !!args.data;
30
59
 
31
60
  if (!binary) {
@@ -31,7 +31,10 @@ export async function audioClassification(
31
31
  args: AudioClassificationArgs,
32
32
  options?: Options
33
33
  ): Promise<AudioClassificationReturn> {
34
- const res = await request<AudioClassificationReturn>(args, options);
34
+ const res = await request<AudioClassificationReturn>(args, {
35
+ ...options,
36
+ taskHint: "audio-classification",
37
+ });
35
38
  const isValidOutput =
36
39
  Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
37
40
  if (!isValidOutput) {
@@ -33,7 +33,10 @@ export type AudioToAudioReturn = AudioToAudioOutputValue[];
33
33
  * Example model: speechbrain/sepformer-wham does audio source separation.
34
34
  */
35
35
  export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioReturn> {
36
- const res = await request<AudioToAudioReturn>(args, options);
36
+ const res = await request<AudioToAudioReturn>(args, {
37
+ ...options,
38
+ taskHint: "audio-to-audio",
39
+ });
37
40
  const isValidOutput =
38
41
  Array.isArray(res) &&
39
42
  res.every(
@@ -24,7 +24,10 @@ export async function automaticSpeechRecognition(
24
24
  args: AutomaticSpeechRecognitionArgs,
25
25
  options?: Options
26
26
  ): Promise<AutomaticSpeechRecognitionOutput> {
27
- const res = await request<AutomaticSpeechRecognitionOutput>(args, options);
27
+ const res = await request<AutomaticSpeechRecognitionOutput>(args, {
28
+ ...options,
29
+ taskHint: "automatic-speech-recognition",
30
+ });
28
31
  const isValidOutput = typeof res?.text === "string";
29
32
  if (!isValidOutput) {
30
33
  throw new InferenceOutputError("Expected {text: string}");
@@ -16,7 +16,10 @@ export type TextToSpeechOutput = Blob;
16
16
  * Recommended model: espnet/kan-bayashi_ljspeech_vits
17
17
  */
18
18
  export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<TextToSpeechOutput> {
19
- const res = await request<TextToSpeechOutput>(args, options);
19
+ const res = await request<TextToSpeechOutput>(args, {
20
+ ...options,
21
+ taskHint: "text-to-speech",
22
+ });
20
23
  const isValidOutput = res && res instanceof Blob;
21
24
  if (!isValidOutput) {
22
25
  throw new InferenceOutputError("Expected Blob");
@@ -11,9 +11,11 @@ export async function request<T>(
11
11
  includeCredentials?: boolean;
12
12
  /** When a model can be used for multiple tasks, and we want to run a non-default task */
13
13
  task?: string | InferenceTask;
14
+ /** To load default model if needed */
15
+ taskHint?: InferenceTask;
14
16
  }
15
17
  ): Promise<T> {
16
- const { url, info } = makeRequestOptions(args, options);
18
+ const { url, info } = await makeRequestOptions(args, options);
17
19
  const response = await (options?.fetch ?? fetch)(url, info);
18
20
 
19
21
  if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
@@ -13,9 +13,11 @@ export async function* streamingRequest<T>(
13
13
  includeCredentials?: boolean;
14
14
  /** When a model can be used for multiple tasks, and we want to run a non-default task */
15
15
  task?: string | InferenceTask;
16
+ /** To load default model if needed */
17
+ taskHint?: InferenceTask;
16
18
  }
17
19
  ): AsyncGenerator<T> {
18
- const { url, info } = makeRequestOptions({ ...args, stream: true }, options);
20
+ const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
19
21
  const response = await (options?.fetch ?? fetch)(url, info);
20
22
 
21
23
  if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
@@ -30,7 +30,10 @@ export async function imageClassification(
30
30
  args: ImageClassificationArgs,
31
31
  options?: Options
32
32
  ): Promise<ImageClassificationOutput> {
33
- const res = await request<ImageClassificationOutput>(args, options);
33
+ const res = await request<ImageClassificationOutput>(args, {
34
+ ...options,
35
+ taskHint: "image-classification",
36
+ });
34
37
  const isValidOutput =
35
38
  Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
36
39
  if (!isValidOutput) {
@@ -34,7 +34,10 @@ export async function imageSegmentation(
34
34
  args: ImageSegmentationArgs,
35
35
  options?: Options
36
36
  ): Promise<ImageSegmentationOutput> {
37
- const res = await request<ImageSegmentationOutput>(args, options);
37
+ const res = await request<ImageSegmentationOutput>(args, {
38
+ ...options,
39
+ taskHint: "image-segmentation",
40
+ });
38
41
  const isValidOutput =
39
42
  Array.isArray(res) &&
40
43
  res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
@@ -74,7 +74,10 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P
74
74
  ),
75
75
  };
76
76
  }
77
- const res = await request<ImageToImageOutput>(reqArgs, options);
77
+ const res = await request<ImageToImageOutput>(reqArgs, {
78
+ ...options,
79
+ taskHint: "image-to-image",
80
+ });
78
81
  const isValidOutput = res && res instanceof Blob;
79
82
  if (!isValidOutput) {
80
83
  throw new InferenceOutputError("Expected Blob");