@huggingface/inference 3.9.2 → 3.11.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. package/README.md +9 -7
  2. package/dist/index.cjs +771 -646
  3. package/dist/index.js +770 -646
  4. package/dist/src/InferenceClient.d.ts +16 -17
  5. package/dist/src/InferenceClient.d.ts.map +1 -1
  6. package/dist/src/lib/getInferenceProviderMapping.d.ts +6 -2
  7. package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -1
  8. package/dist/src/lib/getProviderHelper.d.ts.map +1 -1
  9. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  10. package/dist/src/providers/consts.d.ts.map +1 -1
  11. package/dist/src/providers/ovhcloud.d.ts +38 -0
  12. package/dist/src/providers/ovhcloud.d.ts.map +1 -0
  13. package/dist/src/providers/providerHelper.d.ts +1 -1
  14. package/dist/src/providers/providerHelper.d.ts.map +1 -1
  15. package/dist/src/snippets/getInferenceSnippets.d.ts +1 -1
  16. package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
  17. package/dist/src/snippets/templates.exported.d.ts.map +1 -1
  18. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  19. package/dist/src/tasks/audio/audioToAudio.d.ts.map +1 -1
  20. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  21. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  22. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  23. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  24. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  25. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  26. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  27. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  28. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  29. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  30. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
  31. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  32. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  33. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  34. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  35. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
  36. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  37. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  38. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  39. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  40. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  41. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  42. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  43. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  44. package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
  45. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  46. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  47. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  48. package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
  49. package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
  50. package/dist/src/types.d.ts +7 -5
  51. package/dist/src/types.d.ts.map +1 -1
  52. package/dist/src/utils/typedEntries.d.ts +4 -0
  53. package/dist/src/utils/typedEntries.d.ts.map +1 -0
  54. package/package.json +3 -3
  55. package/src/InferenceClient.ts +32 -43
  56. package/src/lib/getInferenceProviderMapping.ts +68 -19
  57. package/src/lib/getProviderHelper.ts +5 -0
  58. package/src/lib/makeRequestOptions.ts +4 -3
  59. package/src/providers/consts.ts +1 -0
  60. package/src/providers/ovhcloud.ts +75 -0
  61. package/src/providers/providerHelper.ts +1 -1
  62. package/src/snippets/getInferenceSnippets.ts +5 -4
  63. package/src/snippets/templates.exported.ts +7 -3
  64. package/src/tasks/audio/audioClassification.ts +3 -1
  65. package/src/tasks/audio/audioToAudio.ts +4 -1
  66. package/src/tasks/audio/automaticSpeechRecognition.ts +3 -1
  67. package/src/tasks/audio/textToSpeech.ts +2 -1
  68. package/src/tasks/custom/request.ts +3 -1
  69. package/src/tasks/custom/streamingRequest.ts +3 -1
  70. package/src/tasks/cv/imageClassification.ts +3 -1
  71. package/src/tasks/cv/imageSegmentation.ts +3 -1
  72. package/src/tasks/cv/imageToImage.ts +3 -1
  73. package/src/tasks/cv/imageToText.ts +3 -1
  74. package/src/tasks/cv/objectDetection.ts +3 -1
  75. package/src/tasks/cv/textToImage.ts +2 -1
  76. package/src/tasks/cv/textToVideo.ts +2 -1
  77. package/src/tasks/cv/zeroShotImageClassification.ts +3 -1
  78. package/src/tasks/multimodal/documentQuestionAnswering.ts +3 -1
  79. package/src/tasks/multimodal/visualQuestionAnswering.ts +3 -1
  80. package/src/tasks/nlp/chatCompletion.ts +3 -1
  81. package/src/tasks/nlp/chatCompletionStream.ts +3 -1
  82. package/src/tasks/nlp/featureExtraction.ts +3 -1
  83. package/src/tasks/nlp/fillMask.ts +3 -1
  84. package/src/tasks/nlp/questionAnswering.ts +4 -1
  85. package/src/tasks/nlp/sentenceSimilarity.ts +3 -1
  86. package/src/tasks/nlp/summarization.ts +3 -1
  87. package/src/tasks/nlp/tableQuestionAnswering.ts +3 -1
  88. package/src/tasks/nlp/textClassification.ts +3 -1
  89. package/src/tasks/nlp/textGeneration.ts +3 -1
  90. package/src/tasks/nlp/textGenerationStream.ts +3 -1
  91. package/src/tasks/nlp/tokenClassification.ts +3 -1
  92. package/src/tasks/nlp/translation.ts +3 -1
  93. package/src/tasks/nlp/zeroShotClassification.ts +3 -1
  94. package/src/tasks/tabular/tabularClassification.ts +3 -1
  95. package/src/tasks/tabular/tabularRegression.ts +3 -1
  96. package/src/types.ts +9 -4
  97. package/src/utils/typedEntries.ts +5 -0
