@huggingface/inference 3.10.0 → 3.12.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. package/dist/index.cjs +713 -643
  2. package/dist/index.js +712 -643
  3. package/dist/src/InferenceClient.d.ts +16 -17
  4. package/dist/src/InferenceClient.d.ts.map +1 -1
  5. package/dist/src/lib/getInferenceProviderMapping.d.ts +5 -1
  6. package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -1
  7. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  8. package/dist/src/providers/providerHelper.d.ts +1 -1
  9. package/dist/src/providers/providerHelper.d.ts.map +1 -1
  10. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  11. package/dist/src/tasks/audio/audioToAudio.d.ts.map +1 -1
  12. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  13. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  14. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  15. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  16. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  17. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  18. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  19. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  20. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  21. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  22. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
  23. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  24. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  25. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  26. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  27. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
  28. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  29. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  30. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  31. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  32. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  33. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  34. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  35. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  36. package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
  37. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  38. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  39. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  40. package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
  41. package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
  42. package/dist/src/types.d.ts +6 -4
  43. package/dist/src/types.d.ts.map +1 -1
  44. package/dist/src/utils/typedEntries.d.ts +4 -0
  45. package/dist/src/utils/typedEntries.d.ts.map +1 -0
  46. package/package.json +3 -3
  47. package/src/InferenceClient.ts +32 -43
  48. package/src/lib/getInferenceProviderMapping.ts +68 -19
  49. package/src/lib/makeRequestOptions.ts +4 -3
  50. package/src/providers/hf-inference.ts +1 -1
  51. package/src/providers/providerHelper.ts +1 -1
  52. package/src/snippets/getInferenceSnippets.ts +1 -1
  53. package/src/tasks/audio/audioClassification.ts +3 -1
  54. package/src/tasks/audio/audioToAudio.ts +4 -1
  55. package/src/tasks/audio/automaticSpeechRecognition.ts +3 -1
  56. package/src/tasks/audio/textToSpeech.ts +2 -1
  57. package/src/tasks/custom/request.ts +3 -1
  58. package/src/tasks/custom/streamingRequest.ts +3 -1
  59. package/src/tasks/cv/imageClassification.ts +3 -1
  60. package/src/tasks/cv/imageSegmentation.ts +3 -1
  61. package/src/tasks/cv/imageToImage.ts +3 -1
  62. package/src/tasks/cv/imageToText.ts +3 -1
  63. package/src/tasks/cv/objectDetection.ts +3 -1
  64. package/src/tasks/cv/textToImage.ts +2 -1
  65. package/src/tasks/cv/textToVideo.ts +2 -1
  66. package/src/tasks/cv/zeroShotImageClassification.ts +3 -1
  67. package/src/tasks/multimodal/documentQuestionAnswering.ts +3 -1
  68. package/src/tasks/multimodal/visualQuestionAnswering.ts +3 -1
  69. package/src/tasks/nlp/chatCompletion.ts +3 -1
  70. package/src/tasks/nlp/chatCompletionStream.ts +3 -1
  71. package/src/tasks/nlp/featureExtraction.ts +3 -1
  72. package/src/tasks/nlp/fillMask.ts +3 -1
  73. package/src/tasks/nlp/questionAnswering.ts +4 -1
  74. package/src/tasks/nlp/sentenceSimilarity.ts +3 -1
  75. package/src/tasks/nlp/summarization.ts +3 -1
  76. package/src/tasks/nlp/tableQuestionAnswering.ts +3 -1
  77. package/src/tasks/nlp/textClassification.ts +3 -1
  78. package/src/tasks/nlp/textGeneration.ts +3 -1
  79. package/src/tasks/nlp/textGenerationStream.ts +3 -1
  80. package/src/tasks/nlp/tokenClassification.ts +3 -1
  81. package/src/tasks/nlp/translation.ts +3 -1
  82. package/src/tasks/nlp/zeroShotClassification.ts +3 -1
  83. package/src/tasks/tabular/tabularClassification.ts +3 -1
  84. package/src/tasks/tabular/tabularRegression.ts +3 -1
  85. package/src/types.ts +8 -4
  86. package/src/utils/typedEntries.ts +5 -0
