@huggingface/inference 3.7.0 → 3.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. package/dist/index.cjs +1369 -941
  2. package/dist/index.js +1371 -943
  3. package/dist/src/lib/getInferenceProviderMapping.d.ts +21 -0
  4. package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -0
  5. package/dist/src/lib/getProviderHelper.d.ts +37 -0
  6. package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
  7. package/dist/src/lib/makeRequestOptions.d.ts +5 -5
  8. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  9. package/dist/src/providers/black-forest-labs.d.ts +14 -18
  10. package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
  11. package/dist/src/providers/cerebras.d.ts +4 -2
  12. package/dist/src/providers/cerebras.d.ts.map +1 -1
  13. package/dist/src/providers/cohere.d.ts +5 -2
  14. package/dist/src/providers/cohere.d.ts.map +1 -1
  15. package/dist/src/providers/consts.d.ts +2 -3
  16. package/dist/src/providers/consts.d.ts.map +1 -1
  17. package/dist/src/providers/fal-ai.d.ts +50 -3
  18. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  19. package/dist/src/providers/fireworks-ai.d.ts +5 -2
  20. package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
  21. package/dist/src/providers/hf-inference.d.ts +126 -2
  22. package/dist/src/providers/hf-inference.d.ts.map +1 -1
  23. package/dist/src/providers/hyperbolic.d.ts +31 -2
  24. package/dist/src/providers/hyperbolic.d.ts.map +1 -1
  25. package/dist/src/providers/nebius.d.ts +20 -18
  26. package/dist/src/providers/nebius.d.ts.map +1 -1
  27. package/dist/src/providers/novita.d.ts +21 -18
  28. package/dist/src/providers/novita.d.ts.map +1 -1
  29. package/dist/src/providers/openai.d.ts +4 -2
  30. package/dist/src/providers/openai.d.ts.map +1 -1
  31. package/dist/src/providers/providerHelper.d.ts +182 -0
  32. package/dist/src/providers/providerHelper.d.ts.map +1 -0
  33. package/dist/src/providers/replicate.d.ts +23 -19
  34. package/dist/src/providers/replicate.d.ts.map +1 -1
  35. package/dist/src/providers/sambanova.d.ts +4 -2
  36. package/dist/src/providers/sambanova.d.ts.map +1 -1
  37. package/dist/src/providers/together.d.ts +32 -2
  38. package/dist/src/providers/together.d.ts.map +1 -1
  39. package/dist/src/snippets/getInferenceSnippets.d.ts +2 -1
  40. package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
  41. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  42. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  43. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  44. package/dist/src/tasks/audio/utils.d.ts +2 -1
  45. package/dist/src/tasks/audio/utils.d.ts.map +1 -1
  46. package/dist/src/tasks/custom/request.d.ts +0 -2
  47. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  48. package/dist/src/tasks/custom/streamingRequest.d.ts +0 -2
  49. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  50. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  51. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  52. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  53. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  54. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  55. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  56. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
  57. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  58. package/dist/src/tasks/index.d.ts +6 -6
  59. package/dist/src/tasks/index.d.ts.map +1 -1
  60. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  61. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  62. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  63. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
  64. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  65. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  66. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  67. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  68. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  69. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  70. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  71. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  72. package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
  73. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  74. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  75. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  76. package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
  77. package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
  78. package/dist/src/types.d.ts +5 -13
  79. package/dist/src/types.d.ts.map +1 -1
  80. package/dist/src/utils/request.d.ts +3 -2
  81. package/dist/src/utils/request.d.ts.map +1 -1
  82. package/package.json +3 -3
  83. package/src/lib/getInferenceProviderMapping.ts +96 -0
  84. package/src/lib/getProviderHelper.ts +270 -0
  85. package/src/lib/makeRequestOptions.ts +78 -97
  86. package/src/providers/black-forest-labs.ts +73 -22
  87. package/src/providers/cerebras.ts +6 -27
  88. package/src/providers/cohere.ts +9 -28
  89. package/src/providers/consts.ts +5 -2
  90. package/src/providers/fal-ai.ts +224 -77
  91. package/src/providers/fireworks-ai.ts +8 -29
  92. package/src/providers/hf-inference.ts +557 -34
  93. package/src/providers/hyperbolic.ts +107 -29
  94. package/src/providers/nebius.ts +65 -29
  95. package/src/providers/novita.ts +68 -32
  96. package/src/providers/openai.ts +6 -32
  97. package/src/providers/providerHelper.ts +354 -0
  98. package/src/providers/replicate.ts +124 -34
  99. package/src/providers/sambanova.ts +5 -30
  100. package/src/providers/together.ts +92 -28
  101. package/src/snippets/getInferenceSnippets.ts +39 -14
  102. package/src/snippets/templates.exported.ts +25 -25
  103. package/src/tasks/audio/audioClassification.ts +5 -8
  104. package/src/tasks/audio/audioToAudio.ts +4 -27
  105. package/src/tasks/audio/automaticSpeechRecognition.ts +5 -4
  106. package/src/tasks/audio/textToSpeech.ts +5 -29
  107. package/src/tasks/audio/utils.ts +2 -1
  108. package/src/tasks/custom/request.ts +3 -3
  109. package/src/tasks/custom/streamingRequest.ts +4 -3
  110. package/src/tasks/cv/imageClassification.ts +4 -8
  111. package/src/tasks/cv/imageSegmentation.ts +4 -9
  112. package/src/tasks/cv/imageToImage.ts +4 -7
  113. package/src/tasks/cv/imageToText.ts +4 -7
  114. package/src/tasks/cv/objectDetection.ts +4 -19
  115. package/src/tasks/cv/textToImage.ts +9 -137
  116. package/src/tasks/cv/textToVideo.ts +17 -64
  117. package/src/tasks/cv/zeroShotImageClassification.ts +4 -8
  118. package/src/tasks/index.ts +6 -6
  119. package/src/tasks/multimodal/documentQuestionAnswering.ts +4 -19
  120. package/src/tasks/multimodal/visualQuestionAnswering.ts +4 -12
  121. package/src/tasks/nlp/chatCompletion.ts +5 -20
  122. package/src/tasks/nlp/chatCompletionStream.ts +4 -3
  123. package/src/tasks/nlp/featureExtraction.ts +4 -19
  124. package/src/tasks/nlp/fillMask.ts +4 -17
  125. package/src/tasks/nlp/questionAnswering.ts +11 -26
  126. package/src/tasks/nlp/sentenceSimilarity.ts +4 -8
  127. package/src/tasks/nlp/summarization.ts +4 -7
  128. package/src/tasks/nlp/tableQuestionAnswering.ts +10 -30
  129. package/src/tasks/nlp/textClassification.ts +4 -9
  130. package/src/tasks/nlp/textGeneration.ts +11 -79
  131. package/src/tasks/nlp/textGenerationStream.ts +3 -1
  132. package/src/tasks/nlp/tokenClassification.ts +11 -23
  133. package/src/tasks/nlp/translation.ts +4 -7
  134. package/src/tasks/nlp/zeroShotClassification.ts +11 -21
  135. package/src/tasks/tabular/tabularClassification.ts +4 -7
  136. package/src/tasks/tabular/tabularRegression.ts +4 -7
  137. package/src/types.ts +5 -14
  138. package/src/utils/request.ts +7 -4
  139. package/dist/src/lib/getProviderModelId.d.ts +0 -10
  140. package/dist/src/lib/getProviderModelId.d.ts.map +0 -1
  141. package/src/lib/getProviderModelId.ts +0 -74
