@huggingface/inference 2.5.2 → 2.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. package/dist/index.d.ts +48 -2
  2. package/dist/index.js +169 -79
  3. package/dist/index.mjs +169 -79
  4. package/package.json +1 -1
  5. package/src/lib/getDefaultTask.ts +1 -1
  6. package/src/lib/makeRequestOptions.ts +34 -5
  7. package/src/tasks/audio/audioClassification.ts +4 -1
  8. package/src/tasks/audio/audioToAudio.ts +4 -1
  9. package/src/tasks/audio/automaticSpeechRecognition.ts +4 -1
  10. package/src/tasks/audio/textToSpeech.ts +4 -1
  11. package/src/tasks/custom/request.ts +3 -1
  12. package/src/tasks/custom/streamingRequest.ts +3 -1
  13. package/src/tasks/cv/imageClassification.ts +4 -1
  14. package/src/tasks/cv/imageSegmentation.ts +4 -1
  15. package/src/tasks/cv/imageToImage.ts +4 -1
  16. package/src/tasks/cv/imageToText.ts +6 -1
  17. package/src/tasks/cv/objectDetection.ts +4 -1
  18. package/src/tasks/cv/textToImage.ts +4 -1
  19. package/src/tasks/cv/zeroShotImageClassification.ts +4 -1
  20. package/src/tasks/multimodal/documentQuestionAnswering.ts +4 -1
  21. package/src/tasks/multimodal/visualQuestionAnswering.ts +6 -1
  22. package/src/tasks/nlp/conversational.ts +3 -3
  23. package/src/tasks/nlp/featureExtraction.ts +7 -10
  24. package/src/tasks/nlp/fillMask.ts +4 -1
  25. package/src/tasks/nlp/questionAnswering.ts +4 -1
  26. package/src/tasks/nlp/sentenceSimilarity.ts +6 -10
  27. package/src/tasks/nlp/summarization.ts +4 -1
  28. package/src/tasks/nlp/tableQuestionAnswering.ts +4 -1
  29. package/src/tasks/nlp/textClassification.ts +6 -1
  30. package/src/tasks/nlp/textGeneration.ts +4 -1
  31. package/src/tasks/nlp/textGenerationStream.ts +4 -1
  32. package/src/tasks/nlp/tokenClassification.ts +6 -1
  33. package/src/tasks/nlp/translation.ts +4 -1
  34. package/src/tasks/nlp/zeroShotClassification.ts +4 -1
  35. package/src/tasks/tabular/tabularClassification.ts +4 -1
  36. package/src/tasks/tabular/tabularRegression.ts +4 -1
  37. package/src/types.ts +36 -2
package/dist/index.d.ts CHANGED
@@ -26,7 +26,39 @@ export interface Options {
26
26
  fetch?: typeof fetch;
27
27
  }
28
28
 
29
- export type InferenceTask = "text-classification" | "feature-extraction" | "sentence-similarity";
29
+ export type InferenceTask =
30
+ | "audio-classification"
31
+ | "audio-to-audio"
32
+ | "automatic-speech-recognition"
33
+ | "conversational"
34
+ | "depth-estimation"
35
+ | "document-question-answering"
36
+ | "feature-extraction"
37
+ | "fill-mask"
38
+ | "image-classification"
39
+ | "image-segmentation"
40
+ | "image-to-image"
41
+ | "image-to-text"
42
+ | "object-detection"
43
+ | "video-classification"
44
+ | "question-answering"
45
+ | "reinforcement-learning"
46
+ | "sentence-similarity"
47
+ | "summarization"
48
+ | "table-question-answering"
49
+ | "tabular-classification"
50
+ | "tabular-regression"
51
+ | "text-classification"
52
+ | "text-generation"
53
+ | "text-to-image"
54
+ | "text-to-speech"
55
+ | "text-to-video"
56
+ | "token-classification"
57
+ | "translation"
58
+ | "unconditional-image-generation"
59
+ | "visual-question-answering"
60
+ | "zero-shot-classification"
61
+ | "zero-shot-image-classification";
30
62
 