@@ -1,8 +1,8 @@
1
1
  import type { WidgetType } from "@huggingface/tasks";
2
- import type { InferenceProvider, ModelId } from "../types";
3
2
  import { HF_HUB_URL } from "../config";
4
3
  import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts";
5
4
  import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference";
5
+ import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../types";
6
6
  import { typedInclude } from "../utils/typedInclude";
7
7
 
8
8
  export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
@@ -20,44 +20,62 @@ export interface InferenceProviderModelMapping {
20
20
  task: WidgetType;
21
21
  }
22
22
 
23
- export async function getInferenceProviderMapping(
24
- params: {
25
- accessToken?: string;
26
- modelId: ModelId;
27
- provider: InferenceProvider;
28
- task: WidgetType;
29
- },
30
- options: {
23
+ export async function fetchInferenceProviderMappingForModel(
24
+ modelId: ModelId,
25
+ accessToken?: string,
26
+ options?: {
31
27
  fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
32
28
  }
33
- ): Promise<InferenceProviderModelMapping | null> {
34
- if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
35
- return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
36
- }
29
+ ): Promise<InferenceProviderMapping> {
37
30
  let inferenceProviderMapping: InferenceProviderMapping | null;
38
- if (inferenceProviderMappingCache.has(params.modelId)) {
31
+ if (inferenceProviderMappingCache.has(modelId)) {
39
32
  // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
40
- inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId)!;
33
+ inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!;
41
34
  } else {
42
35
  const resp = await (options?.fetch ?? fetch)(
43
- `${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`,
36
+ `${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
44
37
  {
45
- headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {},
38
+ headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {},
46
39
  }
47
40
  );
48
41
  if (resp.status === 404) {
49
- throw new Error(`Model ${params.modelId} does not exist`);
42
+ throw new Error(`Model ${modelId} does not exist`);
50
43
  }
51
44
  inferenceProviderMapping = await resp
52
45
  .json()
53
46
  .then((json) => json.inferenceProviderMapping)
54
47
  .catch(() => null);
48
+
49
+ if (inferenceProviderMapping) {
50
+ inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
51
+ }
55
52
  }
56
53
 
57
54
  if (!inferenceProviderMapping) {
58
- throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`);
55
+ throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
59
56
  }
57
+ return inferenceProviderMapping;
58
+ }
60
59
 
60
+ export async function getInferenceProviderMapping(
61
+ params: {
62
+ accessToken?: string;
63
+ modelId: ModelId;
64
+ provider: InferenceProvider;
65
+ task: WidgetType;
66
+ },
67
+ options: {
68
+ fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
69
+ }
70
+ ): Promise<InferenceProviderModelMapping | null> {
71
+ if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
72
+ return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
73
+ }
74
+ const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
75
+ params.modelId,
76
+ params.accessToken,
77
+ options
78
+ );
61
79
  const providerMapping = inferenceProviderMapping[params.provider];
62
80
  if (providerMapping) {
63
81
  const equivalentTasks =
@@ -78,3 +96,34 @@ export async function getInferenceProviderMapping(
78
96
  }
79
97
  return null;
80
98
  }