@@ -1,5 +1,5 @@
1
1
  import type { TableQuestionAnsweringInput, TableQuestionAnsweringOutput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
3
  import type { BaseArgs, Options } from "../../types";
4
4
  import { innerRequest } from "../../utils/request";
5
5
 
@@ -12,34 +12,14 @@ export async function tableQuestionAnswering(
12
12
  args: TableQuestionAnsweringArgs,
13
13
  options?: Options
14
14
  ): Promise<TableQuestionAnsweringOutput[number]> {
15
- const { data: res } = await innerRequest<TableQuestionAnsweringOutput | TableQuestionAnsweringOutput[number]>(args, {
16
- ...options,
17
- task: "table-question-answering",
18
- });
19
- const isValidOutput = Array.isArray(res) ? res.every((elem) => validate(elem)) : validate(res);
20
- if (!isValidOutput) {
21
- throw new InferenceOutputError(
22
- "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
23
- );
24
- }
25
- return Array.isArray(res) ? res[0] : res;
26
- }
27
-
28
- function validate(elem: unknown): elem is TableQuestionAnsweringOutput[number] {
29
- return (
30
- typeof elem === "object" &&
31
- !!elem &&
32
- "aggregator" in elem &&
33
- typeof elem.aggregator === "string" &&
34
- "answer" in elem &&
35
- typeof elem.answer === "string" &&
36
- "cells" in elem &&
37
- Array.isArray(elem.cells) &&
38
- elem.cells.every((x: unknown): x is string => typeof x === "string") &&
39
- "coordinates" in elem &&
40
- Array.isArray(elem.coordinates) &&
41
- elem.coordinates.every(
42
- (coord: unknown): coord is number[] => Array.isArray(coord) && coord.every((x) => typeof x === "number")
43
- )
15
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
16
+ const { data: res } = await innerRequest<TableQuestionAnsweringOutput | TableQuestionAnsweringOutput[number]>(
17
+ args,
18
+ providerHelper,
19
+ {
20
+ ...options,
21
+ task: "table-question-answering",
22
+ }
44
23
  );
24
+ return providerHelper.getResponse(res);
45
25
  }
@@ -1,5 +1,5 @@
1
1
  import type { TextClassificationInput, TextClassificationOutput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
3
  import type { BaseArgs, Options } from "../../types";
4
4
  import { innerRequest } from "../../utils/request";
5
5
 
@@ -12,15 +12,10 @@ export async function textClassification(
12
12
  args: TextClassificationArgs,
13
13
  options?: Options
14
14
  ): Promise<TextClassificationOutput> {
15
- const { data: res } = await innerRequest<TextClassificationOutput>(args, {
15
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
16
+ const { data: res } = await innerRequest<TextClassificationOutput>(args, providerHelper, {
16
17
  ...options,
17
18
  task: "text-classification",
18
19
  });
19
- const output = res?.[0];
20
- const isValidOutput =
21
- Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number");
22
- if (!isValidOutput) {
23
- throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
24
- }
25
- return output;
20
+ return providerHelper.getResponse(res);
26
21
  }
@@ -1,33 +1,11 @@
1
- import type {
2
- ChatCompletionOutput,
3
- TextGenerationInput,
4
- TextGenerationOutput,
5
- TextGenerationOutputFinishReason,
6
- } from "@huggingface/tasks";
7
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
1
+ import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
+ import type { HyperbolicTextCompletionOutput } from "../../providers/hyperbolic";
8
4
  import type { BaseArgs, Options } from "../../types";
9
- import { omit } from "../../utils/omit";
10
5
  import { innerRequest } from "../../utils/request";
11
- import { toArray } from "../../utils/toArray";
12
6
 
13
7
  export type { TextGenerationInput, TextGenerationOutput };
14
8
 
15
- interface TogeteherTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
16
- choices: Array<{
17
- text: string;
18
- finish_reason: TextGenerationOutputFinishReason;
19
- seed: number;
20
- logprobs: unknown;
21
- index: number;
22
- }>;
23
- }
24
-
25
- interface HyperbolicTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
26
- choices: Array<{
27
- message: { content: string };
28
- }>;
29
- }
30
-
31
9
  /**
32
10
  * Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
33
11
  */
@@ -35,58 +13,12 @@ export async function textGeneration(
35
13
  args: BaseArgs & TextGenerationInput,
36
14
  options?: Options
37
15
  ): Promise<TextGenerationOutput> {
38
- if (args.provider === "together") {
39
- args.prompt = args.inputs;
40
- const { data: raw } = await innerRequest<TogeteherTextCompletionOutput>(args, {
41
- ...options,
42
- task: "text-generation",
43
- });
44
- const isValidOutput =
45
- typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
46
- if (!isValidOutput) {
47
- throw new InferenceOutputError("Expected ChatCompletionOutput");
48
- }
49
- const completion = raw.choices[0];
50
- return {
51
- generated_text: completion.text,
52
- };
53
- } else if (args.provider === "hyperbolic") {
54
- const payload = {
55
- messages: [{ content: args.inputs, role: "user" }],
56
- ...(args.parameters
57
- ? {
58
- max_tokens: args.parameters.max_new_tokens,
59
- ...omit(args.parameters, "max_new_tokens"),
60
- }
61
- : undefined),
62
- ...omit(args, ["inputs", "parameters"]),
63
- };
64
- const raw = (
65
- await innerRequest<HyperbolicTextCompletionOutput>(payload, {
66
- ...options,
67
- task: "text-generation",
68
- })
69
- ).data;
70
- const isValidOutput =
71
- typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
72
- if (!isValidOutput) {
73
- throw new InferenceOutputError("Expected ChatCompletionOutput");
74
- }
75
- const completion = raw.choices[0];
76
- return {
77
- generated_text: completion.message.content,
78
- };
79
- } else {
80
- const { data: res } = await innerRequest<TextGenerationOutput | TextGenerationOutput[]>(args, {
81
- ...options,
82
- task: "text-generation",
83
- });
84
- const output = toArray(res);
85
- const isValidOutput =
86
- Array.isArray(output) && output.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
87
- if (!isValidOutput) {
88
- throw new InferenceOutputError("Expected Array<{generated_text: string}>");
89
- }
90
- return (output as TextGenerationOutput[])?.[0];
91
- }
16
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
17
+ const { data: response } = await innerRequest<
18
+ HyperbolicTextCompletionOutput | TextGenerationOutput | TextGenerationOutput[]
19
+ >(args, providerHelper, {
20
+ ...options,
21
+ task: "text-generation",
22
+ });
23
+ return providerHelper.getResponse(response);
92
24
  }
@@ -1,4 +1,5 @@
1
1
  import type { TextGenerationInput } from "@huggingface/tasks";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
2
3
  import type { BaseArgs, Options } from "../../types";
3
4
  import { innerStreamingRequest } from "../../utils/request";
4
5
 
@@ -89,7 +90,8 @@ export async function* textGenerationStream(
89
90
  args: BaseArgs & TextGenerationInput,
90
91
  options?: Options
91
92
  ): AsyncGenerator<TextGenerationStreamOutput> {
92
- yield* innerStreamingRequest<TextGenerationStreamOutput>(args, {
93
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
94
+ yield* innerStreamingRequest<TextGenerationStreamOutput>(args, providerHelper, {
93
95
  ...options,
94
96
  task: "text-generation",
95
97
  });
@@ -1,8 +1,7 @@
1
1
  import type { TokenClassificationInput, TokenClassificationOutput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
3
  import type { BaseArgs, Options } from "../../types";
4
4
  import { innerRequest } from "../../utils/request";
5
- import { toArray } from "../../utils/toArray";
6
5
 
7
6
  export type TokenClassificationArgs = BaseArgs & TokenClassificationInput;
8
7
 
@@ -13,25 +12,14 @@ export async function tokenClassification(
13
12
  args: TokenClassificationArgs,
14
13
  options?: Options
15
14
  ): Promise<TokenClassificationOutput> {
16
- const { data: res } = await innerRequest<TokenClassificationOutput[number] | TokenClassificationOutput>(args, {
17
- ...options,
18
- task: "token-classification",
19
- });
20
- const output = toArray(res);
21
- const isValidOutput =
22
- Array.isArray(output) &&
23
- output.every(
24
- (x) =>
25
- typeof x.end === "number" &&
26
- typeof x.entity_group === "string" &&
27
- typeof x.score === "number" &&
28
- typeof x.start === "number" &&
29
- typeof x.word === "string"
30
- );
31
- if (!isValidOutput) {
32
- throw new InferenceOutputError(
33
- "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
34
- );
35
- }
36
- return output;
15
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification");
16
+ const { data: res } = await innerRequest<TokenClassificationOutput[number] | TokenClassificationOutput>(
17
+ args,
18
+ providerHelper,
19
+ {
20
+ ...options,
21
+ task: "token-classification",
22
+ }
23
+ );
24
+ return providerHelper.getResponse(res);
37
25
  }
@@ -1,5 +1,5 @@
1
1
  import type { TranslationInput, TranslationOutput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
3
  import type { BaseArgs, Options } from "../../types";
4
4
  import { innerRequest } from "../../utils/request";
5
5
 
@@ -8,13 +8,10 @@ export type TranslationArgs = BaseArgs & TranslationInput;
8
8
  * This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en.
9
9
  */
10
10
  export async function translation(args: TranslationArgs, options?: Options): Promise<TranslationOutput> {
11
- const { data: res } = await innerRequest<TranslationOutput>(args, {
11
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation");
12
+ const { data: res } = await innerRequest<TranslationOutput>(args, providerHelper, {
12
13
  ...options,
13
14
  task: "translation",
14
15
  });
15
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");
16
- if (!isValidOutput) {
17
- throw new InferenceOutputError("Expected type Array<{translation_text: string}>");
18
- }
19
- return res?.length === 1 ? res?.[0] : res;
16
+ return providerHelper.getResponse(res);
20
17
  }
@@ -1,8 +1,7 @@
1
1
  import type { ZeroShotClassificationInput, ZeroShotClassificationOutput } from "@huggingface/tasks";
2
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
+ import { getProviderHelper } from "../../lib/getProviderHelper";
3
3
  import type { BaseArgs, Options } from "../../types";
4
4
  import { innerRequest } from "../../utils/request";
5
- import { toArray } from "../../utils/toArray";
6
5
 
7
6
  export type ZeroShotClassificationArgs = BaseArgs & ZeroShotClassificationInput;
8
7
 
@@ -13,23 +12,14 @@ export async function zeroShotClassification(
13
12
  args: ZeroShotClassificationArgs,
14
13
  options?: Options
15
14
  ): Promise<ZeroShotClassificationOutput> {
16
- const { data: res } = await innerRequest<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>(args, {
17
- ...options,
18
- task: "zero-shot-classification",
19
- });
20
- const output = toArray(res);
21
- const isValidOutput =
22
- Array.isArray(output) &&
23
- output.every(
24
- (x) =>
25
- Array.isArray(x.labels) &&
26
- x.labels.every((_label) => typeof _label === "string") &&
27
- Array.isArray(x.scores) &&
28
- x.scores.every((_score) => typeof _score === "number") &&
29
- typeof x.sequence === "string"
30
- );
31
- if (!isValidOutput) {
32
- throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
33
- }
34
- return output;
15
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification");
16
+ const { data: res } = await innerRequest<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>(
17
+ args,
18
+ providerHelper,
19
+ {
20
+ ...options,
21
+ task: "zero-shot-classification",
22
+ }
23
+ );
24
+ return providerHelper.getResponse(res);
35
25
  }
@@ -1,4 +1,4 @@
1
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
1
+ import { getProviderHelper } from "../../lib/getProviderHelper";
2
2
  import type { BaseArgs, Options } from "../../types";
3
3
  import { innerRequest } from "../../utils/request";
4
4
 
@@ -25,13 +25,10 @@ export async function tabularClassification(
25
25
  args: TabularClassificationArgs,
26
26
  options?: Options
27
27
  ): Promise<TabularClassificationOutput> {
28
- const { data: res } = await innerRequest<TabularClassificationOutput>(args, {
28
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification");
29
+ const { data: res } = await innerRequest<TabularClassificationOutput>(args, providerHelper, {
29
30
  ...options,
30
31
  task: "tabular-classification",
31
32
  });
32
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
33
- if (!isValidOutput) {
34
- throw new InferenceOutputError("Expected number[]");
35
- }
36
- return res;
33
+ return providerHelper.getResponse(res);
37
34
  }
@@ -1,4 +1,4 @@
1
- import { InferenceOutputError } from "../../lib/InferenceOutputError";
1
+ import { getProviderHelper } from "../../lib/getProviderHelper";
2
2
  import type { BaseArgs, Options } from "../../types";
3
3
  import { innerRequest } from "../../utils/request";
4
4
 
@@ -25,13 +25,10 @@ export async function tabularRegression(
25
25
  args: TabularRegressionArgs,
26
26
  options?: Options
27
27
  ): Promise<TabularRegressionOutput> {
28
- const { data: res } = await innerRequest<TabularRegressionOutput>(args, {
28
+ const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression");
29
+ const { data: res } = await innerRequest<TabularRegressionOutput>(args, providerHelper, {
29
30
  ...options,
30
31
  task: "tabular-regression",
31
32
  });
32
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
33
- if (!isValidOutput) {
34
- throw new InferenceOutputError("Expected number[]");
35
- }
36
- return res;
33
+ return providerHelper.getResponse(res);
37
34
  }
package/src/types.ts CHANGED
@@ -1,4 +1,5 @@
1
1
  import type { ChatCompletionInput, PipelineType } from "@huggingface/tasks";
2
+ import type { InferenceProviderModelMapping } from "./lib/getInferenceProviderMapping";
2
3
 
3
4
  /**
4
5
  * HF model id, like "meta-llama/Llama-3.3-70B-Instruct"
@@ -34,7 +35,7 @@ export interface Options {
34
35
  billTo?: string;
35
36
  }
36
37
 
37
- export type InferenceTask = Exclude<PipelineType, "other">;
38
+ export type InferenceTask = Exclude<PipelineType, "other"> | "conversational";
38
39
 
39
40
  export const INFERENCE_PROVIDERS = [
40
41
  "black-forest-labs",
@@ -101,14 +102,6 @@ export type RequestArgs = BaseArgs &
101
102
  parameters?: Record<string, unknown>;
102
103
  };
103
104
 
104
- export interface ProviderConfig {
105
- makeBaseUrl: ((task?: InferenceTask) => string) | (() => string);
106
- makeBody: (params: BodyParams) => Record<string, unknown>;
107
- makeHeaders: (params: HeaderParams) => Record<string, string>;
108
- makeUrl: (params: UrlParams) => string;
109
- clientSideRoutingOnly?: boolean;
110
- }
111
-
112
105
  export type AuthMethod = "none" | "hf-token" | "credentials-include" | "provider-key";
113
106
 
114
107
  export interface HeaderParams {
@@ -118,15 +111,13 @@ export interface HeaderParams {
118
111
 
119
112
  export interface UrlParams {
120
113
  authMethod: AuthMethod;
121
- baseUrl: string;
122
114
  model: string;
123
115
  task?: InferenceTask;
124
- chatCompletion?: boolean;
125
116
  }
126
117
 
127
- export interface BodyParams {
128
- args: Record<string, unknown>;
129
- chatCompletion?: boolean;
118
+ export interface BodyParams<T extends Record<string, unknown> = Record<string, unknown>> {
119
+ args: T;
130
120
  model: string;
121
+ mapping?: InferenceProviderModelMapping | undefined;
131
122
  task?: InferenceTask;
132
123
  }
@@ -1,3 +1,4 @@
1
+ import type { getProviderHelper } from "../lib/getProviderHelper";
1
2
  import { makeRequestOptions } from "../lib/makeRequestOptions";
2
3
  import type { InferenceTask, Options, RequestArgs } from "../types";
3
4
  import type { EventSourceMessage } from "../vendor/fetch-event-source/parse";
@@ -16,6 +17,7 @@ export interface ResponseWrapper<T> {
16
17
  */
17
18
  export async function innerRequest<T>(
18
19
  args: RequestArgs,
20
+ providerHelper: ReturnType<typeof getProviderHelper>,
19
21
  options?: Options & {
20
22
  /** In most cases (unless we pass a endpointUrl) we know the task */
21
23
  task?: InferenceTask;
@@ -23,13 +25,13 @@ export async function innerRequest<T>(
23
25
  chatCompletion?: boolean;
24
26
  }
25
27
  ): Promise<ResponseWrapper<T>> {
26
- const { url, info } = await makeRequestOptions(args, options);
28
+ const { url, info } = await makeRequestOptions(args, providerHelper, options);
27
29
  const response = await (options?.fetch ?? fetch)(url, info);
28
30
 
29
31
  const requestContext: ResponseWrapper<T>["requestContext"] = { url, info };
30
32
 
31
33
  if (options?.retry_on_error !== false && response.status === 503) {
32
- return innerRequest(args, options);
34
+ return innerRequest(args, providerHelper, options);
33
35
  }
34
36
 
35
37
  if (!response.ok) {
@@ -65,6 +67,7 @@ export async function innerRequest<T>(
65
67
  */
66
68
  export async function* innerStreamingRequest<T>(
67
69
  args: RequestArgs,
70
+ providerHelper: ReturnType<typeof getProviderHelper>,
68
71
  options?: Options & {
69
72
  /** In most cases (unless we pass a endpointUrl) we know the task */
70
73
  task?: InferenceTask;
@@ -72,11 +75,11 @@ export async function* innerStreamingRequest<T>(
72
75
  chatCompletion?: boolean;
73
76
  }
74
77
  ): AsyncGenerator<T> {
75
- const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
78
+ const { url, info } = await makeRequestOptions({ ...args, stream: true }, providerHelper, options);
76
79
  const response = await (options?.fetch ?? fetch)(url, info);
77
80
 
78
81
  if (options?.retry_on_error !== false && response.status === 503) {
79
- return yield* innerStreamingRequest(args, options);
82
+ return yield* innerStreamingRequest(args, providerHelper, options);
80
83
  }
81
84
  if (!response.ok) {
82
85
  if (response.headers.get("Content-Type")?.startsWith("application/json")) {
@@ -1,10 +0,0 @@
1
- import type { InferenceProvider, InferenceTask, Options, RequestArgs } from "../types";
2
- export declare function getProviderModelId(params: {
3
- model: string;
4
- provider: InferenceProvider;
5
- }, args: RequestArgs, options?: {
6
- task?: InferenceTask;
7
- chatCompletion?: boolean;
8
- fetch?: Options["fetch"];
9
- }): Promise<string>;
10
- //# sourceMappingURL=getProviderModelId.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"getProviderModelId.d.ts","sourceRoot":"","sources":["../../../src/lib/getProviderModelId.ts"],"names":[],"mappings":"AACA,OAAO,KAAK,EAAE,iBAAiB,EAAE,aAAa,EAAW,OAAO,EAAE,WAAW,EAAE,MAAM,UAAU,CAAC;AAShG,wBAAsB,kBAAkB,CACvC,MAAM,EAAE;IACP,KAAK,EAAE,MAAM,CAAC;IACd,QAAQ,EAAE,iBAAiB,CAAC;CAC5B,EACD,IAAI,EAAE,WAAW,EACjB,OAAO,GAAE;IACR,IAAI,CAAC,EAAE,aAAa,CAAC;IACrB,cAAc,CAAC,EAAE,OAAO,CAAC;IACzB,KAAK,CAAC,EAAE,OAAO,CAAC,OAAO,CAAC,CAAC;CACpB,GACJ,OAAO,CAAC,MAAM,CAAC,CAoDjB"}
@@ -1,74 +0,0 @@
1
- import type { WidgetType } from "@huggingface/tasks";
2
- import type { InferenceProvider, InferenceTask, ModelId, Options, RequestArgs } from "../types";
3
- import { HF_HUB_URL } from "../config";
4
- import { HARDCODED_MODEL_ID_MAPPING } from "../providers/consts";
5
-
6
- type InferenceProviderMapping = Partial<
7
- Record<InferenceProvider, { providerId: string; status: "live" | "staging"; task: WidgetType }>
8
- >;
9
- const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
10
-
11
- export async function getProviderModelId(
12
- params: {
13
- model: string;
14
- provider: InferenceProvider;
15
- },
16
- args: RequestArgs,
17
- options: {
18
- task?: InferenceTask;
19
- chatCompletion?: boolean;
20
- fetch?: Options["fetch"];
21
- } = {}
22
- ): Promise<string> {
23
- if (params.provider === "hf-inference") {
24
- return params.model;
25
- }
26
- if (!options.task) {
27
- throw new Error("task must be specified when using a third-party provider");
28
- }
29
- const task: WidgetType =
30
- options.task === "text-generation" && options.chatCompletion ? "conversational" : options.task;
31
-
32
- // A dict called HARDCODED_MODEL_ID_MAPPING takes precedence in all cases (useful for dev purposes)
33
- if (HARDCODED_MODEL_ID_MAPPING[params.provider]?.[params.model]) {
34
- return HARDCODED_MODEL_ID_MAPPING[params.provider][params.model];
35
- }
36
-
37
- let inferenceProviderMapping: InferenceProviderMapping | null;
38
- if (inferenceProviderMappingCache.has(params.model)) {
39
- // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
40
- inferenceProviderMapping = inferenceProviderMappingCache.get(params.model)!;
41
- } else {
42
- inferenceProviderMapping = await (options?.fetch ?? fetch)(
43
- `${HF_HUB_URL}/api/models/${params.model}?expand[]=inferenceProviderMapping`,
44
- {
45
- headers: args.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${args.accessToken}` } : {},
46
- }
47
- )
48
- .then((resp) => resp.json())
49
- .then((json) => json.inferenceProviderMapping)
50
- .catch(() => null);
51
- }
52
-
53
- if (!inferenceProviderMapping) {
54
- throw new Error(`We have not been able to find inference provider information for model ${params.model}.`);
55
- }
56
-
57
- const providerMapping = inferenceProviderMapping[params.provider];
58
- if (providerMapping) {
59
- if (providerMapping.task !== task) {
60
- throw new Error(
61
- `Model ${params.model} is not supported for task ${task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
62
- );
63
- }
64
- if (providerMapping.status === "staging") {
65
- console.warn(
66
- `Model ${params.model} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
67
- );
68
- }
69
- // TODO: how is it handled server-side if model has multiple tasks (e.g. `text-generation` + `conversational`)?
70
- return providerMapping.providerId;
71
- }
72
-
73
- throw new Error(`Model ${params.model} is not supported provider ${params.provider}.`);
74
- }