31
63
  export interface BaseArgs {
32
64
  /**
@@ -37,8 +69,10 @@ export interface BaseArgs {
37
69
  accessToken?: string;
38
70
  /**
39
71
  * The model to use. Can be a full URL for HF inference endpoints.
72
+ *
73
+ * If not specified, will call huggingface.co/api/tasks to get the default model for the task.
40
74
  */
41
- model: string;
75
+ model?: string;
42
76
  }
43
77
 
44
78
  export type RequestArgs = BaseArgs &
@@ -144,6 +178,8 @@ export function request<T>(
144
178
  includeCredentials?: boolean;
145
179
  /** When a model can be used for multiple tasks, and we want to run a non-default task */
146
180
  task?: string | InferenceTask;
181
+ /** To load default model if needed */
182
+ taskHint?: InferenceTask;
147
183
  }
148
184
  ): Promise<T>;
149
185
  /**
@@ -156,6 +192,8 @@ export function streamingRequest<T>(
156
192
  includeCredentials?: boolean;
157
193
  /** When a model can be used for multiple tasks, and we want to run a non-default task */
158
194
  task?: string | InferenceTask;
195
+ /** To load default model if needed */
196
+ taskHint?: InferenceTask;
159
197
  }
160
198
  ): AsyncGenerator<T>;
161
199
  export type ImageClassificationArgs = BaseArgs & {
@@ -1020,6 +1058,8 @@ export class HfInference {
1020
1058
  includeCredentials?: boolean;
1021
1059
  /** When a model can be used for multiple tasks, and we want to run a non-default task */
1022
1060
  task?: string | InferenceTask;
1061
+ /** To load default model if needed */
1062
+ taskHint?: InferenceTask;
1023
1063
  }
1024
1064
  ): Promise<T>;
1025
1065
  /**
@@ -1032,6 +1072,8 @@ export class HfInference {
1032
1072
  includeCredentials?: boolean;
1033
1073
  /** When a model can be used for multiple tasks, and we want to run a non-default task */
1034
1074
  task?: string | InferenceTask;
1075
+ /** To load default model if needed */
1076
+ taskHint?: InferenceTask;
1035
1077
  }
1036
1078
  ): AsyncGenerator<T>;
1037
1079
  /**
@@ -1225,6 +1267,8 @@ export class HfInferenceEndpoint {
1225
1267
  includeCredentials?: boolean;
1226
1268
  /** When a model can be used for multiple tasks, and we want to run a non-default task */
1227
1269
  task?: string | InferenceTask;
1270
+ /** To load default model if needed */
1271
+ taskHint?: InferenceTask;
1228
1272
  }
1229
1273
  ): Promise<T>;
1230
1274
  /**
@@ -1237,6 +1281,8 @@ export class HfInferenceEndpoint {
1237
1281
  includeCredentials?: boolean;
1238
1282
  /** When a model can be used for multiple tasks, and we want to run a non-default task */
1239
1283
  task?: string | InferenceTask;
1284
+ /** To load default model if needed */
1285
+ taskHint?: InferenceTask;
1240
1286
  }
1241
1287
  ): AsyncGenerator<T>;