99
+
100
+ export async function resolveProvider(
101
+ provider?: InferenceProviderOrPolicy,
102
+ modelId?: string,
103
+ endpointUrl?: string
104
+ ): Promise<InferenceProvider> {
105
+ if (endpointUrl) {
106
+ if (provider) {
107
+ throw new Error("Specifying both endpointUrl and provider is not supported.");
108
+ }
109
+ /// Defaulting to hf-inference helpers / API
110
+ return "hf-inference";
111
+ }
112
+ if (!provider) {
113
+ console.log(
114
+ "Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
115
+ );
116
+ provider = "auto";
117
+ }
118
+ if (provider === "auto") {
119
+ if (!modelId) {
120
+ throw new Error("Specifying a model is required when provider is 'auto'");
121
+ }
122
+ const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
123
+ provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider | undefined;
124
+ }
125
+ if (!provider) {
126
+ throw new Error(`No Inference Provider available for model ${modelId}.`);
127
+ }
128
+ return provider;
129
+ }
@@ -27,8 +27,8 @@ export async function makeRequestOptions(
27
27
  task?: InferenceTask;
28
28
  }
29
29
  ): Promise<{ url: string; info: RequestInit }> {
30
- const { provider: maybeProvider, model: maybeModel } = args;
31
- const provider = maybeProvider ?? "hf-inference";
30
+ const { model: maybeModel } = args;
31
+ const provider = providerHelper.provider;
32
32
  const { task } = options ?? {};
33
33
 
34
34
  // Validate inputs
@@ -113,8 +113,9 @@ export function makeRequestOptionsFromResolvedModel(
113
113
  ): { url: string; info: RequestInit } {
114
114
  const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
115
115
  void model;
116
+ void maybeProvider;
116
117
 
117
- const provider = maybeProvider ?? "hf-inference";
118
+ const provider = providerHelper.provider;
118
119
 
119
120
  const { includeCredentials, task, signal, billTo } = options ?? {};
120
121
  const authMethod = (() => {
@@ -106,7 +106,7 @@ export class HFInferenceTask extends TaskProviderHelper {
106
106
  makeRoute(params: UrlParams): string {
107
107
  if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
108
108
  // when deployed on hf-inference, those two tasks are automatically compatible with one another.
109
- return `pipeline/${params.task}/${params.model}`;
109
+ return `models/${params.model}/pipeline/${params.task}`;
110
110
  }
111
111
  return `models/${params.model}`;
112
112
  }
@@ -56,7 +56,7 @@ import { toArray } from "../utils/toArray";
56
56
  */
57
57
  export abstract class TaskProviderHelper {
58
58
  constructor(
59
- private provider: InferenceProvider,
59
+ readonly provider: InferenceProvider,
60
60
  private baseUrl: string,
61
61
  readonly clientSideRoutingOnly: boolean = false
62
62
  ) {}
@@ -272,7 +272,7 @@ const prepareConversationalInput = (
272
272
  return {
273
273
  messages: opts?.messages ?? getModelInputSnippet(model),
274
274
  ...(opts?.temperature ? { temperature: opts?.temperature } : undefined),
275
- max_tokens: opts?.max_tokens ?? 512,
275
+ ...(opts?.max_tokens ? { max_tokens: opts?.max_tokens } : undefined),
276
276
  ...(opts?.top_p ? { top_p: opts?.top_p } : undefined),
277
277
  };
278
278
  };
@@ -1,4 +1,5 @@
1
1
  import type { AudioClassificationInput, AudioClassificationOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -15,7 +16,8 @@ export async function audioClassification(
15
16
  args: AudioClassificationArgs,
16
17
  options?: Options
17
18
  ): Promise<AudioClassificationOutput> {
18
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
19
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
20
+ const providerHelper = getProviderHelper(provider, "audio-classification");
19
21
  const payload = preparePayload(args);
20
22
  const { data: res } = await innerRequest<AudioClassificationOutput>(payload, providerHelper, {
21
23
  ...options,
@@ -1,3 +1,4 @@
1
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
1
2
  import { getProviderHelper } from "../../lib/getProviderHelper";
2
3
  import type { BaseArgs, Options } from "../../types";
3
4
  import { innerRequest } from "../../utils/request";
@@ -36,7 +37,9 @@ export interface AudioToAudioOutput {
36
37
  * Example model: speechbrain/sepformer-wham does audio source separation.
37
38
  */
38
39
  export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
39
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
40
+ const model = "inputs" in args ? args.model : undefined;
41
+ const provider = await resolveProvider(args.provider, model);
42
+ const providerHelper = getProviderHelper(provider, "audio-to-audio");
40
43
  const payload = preparePayload(args);
41
44
  const { data: res } = await innerRequest<AudioToAudioOutput>(payload, providerHelper, {
42
45
  ...options,
@@ -1,4 +1,5 @@
1
1
  import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
4
5
  import { FAL_AI_SUPPORTED_BLOB_TYPES } from "../../providers/fal-ai";
@@ -18,7 +19,8 @@ export async function automaticSpeechRecognition(
18
19
  args: AutomaticSpeechRecognitionArgs,
19
20
  options?: Options
20
21
  ): Promise<AutomaticSpeechRecognitionOutput> {
21
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
22
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
23
+ const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
22
24
  const payload = await buildPayload(args);
23
25
  const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, providerHelper, {
24
26
  ...options,
@@ -1,4 +1,5 @@
1
1
  import type { TextToSpeechInput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -12,7 +13,7 @@ interface OutputUrlTextToSpeechGeneration {
12
13
  * Recommended model: espnet/kan-bayashi_ljspeech_vits
13
14
  */
14
15
  export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
15
- const provider = args.provider ?? "hf-inference";
16
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
16
17
  const providerHelper = getProviderHelper(provider, "text-to-speech");
17
18
  const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, providerHelper, {
18
19
  ...options,
@@ -1,3 +1,4 @@
1
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
1
2
  import { getProviderHelper } from "../../lib/getProviderHelper";
2
3
  import type { InferenceTask, Options, RequestArgs } from "../../types";
3
4
  import { innerRequest } from "../../utils/request";
@@ -16,7 +17,8 @@ export async function request<T>(
16
17
  console.warn(
17
18
  "The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
18
19
  );
19
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
20
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
21
+ const providerHelper = getProviderHelper(provider, options?.task);
20
22
  const result = await innerRequest<T>(args, providerHelper, options);
21
23
  return result.data;
22
24
  }
@@ -1,3 +1,4 @@
1
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
1
2
  import { getProviderHelper } from "../../lib/getProviderHelper";
2
3
  import type { InferenceTask, Options, RequestArgs } from "../../types";
3
4
  import { innerStreamingRequest } from "../../utils/request";
@@ -16,6 +17,7 @@ export async function* streamingRequest<T>(
16
17
  console.warn(
17
18
  "The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
18
19
  );
19
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
20
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
21
+ const providerHelper = getProviderHelper(provider, options?.task);
20
22
  yield* innerStreamingRequest(args, providerHelper, options);
21
23
  }
@@ -1,4 +1,5 @@
1
1
  import type { ImageClassificationInput, ImageClassificationOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -14,7 +15,8 @@ export async function imageClassification(
14
15
  args: ImageClassificationArgs,
15
16
  options?: Options
16
17
  ): Promise<ImageClassificationOutput> {
17
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
18
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
19
+ const providerHelper = getProviderHelper(provider, "image-classification");
18
20
  const payload = preparePayload(args);
19
21
  const { data: res } = await innerRequest<ImageClassificationOutput>(payload, providerHelper, {
20
22
  ...options,
@@ -1,4 +1,5 @@
1
1
  import type { ImageSegmentationInput, ImageSegmentationOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -14,7 +15,8 @@ export async function imageSegmentation(
14
15
  args: ImageSegmentationArgs,
15
16
  options?: Options
16
17
  ): Promise<ImageSegmentationOutput> {
17
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
18
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
19
+ const providerHelper = getProviderHelper(provider, "image-segmentation");
18
20
  const payload = preparePayload(args);
19
21
  const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, providerHelper, {
20
22
  ...options,
@@ -1,4 +1,5 @@
1
1
  import type { ImageToImageInput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options, RequestArgs } from "../../types";
4
5
  import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -11,7 +12,8 @@ export type ImageToImageArgs = BaseArgs & ImageToImageInput;
11
12
  * Recommended model: lllyasviel/sd-controlnet-depth
12
13
  */
13
14
  export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> {
14
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
15
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
16
+ const providerHelper = getProviderHelper(provider, "image-to-image");
15
17
  let reqArgs: RequestArgs;
16
18
  if (!args.parameters) {
17
19
  reqArgs = {
@@ -1,4 +1,5 @@
1
1
  import type { ImageToTextInput, ImageToTextOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -10,7 +11,8 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput);
10
11
  * This task reads some image input and outputs the text caption.
11
12
  */
12
13
  export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
13
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
14
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
15
+ const providerHelper = getProviderHelper(provider, "image-to-text");
14
16
  const payload = preparePayload(args);
15
17
  const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, providerHelper, {
16
18
  ...options,
@@ -1,4 +1,5 @@
1
1
  import type { ObjectDetectionInput, ObjectDetectionOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -11,7 +12,8 @@ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImage
11
12
  * Recommended model: facebook/detr-resnet-50
12
13
  */
13
14
  export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> {
14
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
15
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
16
+ const providerHelper = getProviderHelper(provider, "object-detection");
15
17
  const payload = preparePayload(args);
16
18
  const { data: res } = await innerRequest<ObjectDetectionOutput>(payload, providerHelper, {
17
19
  ...options,
@@ -1,4 +1,5 @@
1
1
  import type { TextToImageInput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import { makeRequestOptions } from "../../lib/makeRequestOptions";
4
5
  import type { BaseArgs, Options } from "../../types";
@@ -23,7 +24,7 @@ export async function textToImage(
23
24
  options?: TextToImageOptions & { outputType?: undefined | "blob" }
24
25
  ): Promise<Blob>;
25
26
  export async function textToImage(args: TextToImageArgs, options?: TextToImageOptions): Promise<Blob | string> {
26
- const provider = args.provider ?? "hf-inference";
27
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
27
28
  const providerHelper = getProviderHelper(provider, "text-to-image");
28
29
  const { data: res } = await innerRequest<Record<string, unknown>>(args, providerHelper, {
29
30
  ...options,
@@ -1,4 +1,5 @@
1
1
  import type { TextToVideoInput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import { makeRequestOptions } from "../../lib/makeRequestOptions";
4
5
  import type { FalAiQueueOutput } from "../../providers/fal-ai";
@@ -12,7 +13,7 @@ export type TextToVideoArgs = BaseArgs & TextToVideoInput;
12
13
  export type TextToVideoOutput = Blob;
13
14
 
14
15
  export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise<TextToVideoOutput> {
15
- const provider = args.provider ?? "hf-inference";
16
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
16
17
  const providerHelper = getProviderHelper(provider, "text-to-video");
17
18
  const { data: response } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(
18
19
  args,
@@ -1,4 +1,5 @@
1
1
  import type { ZeroShotImageClassificationInput, ZeroShotImageClassificationOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options, RequestArgs } from "../../types";
4
5
  import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -44,7 +45,8 @@ export async function zeroShotImageClassification(
44
45
  args: ZeroShotImageClassificationArgs,
45
46
  options?: Options
46
47
  ): Promise<ZeroShotImageClassificationOutput> {
47
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
48
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
49
+ const providerHelper = getProviderHelper(provider, "zero-shot-image-classification");
48
50
  const payload = await preparePayload(args);
49
51
  const { data: res } = await innerRequest<ZeroShotImageClassificationOutput>(payload, providerHelper, {
50
52
  ...options,
@@ -3,6 +3,7 @@ import type {
3
3
  DocumentQuestionAnsweringInputData,
4
4
  DocumentQuestionAnsweringOutput,
5
5
  } from "@huggingface/tasks";
6
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
6
7
  import { getProviderHelper } from "../../lib/getProviderHelper";
7
8
  import type { BaseArgs, Options, RequestArgs } from "../../types";
8
9
  import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -19,7 +20,8 @@ export async function documentQuestionAnswering(
19
20
  args: DocumentQuestionAnsweringArgs,
20
21
  options?: Options
21
22
  ): Promise<DocumentQuestionAnsweringOutput[number]> {
22
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering");
23
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
24
+ const providerHelper = getProviderHelper(provider, "document-question-answering");
23
25
  const reqArgs: RequestArgs = {
24
26
  ...args,
25
27
  inputs: {
@@ -3,6 +3,7 @@ import type {
3
3
  VisualQuestionAnsweringInputData,
4
4
  VisualQuestionAnsweringOutput,
5
5
  } from "@huggingface/tasks";
6
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
6
7
  import { getProviderHelper } from "../../lib/getProviderHelper";
7
8
  import type { BaseArgs, Options, RequestArgs } from "../../types";
8
9
  import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -19,7 +20,8 @@ export async function visualQuestionAnswering(
19
20
  args: VisualQuestionAnsweringArgs,
20
21
  options?: Options
21
22
  ): Promise<VisualQuestionAnsweringOutput[number]> {
22
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "visual-question-answering");
23
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
24
+ const providerHelper = getProviderHelper(provider, "visual-question-answering");
23
25
  const reqArgs: RequestArgs = {
24
26
  ...args,
25
27
  inputs: {
@@ -1,4 +1,5 @@
1
1
  import type { ChatCompletionInput, ChatCompletionOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -10,7 +11,8 @@ export async function chatCompletion(
10
11
  args: BaseArgs & ChatCompletionInput,
11
12
  options?: Options
12
13
  ): Promise<ChatCompletionOutput> {
13
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
14
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
15
+ const providerHelper = getProviderHelper(provider, "conversational");
14
16
  const { data: response } = await innerRequest<ChatCompletionOutput>(args, providerHelper, {
15
17
  ...options,
16
18
  task: "conversational",
@@ -1,4 +1,5 @@
1
1
  import type { ChatCompletionInput, ChatCompletionStreamOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerStreamingRequest } from "../../utils/request";
@@ -10,7 +11,8 @@ export async function* chatCompletionStream(
10
11
  args: BaseArgs & ChatCompletionInput,
11
12
  options?: Options
12
13
  ): AsyncGenerator<ChatCompletionStreamOutput> {
13
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
14
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
15
+ const providerHelper = getProviderHelper(provider, "conversational");
14
16
  yield* innerStreamingRequest<ChatCompletionStreamOutput>(args, providerHelper, {
15
17
  ...options,
16
18
  task: "conversational",
@@ -1,4 +1,5 @@
1
1
  import type { FeatureExtractionInput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -22,7 +23,8 @@ export async function featureExtraction(
22
23
  args: FeatureExtractionArgs,
23
24
  options?: Options
24
25
  ): Promise<FeatureExtractionOutput> {
25
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction");
26
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
27
+ const providerHelper = getProviderHelper(provider, "feature-extraction");
26
28
  const { data: res } = await innerRequest<FeatureExtractionOutput>(args, providerHelper, {
27
29
  ...options,
28
30
  task: "feature-extraction",
@@ -1,4 +1,5 @@
1
1
  import type { FillMaskInput, FillMaskOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -9,7 +10,8 @@ export type FillMaskArgs = BaseArgs & FillMaskInput;
9
10
  * Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
10
11
  */
11
12
  export async function fillMask(args: FillMaskArgs, options?: Options): Promise<FillMaskOutput> {
12
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask");
13
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
14
+ const providerHelper = getProviderHelper(provider, "fill-mask");
13
15
  const { data: res } = await innerRequest<FillMaskOutput>(args, providerHelper, {
14
16
  ...options,
15
17
  task: "fill-mask",
@@ -1,4 +1,6 @@
1
1
  import type { QuestionAnsweringInput, QuestionAnsweringOutput } from "@huggingface/tasks";
2
+
3
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
4
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
5
  import type { BaseArgs, Options } from "../../types";
4
6
  import { innerRequest } from "../../utils/request";
@@ -12,7 +14,8 @@ export async function questionAnswering(
12
14
  args: QuestionAnsweringArgs,
13
15
  options?: Options
14
16
  ): Promise<QuestionAnsweringOutput[number]> {
15
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering");
17
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
18
+ const providerHelper = getProviderHelper(provider, "question-answering");
16
19
  const { data: res } = await innerRequest<QuestionAnsweringOutput | QuestionAnsweringOutput[number]>(
17
20
  args,
18
21
  providerHelper,
@@ -1,4 +1,5 @@
1
1
  import type { SentenceSimilarityInput, SentenceSimilarityOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -12,7 +13,8 @@ export async function sentenceSimilarity(
12
13
  args: SentenceSimilarityArgs,
13
14
  options?: Options
14
15
  ): Promise<SentenceSimilarityOutput> {
15
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity");
16
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
17
+ const providerHelper = getProviderHelper(provider, "sentence-similarity");
16
18
  const { data: res } = await innerRequest<SentenceSimilarityOutput>(args, providerHelper, {
17
19
  ...options,
18
20
  task: "sentence-similarity",
@@ -1,4 +1,5 @@
1
1
  import type { SummarizationInput, SummarizationOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -9,7 +10,8 @@ export type SummarizationArgs = BaseArgs & SummarizationInput;
9
10
  * This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.
10
11
  */
11
12
  export async function summarization(args: SummarizationArgs, options?: Options): Promise<SummarizationOutput> {
12
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization");
13
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
14
+ const providerHelper = getProviderHelper(provider, "summarization");
13
15
  const { data: res } = await innerRequest<SummarizationOutput[]>(args, providerHelper, {
14
16
  ...options,
15
17
  task: "summarization",
@@ -1,4 +1,5 @@
1
1
  import type { TableQuestionAnsweringInput, TableQuestionAnsweringOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -12,7 +13,8 @@ export async function tableQuestionAnswering(
12
13
  args: TableQuestionAnsweringArgs,
13
14
  options?: Options
14
15
  ): Promise<TableQuestionAnsweringOutput[number]> {
15
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
16
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
17
+ const providerHelper = getProviderHelper(provider, "table-question-answering");
16
18
  const { data: res } = await innerRequest<TableQuestionAnsweringOutput | TableQuestionAnsweringOutput[number]>(
17
19
  args,
18
20
  providerHelper,
@@ -1,4 +1,5 @@
1
1
  import type { TextClassificationInput, TextClassificationOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -12,7 +13,8 @@ export async function textClassification(
12
13
  args: TextClassificationArgs,
13
14
  options?: Options
14
15
  ): Promise<TextClassificationOutput> {
15
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
16
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
17
+ const providerHelper = getProviderHelper(provider, "text-classification");
16
18
  const { data: res } = await innerRequest<TextClassificationOutput>(args, providerHelper, {
17
19
  ...options,
18
20
  task: "text-classification",
@@ -1,4 +1,5 @@
1
1
  import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { HyperbolicTextCompletionOutput } from "../../providers/hyperbolic";
4
5
  import type { BaseArgs, Options } from "../../types";
@@ -13,7 +14,8 @@ export async function textGeneration(
13
14
  args: BaseArgs & TextGenerationInput,
14
15
  options?: Options
15
16
  ): Promise<TextGenerationOutput> {
16
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
17
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
18
+ const providerHelper = getProviderHelper(provider, "text-generation");
17
19
  const { data: response } = await innerRequest<
18
20
  HyperbolicTextCompletionOutput | TextGenerationOutput | TextGenerationOutput[]
19
21
  >(args, providerHelper, {