@@ -1,4 +1,5 @@
1
1
  import type { ObjectDetectionInput, ObjectDetectionOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -11,7 +12,8 @@ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImage
11
12
  * Recommended model: facebook/detr-resnet-50
12
13
  */
13
14
  export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> {
14
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
15
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
16
+ const providerHelper = getProviderHelper(provider, "object-detection");
15
17
  const payload = preparePayload(args);
16
18
  const { data: res } = await innerRequest<ObjectDetectionOutput>(payload, providerHelper, {
17
19
  ...options,
@@ -1,4 +1,5 @@
1
1
  import type { TextToImageInput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import { makeRequestOptions } from "../../lib/makeRequestOptions";
4
5
  import type { BaseArgs, Options } from "../../types";
@@ -23,7 +24,7 @@ export async function textToImage(
23
24
  options?: TextToImageOptions & { outputType?: undefined | "blob" }
24
25
  ): Promise<Blob>;
25
26
  export async function textToImage(args: TextToImageArgs, options?: TextToImageOptions): Promise<Blob | string> {
26
- const provider = args.provider ?? "hf-inference";
27
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
27
28
  const providerHelper = getProviderHelper(provider, "text-to-image");
28
29
  const { data: res } = await innerRequest<Record<string, unknown>>(args, providerHelper, {
29
30
  ...options,
@@ -1,4 +1,5 @@
1
1
  import type { TextToVideoInput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import { makeRequestOptions } from "../../lib/makeRequestOptions";
4
5
  import type { FalAiQueueOutput } from "../../providers/fal-ai";
@@ -12,7 +13,7 @@ export type TextToVideoArgs = BaseArgs & TextToVideoInput;
12
13
  export type TextToVideoOutput = Blob;
13
14
 
14
15
  export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise<TextToVideoOutput> {
15
- const provider = args.provider ?? "hf-inference";
16
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
16
17
  const providerHelper = getProviderHelper(provider, "text-to-video");
17
18
  const { data: response } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(
18
19
  args,
@@ -1,4 +1,5 @@
1
1
  import type { ZeroShotImageClassificationInput, ZeroShotImageClassificationOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options, RequestArgs } from "../../types";
4
5
  import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -44,7 +45,8 @@ export async function zeroShotImageClassification(
44
45
  args: ZeroShotImageClassificationArgs,
45
46
  options?: Options
46
47
  ): Promise<ZeroShotImageClassificationOutput> {
47
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
48
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
49
+ const providerHelper = getProviderHelper(provider, "zero-shot-image-classification");
48
50
  const payload = await preparePayload(args);
49
51
  const { data: res } = await innerRequest<ZeroShotImageClassificationOutput>(payload, providerHelper, {
50
52
  ...options,
@@ -3,6 +3,7 @@ import type {
3
3
  DocumentQuestionAnsweringInputData,
4
4
  DocumentQuestionAnsweringOutput,
5
5
  } from "@huggingface/tasks";
6
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
6
7
  import { getProviderHelper } from "../../lib/getProviderHelper";
7
8
  import type { BaseArgs, Options, RequestArgs } from "../../types";
8
9
  import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -19,7 +20,8 @@ export async function documentQuestionAnswering(
19
20
  args: DocumentQuestionAnsweringArgs,
20
21
  options?: Options
21
22
  ): Promise<DocumentQuestionAnsweringOutput[number]> {
22
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering");
23
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
24
+ const providerHelper = getProviderHelper(provider, "document-question-answering");
23
25
  const reqArgs: RequestArgs = {
24
26
  ...args,
25
27
  inputs: {
@@ -3,6 +3,7 @@ import type {
3
3
  VisualQuestionAnsweringInputData,
4
4
  VisualQuestionAnsweringOutput,
5
5
  } from "@huggingface/tasks";
6
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
6
7
  import { getProviderHelper } from "../../lib/getProviderHelper";
7
8
  import type { BaseArgs, Options, RequestArgs } from "../../types";
8
9
  import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -19,7 +20,8 @@ export async function visualQuestionAnswering(
19
20
  args: VisualQuestionAnsweringArgs,
20
21
  options?: Options
21
22
  ): Promise<VisualQuestionAnsweringOutput[number]> {
22
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "visual-question-answering");
23
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
24
+ const providerHelper = getProviderHelper(provider, "visual-question-answering");
23
25
  const reqArgs: RequestArgs = {
24
26
  ...args,
25
27
  inputs: {
@@ -1,4 +1,5 @@
1
1
  import type { ChatCompletionInput, ChatCompletionOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -10,7 +11,8 @@ export async function chatCompletion(
10
11
  args: BaseArgs & ChatCompletionInput,
11
12
  options?: Options
12
13
  ): Promise<ChatCompletionOutput> {
13
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
14
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
15
+ const providerHelper = getProviderHelper(provider, "conversational");
14
16
  const { data: response } = await innerRequest<ChatCompletionOutput>(args, providerHelper, {
15
17
  ...options,
16
18
  task: "conversational",
@@ -1,4 +1,5 @@
1
1
  import type { ChatCompletionInput, ChatCompletionStreamOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerStreamingRequest } from "../../utils/request";
@@ -10,7 +11,8 @@ export async function* chatCompletionStream(
10
11
  args: BaseArgs & ChatCompletionInput,
11
12
  options?: Options
12
13
  ): AsyncGenerator<ChatCompletionStreamOutput> {
13
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
14
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
15
+ const providerHelper = getProviderHelper(provider, "conversational");
14
16
  yield* innerStreamingRequest<ChatCompletionStreamOutput>(args, providerHelper, {
15
17
  ...options,
16
18
  task: "conversational",
@@ -1,4 +1,5 @@
1
1
  import type { FeatureExtractionInput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -22,7 +23,8 @@ export async function featureExtraction(
22
23
  args: FeatureExtractionArgs,
23
24
  options?: Options
24
25
  ): Promise<FeatureExtractionOutput> {
25
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction");
26
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
27
+ const providerHelper = getProviderHelper(provider, "feature-extraction");
26
28
  const { data: res } = await innerRequest<FeatureExtractionOutput>(args, providerHelper, {
27
29
  ...options,
28
30
  task: "feature-extraction",
@@ -1,4 +1,5 @@
1
1
  import type { FillMaskInput, FillMaskOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -9,7 +10,8 @@ export type FillMaskArgs = BaseArgs & FillMaskInput;
9
10
  * Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
10
11
  */
11
12
  export async function fillMask(args: FillMaskArgs, options?: Options): Promise<FillMaskOutput> {
12
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask");
13
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
14
+ const providerHelper = getProviderHelper(provider, "fill-mask");
13
15
  const { data: res } = await innerRequest<FillMaskOutput>(args, providerHelper, {
14
16
  ...options,
15
17
  task: "fill-mask",
@@ -1,4 +1,6 @@
1
1
  import type { QuestionAnsweringInput, QuestionAnsweringOutput } from "@huggingface/tasks";
2
+
3
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
4
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
5
  import type { BaseArgs, Options } from "../../types";
4
6
  import { innerRequest } from "../../utils/request";
@@ -12,7 +14,8 @@ export async function questionAnswering(
12
14
  args: QuestionAnsweringArgs,
13
15
  options?: Options
14
16
  ): Promise<QuestionAnsweringOutput[number]> {
15
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering");
17
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
18
+ const providerHelper = getProviderHelper(provider, "question-answering");
16
19
  const { data: res } = await innerRequest<QuestionAnsweringOutput | QuestionAnsweringOutput[number]>(
17
20
  args,
18
21
  providerHelper,
@@ -1,4 +1,5 @@
1
1
  import type { SentenceSimilarityInput, SentenceSimilarityOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -12,7 +13,8 @@ export async function sentenceSimilarity(
12
13
  args: SentenceSimilarityArgs,
13
14
  options?: Options
14
15
  ): Promise<SentenceSimilarityOutput> {
15
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity");
16
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
17
+ const providerHelper = getProviderHelper(provider, "sentence-similarity");
16
18
  const { data: res } = await innerRequest<SentenceSimilarityOutput>(args, providerHelper, {
17
19
  ...options,
18
20
  task: "sentence-similarity",
@@ -1,4 +1,5 @@
1
1
  import type { SummarizationInput, SummarizationOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -9,7 +10,8 @@ export type SummarizationArgs = BaseArgs & SummarizationInput;
9
10
  * This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.
10
11
  */
11
12
  export async function summarization(args: SummarizationArgs, options?: Options): Promise<SummarizationOutput> {
12
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization");
13
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
14
+ const providerHelper = getProviderHelper(provider, "summarization");
13
15
  const { data: res } = await innerRequest<SummarizationOutput[]>(args, providerHelper, {
14
16
  ...options,
15
17
  task: "summarization",
@@ -1,4 +1,5 @@
1
1
  import type { TableQuestionAnsweringInput, TableQuestionAnsweringOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -12,7 +13,8 @@ export async function tableQuestionAnswering(
12
13
  args: TableQuestionAnsweringArgs,
13
14
  options?: Options
14
15
  ): Promise<TableQuestionAnsweringOutput[number]> {
15
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
16
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
17
+ const providerHelper = getProviderHelper(provider, "table-question-answering");
16
18
  const { data: res } = await innerRequest<TableQuestionAnsweringOutput | TableQuestionAnsweringOutput[number]>(
17
19
  args,
18
20
  providerHelper,
@@ -1,4 +1,5 @@
1
1
  import type { TextClassificationInput, TextClassificationOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -12,7 +13,8 @@ export async function textClassification(
12
13
  args: TextClassificationArgs,
13
14
  options?: Options
14
15
  ): Promise<TextClassificationOutput> {
15
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
16
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
17
+ const providerHelper = getProviderHelper(provider, "text-classification");
16
18
  const { data: res } = await innerRequest<TextClassificationOutput>(args, providerHelper, {
17
19
  ...options,
18
20
  task: "text-classification",
@@ -1,4 +1,5 @@
1
1
  import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { HyperbolicTextCompletionOutput } from "../../providers/hyperbolic";
4
5
  import type { BaseArgs, Options } from "../../types";
@@ -13,7 +14,8 @@ export async function textGeneration(
13
14
  args: BaseArgs & TextGenerationInput,
14
15
  options?: Options
15
16
  ): Promise<TextGenerationOutput> {
16
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
17
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
18
+ const providerHelper = getProviderHelper(provider, "text-generation");
17
19
  const { data: response } = await innerRequest<
18
20
  HyperbolicTextCompletionOutput | TextGenerationOutput | TextGenerationOutput[]
19
21
  >(args, providerHelper, {
@@ -1,4 +1,5 @@
1
1
  import type { TextGenerationInput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerStreamingRequest } from "../../utils/request";
@@ -90,7 +91,8 @@ export async function* textGenerationStream(
90
91
  args: BaseArgs & TextGenerationInput,
91
92
  options?: Options
92
93
  ): AsyncGenerator<TextGenerationStreamOutput> {
93
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
94
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
95
+ const providerHelper = getProviderHelper(provider, "text-generation");
94
96
  yield* innerStreamingRequest<TextGenerationStreamOutput>(args, providerHelper, {
95
97
  ...options,
96
98
  task: "text-generation",
@@ -1,4 +1,5 @@
1
1
  import type { TokenClassificationInput, TokenClassificationOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -12,7 +13,8 @@ export async function tokenClassification(
12
13
  args: TokenClassificationArgs,
13
14
  options?: Options
14
15
  ): Promise<TokenClassificationOutput> {
15
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification");
16
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
17
+ const providerHelper = getProviderHelper(provider, "token-classification");
16
18
  const { data: res } = await innerRequest<TokenClassificationOutput[number] | TokenClassificationOutput>(
17
19
  args,
18
20
  providerHelper,
@@ -1,4 +1,5 @@
1
1
  import type { TranslationInput, TranslationOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -8,7 +9,8 @@ export type TranslationArgs = BaseArgs & TranslationInput;
8
9
  * This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en.
9
10
  */
10
11
  export async function translation(args: TranslationArgs, options?: Options): Promise<TranslationOutput> {
11
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation");
12
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
13
+ const providerHelper = getProviderHelper(provider, "translation");
12
14
  const { data: res } = await innerRequest<TranslationOutput>(args, providerHelper, {
13
15
  ...options,
14
16
  task: "translation",
@@ -1,4 +1,5 @@
1
1
  import type { ZeroShotClassificationInput, ZeroShotClassificationOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -12,7 +13,8 @@ export async function zeroShotClassification(
12
13
  args: ZeroShotClassificationArgs,
13
14
  options?: Options
14
15
  ): Promise<ZeroShotClassificationOutput> {
15
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification");
16
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
17
+ const providerHelper = getProviderHelper(provider, "zero-shot-classification");
16
18
  const { data: res } = await innerRequest<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>(
17
19
  args,
18
20
  providerHelper,
@@ -1,3 +1,4 @@
1
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
1
2
  import { getProviderHelper } from "../../lib/getProviderHelper";
2
3
  import type { BaseArgs, Options } from "../../types";
3
4
  import { innerRequest } from "../../utils/request";
@@ -25,7 +26,8 @@ export async function tabularClassification(
25
26
  args: TabularClassificationArgs,
26
27
  options?: Options
27
28
  ): Promise<TabularClassificationOutput> {
28
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification");
29
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
30
+ const providerHelper = getProviderHelper(provider, "tabular-classification");
29
31
  const { data: res } = await innerRequest<TabularClassificationOutput>(args, providerHelper, {
30
32
  ...options,
31
33
  task: "tabular-classification",
@@ -1,3 +1,4 @@
1
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
1
2
  import { getProviderHelper } from "../../lib/getProviderHelper";
2
3
  import type { BaseArgs, Options } from "../../types";
3
4
  import { innerRequest } from "../../utils/request";
@@ -25,7 +26,8 @@ export async function tabularRegression(
25
26
  args: TabularRegressionArgs,
26
27
  options?: Options
27
28
  ): Promise<TabularRegressionOutput> {
28
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression");
29
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
30
+ const providerHelper = getProviderHelper(provider, "tabular-regression");
29
31
  const { data: res } = await innerRequest<TabularRegressionOutput>(args, providerHelper, {
30
32
  ...options,
31
33
  task: "tabular-regression",
package/src/types.ts CHANGED
@@ -51,13 +51,18 @@ export const INFERENCE_PROVIDERS = [
51
51
  "novita",
52
52
  "nscale",
53
53
  "openai",
54
+ "ovhcloud",
54
55
  "replicate",
55
56
  "sambanova",
56
57
  "together",
57
58
  ] as const;
58
59
 
60
+ export const PROVIDERS_OR_POLICIES = [...INFERENCE_PROVIDERS, "auto"] as const;
61
+
59
62
  export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
60
63
 
64
+ export type InferenceProviderOrPolicy = (typeof PROVIDERS_OR_POLICIES)[number];
65
+
61
66
  export interface BaseArgs {
62
67
  /**
63
68
  * The access token to use. Without it, you'll get rate-limited quickly.
@@ -79,18 +84,18 @@ export interface BaseArgs {
79
84
  model?: ModelId;
80
85
 
81
86
  /**
82
- * The URL of the endpoint to use. If not specified, will call huggingface.co/api/tasks to get the default endpoint for the task.
87
+ * The URL of the endpoint to use.
83
88
  *
84
- * If specified, will use this URL instead of the default one.
89
+ * If not specified, will call the default router.huggingface.co Inference Providers endpoint.
85
90
  */
86
91
  endpointUrl?: string;
87
92
 
88
93
  /**
89
94
  * Set an Inference provider to run this model on.
90
95
  *
91
- * Defaults to the first provider in your user settings that is compatible with this model.
96
+ * Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
92
97
  */
93
- provider?: InferenceProvider;
98
+ provider?: InferenceProviderOrPolicy;
94
99
  }
95
100
 
96
101
  export type RequestArgs = BaseArgs &
@@ -0,0 +1,5 @@
1
+ export function typedEntries<T extends { [s: string]: T[keyof T] } | ArrayLike<T[keyof T]>>(
2
+ obj: T
3
+ ): [keyof T, T[keyof T]][] {
4
+ return Object.entries(obj) as [keyof T, T[keyof T]][];
5
+ }