1242
1288
  /**
package/dist/index.js CHANGED
@@ -97,15 +97,63 @@ function isUrl(modelOrUrl) {
97
97
  return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
98
98
  }
99
99
 
100
+ // src/lib/getDefaultTask.ts
101
+ var taskCache = /* @__PURE__ */ new Map();
102
+ var CACHE_DURATION = 10 * 60 * 1e3;
103
+ var MAX_CACHE_ITEMS = 1e3;
104
+ var HF_HUB_URL = "https://huggingface.co";
105
+ async function getDefaultTask(model, accessToken) {
106
+ if (isUrl(model)) {
107
+ return null;
108
+ }
109
+ const key = `${model}:${accessToken}`;
110
+ let cachedTask = taskCache.get(key);
111
+ if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
112
+ taskCache.delete(key);
113
+ cachedTask = void 0;
114
+ }
115
+ if (cachedTask === void 0) {
116
+ const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
117
+ headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
118
+ }).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
119
+ if (!modelTask) {
120
+ return null;
121
+ }
122
+ cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
123
+ taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
124
+ if (taskCache.size > MAX_CACHE_ITEMS) {
125
+ taskCache.delete(taskCache.keys().next().value);
126
+ }
127
+ }
128
+ return cachedTask.task;
129
+ }
130
+
100
131
  // src/lib/makeRequestOptions.ts
101
132
  var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
102
- function makeRequestOptions(args, options) {
103
- const { model, accessToken, ...otherArgs } = args;
104
- const { task, includeCredentials, ...otherOptions } = options ?? {};
133
+ var tasks = null;
134
+ async function makeRequestOptions(args, options) {
135
+ const { accessToken, model: _model, ...otherArgs } = args;
136
+ let { model } = args;
137
+ const { forceTask: task, includeCredentials, taskHint, ...otherOptions } = options ?? {};
105
138
  const headers = {};
106
139
  if (accessToken) {
107
140
  headers["Authorization"] = `Bearer ${accessToken}`;
108
141
  }
142
+ if (!model && !tasks && taskHint) {
143
+ const res = await fetch(`${HF_HUB_URL}/api/tasks`);
144
+ if (res.ok) {
145
+ tasks = await res.json();
146
+ }
147
+ }
148
+ if (!model && tasks && taskHint) {
149
+ const taskInfo = tasks[taskHint];
150
+ if (taskInfo) {
151
+ model = taskInfo.models[0].id;
152
+ }
153
+ }
154
+ if (!model) {
155
+ throw new Error("No model provided, and no default model found for this task");
156
+ }
109
157
  const binary = "data" in args && !!args.data;
110
158
  if (!binary) {
111
159
  headers["Content-Type"] = "application/json";
@@ -143,7 +191,7 @@ function makeRequestOptions(args, options) {
143
191
 
144
192
  // src/tasks/custom/request.ts
145
193
  async function request(args, options) {
146
- const { url, info } = makeRequestOptions(args, options);
194
+ const { url, info } = await makeRequestOptions(args, options);
147
195
  const response = await (options?.fetch ?? fetch)(url, info);
148
196
  if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
149
197
  return request(args, {
@@ -267,7 +315,7 @@ function newMessage() {
267
315
 
268
316
  // src/tasks/custom/streamingRequest.ts
269
317
  async function* streamingRequest(args, options) {
270
- const { url, info } = makeRequestOptions({ ...args, stream: true }, options);
318
+ const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
271
319
  const response = await (options?.fetch ?? fetch)(url, info);
272
320
  if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
273
321
  return streamingRequest(args, {
@@ -340,7 +388,10 @@ var InferenceOutputError = class extends TypeError {
340
388
 
341
389
  // src/tasks/audio/audioClassification.ts
342
390
  async function audioClassification(args, options) {
343
- const res = await request(args, options);
391
+ const res = await request(args, {
392
+ ...options,
393
+ taskHint: "audio-classification"
394
+ });
344
395
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
345
396
  if (!isValidOutput) {
346
397
  throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
@@ -350,7 +401,10 @@ async function audioClassification(args, options) {
350
401
 
351
402
  // src/tasks/audio/automaticSpeechRecognition.ts
352
403
  async function automaticSpeechRecognition(args, options) {
353
- const res = await request(args, options);
404
+ const res = await request(args, {
405
+ ...options,
406
+ taskHint: "automatic-speech-recognition"
407
+ });
354
408
  const isValidOutput = typeof res?.text === "string";
355
409
  if (!isValidOutput) {
356
410
  throw new InferenceOutputError("Expected {text: string}");
@@ -360,7 +414,10 @@ async function automaticSpeechRecognition(args, options) {
360
414
 
361
415
  // src/tasks/audio/textToSpeech.ts
362
416
  async function textToSpeech(args, options) {
363
- const res = await request(args, options);
417
+ const res = await request(args, {
418
+ ...options,
419
+ taskHint: "text-to-speech"
420
+ });
364
421
  const isValidOutput = res && res instanceof Blob;
365
422
  if (!isValidOutput) {
366
423
  throw new InferenceOutputError("Expected Blob");
@@ -370,7 +427,10 @@ async function textToSpeech(args, options) {
370
427
 
371
428
  // src/tasks/audio/audioToAudio.ts
372
429
  async function audioToAudio(args, options) {
373
- const res = await request(args, options);
430
+ const res = await request(args, {
431
+ ...options,
432
+ taskHint: "audio-to-audio"
433
+ });
374
434
  const isValidOutput = Array.isArray(res) && res.every(
375
435
  (x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string"
376
436
  );
@@ -382,7 +442,10 @@ async function audioToAudio(args, options) {
382
442
 
383
443
  // src/tasks/cv/imageClassification.ts
384
444
  async function imageClassification(args, options) {
385
- const res = await request(args, options);
445
+ const res = await request(args, {
446
+ ...options,
447
+ taskHint: "image-classification"
448
+ });
386
449
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
387
450
  if (!isValidOutput) {
388
451
  throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
@@ -392,7 +455,10 @@ async function imageClassification(args, options) {
392
455
 
393
456
  // src/tasks/cv/imageSegmentation.ts
394
457
  async function imageSegmentation(args, options) {
395
- const res = await request(args, options);
458
+ const res = await request(args, {
459
+ ...options,
460
+ taskHint: "image-segmentation"
461
+ });
396
462
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
397
463
  if (!isValidOutput) {
398
464
  throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
@@ -402,7 +468,10 @@ async function imageSegmentation(args, options) {
402
468
 
403
469
  // src/tasks/cv/imageToText.ts
404
470
  async function imageToText(args, options) {
405
- const res = (await request(args, options))?.[0];
471
+ const res = (await request(args, {
472
+ ...options,
473
+ taskHint: "image-to-text"
474
+ }))?.[0];
406
475
  if (typeof res?.generated_text !== "string") {
407
476
  throw new InferenceOutputError("Expected {generated_text: string}");
408
477
  }
@@ -411,7 +480,10 @@ async function imageToText(args, options) {
411
480
 
412
481
  // src/tasks/cv/objectDetection.ts
413
482
  async function objectDetection(args, options) {
414
- const res = await request(args, options);
483
+ const res = await request(args, {
484
+ ...options,
485
+ taskHint: "object-detection"
486
+ });
415
487
  const isValidOutput = Array.isArray(res) && res.every(
416
488
  (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
417
489
  );
@@ -425,7 +497,10 @@ async function objectDetection(args, options) {
425
497
 
426
498
  // src/tasks/cv/textToImage.ts
427
499
  async function textToImage(args, options) {
428
- const res = await request(args, options);
500
+ const res = await request(args, {
501
+ ...options,
502
+ taskHint: "text-to-image"
503
+ });
429
504
  const isValidOutput = res && res instanceof Blob;
430
505
  if (!isValidOutput) {
431
506
  throw new InferenceOutputError("Expected Blob");
@@ -467,7 +542,10 @@ async function imageToImage(args, options) {
467
542
  )
468
543
  };
469
544
  }
470
- const res = await request(reqArgs, options);
545
+ const res = await request(reqArgs, {
546
+ ...options,
547
+ taskHint: "image-to-image"
548
+ });
471
549
  const isValidOutput = res && res instanceof Blob;
472
550
  if (!isValidOutput) {
473
551
  throw new InferenceOutputError("Expected Blob");
@@ -487,7 +565,10 @@ async function zeroShotImageClassification(args, options) {
487
565
  )
488
566
  }
489
567
  };
490
- const res = await request(reqArgs, options);
568
+ const res = await request(reqArgs, {
569
+ ...options,
570
+ taskHint: "zero-shot-image-classification"
571
+ });
491
572
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
492
573
  if (!isValidOutput) {
493
574
  throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
@@ -497,8 +578,8 @@ async function zeroShotImageClassification(args, options) {
497
578
 
498
579
  // src/tasks/nlp/conversational.ts
499
580
  async function conversational(args, options) {
500
- const res = await request(args, options);
501
- const isValidOutput = Array.isArray(res.conversation.generated_responses) && res.conversation.generated_responses.every((x) => typeof x === "string") && Array.isArray(res.conversation.past_user_inputs) && res.conversation.past_user_inputs.every((x) => typeof x === "string") && typeof res.generated_text === "string" && Array.isArray(res.warnings) && res.warnings.every((x) => typeof x === "string");
581
+ const res = await request(args, { ...options, taskHint: "conversational" });
582
+ const isValidOutput = Array.isArray(res.conversation.generated_responses) && res.conversation.generated_responses.every((x) => typeof x === "string") && Array.isArray(res.conversation.past_user_inputs) && res.conversation.past_user_inputs.every((x) => typeof x === "string") && typeof res.generated_text === "string" && (typeof res.warnings === "undefined" || Array.isArray(res.warnings) && res.warnings.every((x) => typeof x === "string"));
502
583
  if (!isValidOutput) {
503
584
  throw new InferenceOutputError(
504
585
  "Expected {conversation: {generated_responses: string[], past_user_inputs: string[]}, generated_text: string, warnings: string[]}"
@@ -507,47 +588,14 @@ async function conversational(args, options) {
507
588
  return res;
508
589
  }
509
590
 
510
- // src/lib/getDefaultTask.ts
511
- var taskCache = /* @__PURE__ */ new Map();
512
- var CACHE_DURATION = 10 * 60 * 1e3;
513
- var MAX_CACHE_ITEMS = 1e3;
514
- var HF_HUB_URL = "https://huggingface.co";
515
- async function getDefaultTask(model, accessToken) {
516
- if (isUrl(model)) {
517
- return null;
518
- }
519
- const key = `${model}:${accessToken}`;
520
- let cachedTask = taskCache.get(key);
521
- if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
522
- taskCache.delete(key);
523
- cachedTask = void 0;
524
- }
525
- if (cachedTask === void 0) {
526
- const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
527
- headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
528
- }).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
529
- if (!modelTask) {
530
- return null;
531
- }
532
- cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
533
- taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
534
- if (taskCache.size > MAX_CACHE_ITEMS) {
535
- taskCache.delete(taskCache.keys().next().value);
536
- }
537
- }
538
- return cachedTask.task;
539
- }
540
-
541
591
  // src/tasks/nlp/featureExtraction.ts
542
592
  async function featureExtraction(args, options) {
543
- const defaultTask = await getDefaultTask(args.model, args.accessToken);
544
- const res = await request(
545
- args,
546
- defaultTask === "sentence-similarity" ? {
547
- ...options,
548
- task: "feature-extraction"
549
- } : options
550
- );
593
+ const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : void 0;
594
+ const res = await request(args, {
595
+ ...options,
596
+ taskHint: "feature-extraction",
597
+ ...defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" }
598
+ });
551
599
  let isValidOutput = true;
552
600
  const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
553
601
  if (curDepth > maxDepth)
@@ -567,7 +615,10 @@ async function featureExtraction(args, options) {
567
615
 
568
616
  // src/tasks/nlp/fillMask.ts
569
617
  async function fillMask(args, options) {
570
- const res = await request(args, options);
618
+ const res = await request(args, {
619
+ ...options,
620
+ taskHint: "fill-mask"
621
+ });
571
622
  const isValidOutput = Array.isArray(res) && res.every(
572
623
  (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
573
624
  );
@@ -581,7 +632,10 @@ async function fillMask(args, options) {
581
632
 
582
633
  // src/tasks/nlp/questionAnswering.ts
583
634
  async function questionAnswering(args, options) {
584
- const res = await request(args, options);
635
+ const res = await request(args, {
636
+ ...options,
637
+ taskHint: "question-answering"
638
+ });
585
639
  const isValidOutput = typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";
586
640
  if (!isValidOutput) {
587
641
  throw new InferenceOutputError("Expected {answer: string, end: number, score: number, start: number}");
@@ -591,14 +645,12 @@ async function questionAnswering(args, options) {
591
645
 
592
646
  // src/tasks/nlp/sentenceSimilarity.ts
593
647
  async function sentenceSimilarity(args, options) {
594
- const defaultTask = await getDefaultTask(args.model, args.accessToken);
595
- const res = await request(
596
- args,
597
- defaultTask === "feature-extraction" ? {
598
- ...options,
599
- task: "sentence-similarity"
600
- } : options
601
- );
648
+ const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : void 0;
649
+ const res = await request(args, {
650
+ ...options,
651
+ taskHint: "sentence-similarity",
652
+ ...defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }
653
+ });
602
654
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
603
655
  if (!isValidOutput) {
604
656
  throw new InferenceOutputError("Expected number[]");
@@ -608,7 +660,10 @@ async function sentenceSimilarity(args, options) {
608
660
 
609
661
  // src/tasks/nlp/summarization.ts
610
662
  async function summarization(args, options) {
611
- const res = await request(args, options);
663
+ const res = await request(args, {
664
+ ...options,
665
+ taskHint: "summarization"
666
+ });
612
667
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string");
613
668
  if (!isValidOutput) {
614
669
  throw new InferenceOutputError("Expected Array<{summary_text: string}>");
@@ -618,7 +673,10 @@ async function summarization(args, options) {
618
673
 
619
674
  // src/tasks/nlp/tableQuestionAnswering.ts
620
675
  async function tableQuestionAnswering(args, options) {
621
- const res = await request(args, options);
676
+ const res = await request(args, {
677
+ ...options,
678
+ taskHint: "table-question-answering"
679
+ });
622
680
  const isValidOutput = typeof res?.aggregator === "string" && typeof res.answer === "string" && Array.isArray(res.cells) && res.cells.every((x) => typeof x === "string") && Array.isArray(res.coordinates) && res.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number"));
623
681
  if (!isValidOutput) {
624
682
  throw new InferenceOutputError(
@@ -630,7 +688,10 @@ async function tableQuestionAnswering(args, options) {
630
688
 
631
689
  // src/tasks/nlp/textClassification.ts
632
690
  async function textClassification(args, options) {
633
- const res = (await request(args, options))?.[0];
691
+ const res = (await request(args, {
692
+ ...options,
693
+ taskHint: "text-classification"
694
+ }))?.[0];
634
695
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number");
635
696
  if (!isValidOutput) {
636
697
  throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
@@ -640,7 +701,10 @@ async function textClassification(args, options) {
640
701
 
641
702
  // src/tasks/nlp/textGeneration.ts
642
703
  async function textGeneration(args, options) {
643
- const res = await request(args, options);
704
+ const res = await request(args, {
705
+ ...options,
706
+ taskHint: "text-generation"
707
+ });
644
708
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string");
645
709
  if (!isValidOutput) {
646
710
  throw new InferenceOutputError("Expected Array<{generated_text: string}>");
@@ -650,7 +714,10 @@ async function textGeneration(args, options) {
650
714
 
651
715
  // src/tasks/nlp/textGenerationStream.ts
652
716
  async function* textGenerationStream(args, options) {
653
- yield* streamingRequest(args, options);
717
+ yield* streamingRequest(args, {
718
+ ...options,
719
+ taskHint: "text-generation"
720
+ });
654
721
  }
655
722
 
656
723
  // src/utils/toArray.ts
@@ -663,7 +730,12 @@ function toArray(obj) {
663
730
 
664
731
  // src/tasks/nlp/tokenClassification.ts
665
732
  async function tokenClassification(args, options) {
666
- const res = toArray(await request(args, options));
733
+ const res = toArray(
734
+ await request(args, {
735
+ ...options,
736
+ taskHint: "token-classification"
737
+ })
738
+ );
667
739
  const isValidOutput = Array.isArray(res) && res.every(
668
740
  (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
669
741
  );
@@ -677,7 +749,10 @@ async function tokenClassification(args, options) {
677
749
 
678
750
  // src/tasks/nlp/translation.ts
679
751
  async function translation(args, options) {
680
- const res = await request(args, options);
752
+ const res = await request(args, {
753
+ ...options,
754
+ taskHint: "translation"
755
+ });
681
756
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");
682
757
  if (!isValidOutput) {
683
758
  throw new InferenceOutputError("Expected type Array<{translation_text: string}>");
@@ -688,7 +763,10 @@ async function translation(args, options) {
688
763
  // src/tasks/nlp/zeroShotClassification.ts
689
764
  async function zeroShotClassification(args, options) {
690
765
  const res = toArray(
691
- await request(args, options)
766
+ await request(args, {
767
+ ...options,
768
+ taskHint: "zero-shot-classification"
769
+ })
692
770
  );
693
771
  const isValidOutput = Array.isArray(res) && res.every(
694
772
  (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
@@ -714,7 +792,10 @@ async function documentQuestionAnswering(args, options) {
714
792
  }
715
793
  };
716
794
  const res = toArray(
717
- await request(reqArgs, options)
795
+ await request(reqArgs, {
796
+ ...options,
797
+ taskHint: "document-question-answering"
798
+ })
718
799
  )?.[0];
719
800
  const isValidOutput = typeof res?.answer === "string" && (typeof res.end === "number" || typeof res.end === "undefined") && (typeof res.score === "number" || typeof res.score === "undefined") && (typeof res.start === "number" || typeof res.start === "undefined");
720
801
  if (!isValidOutput) {
@@ -737,7 +818,10 @@ async function visualQuestionAnswering(args, options) {
737
818
  )
738
819
  }
739
820
  };
740
- const res = (await request(reqArgs, options))?.[0];
821
+ const res = (await request(reqArgs, {
822
+ ...options,
823
+ taskHint: "visual-question-answering"
824
+ }))?.[0];
741
825
  const isValidOutput = typeof res?.answer === "string" && typeof res.score === "number";
742
826
  if (!isValidOutput) {
743
827
  throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
@@ -747,7 +831,10 @@ async function visualQuestionAnswering(args, options) {
747
831
 
748
832
  // src/tasks/tabular/tabularRegression.ts
749
833
  async function tabularRegression(args, options) {
750
- const res = await request(args, options);
834
+ const res = await request(args, {
835
+ ...options,
836
+ taskHint: "tabular-regression"
837
+ });
751
838
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
752
839
  if (!isValidOutput) {
753
840
  throw new InferenceOutputError("Expected number[]");
@@ -757,7 +844,10 @@ async function tabularRegression(args, options) {
757
844
 
758
845
  // src/tasks/tabular/tabularClassification.ts
759
846
  async function tabularClassification(args, options) {
760
- const res = await request(args, options);
847
+ const res = await request(args, {
848
+ ...options,
849
+ taskHint: "tabular-classification"
850
+ });
761
851
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
762
852
  if (!isValidOutput) {
763
853
  throw new InferenceOutputError("Expected number[]");