@huggingface/inference 3.3.6 → 3.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. package/dist/index.cjs +315 -174
  2. package/dist/index.js +315 -174
  3. package/dist/src/lib/getProviderModelId.d.ts +1 -1
  4. package/dist/src/lib/getProviderModelId.d.ts.map +1 -1
  5. package/dist/src/lib/makeRequestOptions.d.ts +2 -2
  6. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  7. package/dist/src/providers/black-forest-labs.d.ts +2 -1
  8. package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
  9. package/dist/src/providers/fal-ai.d.ts +2 -1
  10. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  11. package/dist/src/providers/fireworks-ai.d.ts +2 -1
  12. package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
  13. package/dist/src/providers/hf-inference.d.ts +3 -0
  14. package/dist/src/providers/hf-inference.d.ts.map +1 -0
  15. package/dist/src/providers/hyperbolic.d.ts +2 -1
  16. package/dist/src/providers/hyperbolic.d.ts.map +1 -1
  17. package/dist/src/providers/nebius.d.ts +2 -1
  18. package/dist/src/providers/nebius.d.ts.map +1 -1
  19. package/dist/src/providers/novita.d.ts +2 -1
  20. package/dist/src/providers/novita.d.ts.map +1 -1
  21. package/dist/src/providers/replicate.d.ts +3 -1
  22. package/dist/src/providers/replicate.d.ts.map +1 -1
  23. package/dist/src/providers/sambanova.d.ts +2 -1
  24. package/dist/src/providers/sambanova.d.ts.map +1 -1
  25. package/dist/src/providers/together.d.ts +2 -1
  26. package/dist/src/providers/together.d.ts.map +1 -1
  27. package/dist/src/tasks/custom/request.d.ts +2 -4
  28. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  29. package/dist/src/tasks/custom/streamingRequest.d.ts +2 -4
  30. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  31. package/dist/src/tasks/nlp/featureExtraction.d.ts +2 -9
  32. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  33. package/dist/src/types.d.ts +24 -3
  34. package/dist/src/types.d.ts.map +1 -1
  35. package/package.json +2 -2
  36. package/src/lib/getProviderModelId.ts +4 -4
  37. package/src/lib/makeRequestOptions.ts +72 -186
  38. package/src/providers/black-forest-labs.ts +26 -2
  39. package/src/providers/consts.ts +1 -1
  40. package/src/providers/fal-ai.ts +24 -2
  41. package/src/providers/fireworks-ai.ts +28 -2
  42. package/src/providers/hf-inference.ts +43 -0
  43. package/src/providers/hyperbolic.ts +28 -2
  44. package/src/providers/nebius.ts +34 -2
  45. package/src/providers/novita.ts +31 -2
  46. package/src/providers/replicate.ts +30 -2
  47. package/src/providers/sambanova.ts +28 -2
  48. package/src/providers/together.ts +34 -2
  49. package/src/tasks/audio/audioClassification.ts +1 -1
  50. package/src/tasks/audio/audioToAudio.ts +1 -1
  51. package/src/tasks/audio/automaticSpeechRecognition.ts +1 -1
  52. package/src/tasks/audio/textToSpeech.ts +1 -1
  53. package/src/tasks/custom/request.ts +2 -4
  54. package/src/tasks/custom/streamingRequest.ts +2 -4
  55. package/src/tasks/cv/imageClassification.ts +1 -1
  56. package/src/tasks/cv/imageSegmentation.ts +1 -1
  57. package/src/tasks/cv/imageToImage.ts +1 -1
  58. package/src/tasks/cv/imageToText.ts +1 -1
  59. package/src/tasks/cv/objectDetection.ts +1 -1
  60. package/src/tasks/cv/textToImage.ts +1 -1
  61. package/src/tasks/cv/textToVideo.ts +1 -1
  62. package/src/tasks/cv/zeroShotImageClassification.ts +1 -1
  63. package/src/tasks/multimodal/documentQuestionAnswering.ts +1 -1
  64. package/src/tasks/multimodal/visualQuestionAnswering.ts +1 -1
  65. package/src/tasks/nlp/chatCompletion.ts +1 -1
  66. package/src/tasks/nlp/chatCompletionStream.ts +1 -1
  67. package/src/tasks/nlp/featureExtraction.ts +3 -10
  68. package/src/tasks/nlp/fillMask.ts +1 -1
  69. package/src/tasks/nlp/questionAnswering.ts +1 -1
  70. package/src/tasks/nlp/sentenceSimilarity.ts +1 -1
  71. package/src/tasks/nlp/summarization.ts +1 -1
  72. package/src/tasks/nlp/tableQuestionAnswering.ts +1 -1
  73. package/src/tasks/nlp/textClassification.ts +1 -1
  74. package/src/tasks/nlp/textGeneration.ts +3 -3
  75. package/src/tasks/nlp/textGenerationStream.ts +1 -1
  76. package/src/tasks/nlp/tokenClassification.ts +1 -1
  77. package/src/tasks/nlp/translation.ts +1 -1
  78. package/src/tasks/nlp/zeroShotClassification.ts +1 -1
  79. package/src/tasks/tabular/tabularClassification.ts +1 -1
  80. package/src/tasks/tabular/tabularRegression.ts +1 -1
  81. package/src/types.ts +28 -2
@@ -1,5 +1,3 @@
1
- export const TOGETHER_API_BASE_URL = "https://api.together.xyz";
2
-
3
1
  /**
4
2
  * See the registered mapping of HF model ID => Together model ID here:
5
3
  *
@@ -16,3 +14,37 @@ export const TOGETHER_API_BASE_URL = "https://api.together.xyz";
16
14
  *
17
15
  * Thanks!
18
16
  */
17
+ import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18
+
19
+ const TOGETHER_API_BASE_URL = "https://api.together.xyz";
20
+
21
+ const makeBody = (params: BodyParams): Record<string, unknown> => {
22
+ return {
23
+ ...params.args,
24
+ model: params.model,
25
+ };
26
+ };
27
+
28
+ const makeHeaders = (params: HeaderParams): Record<string, string> => {
29
+ return { Authorization: `Bearer ${params.accessToken}` };
30
+ };
31
+
32
+ const makeUrl = (params: UrlParams): string => {
33
+ if (params.task === "text-to-image") {
34
+ return `${params.baseUrl}/v1/images/generations`;
35
+ }
36
+ if (params.task === "text-generation") {
37
+ if (params.chatCompletion) {
38
+ return `${params.baseUrl}/v1/chat/completions`;
39
+ }
40
+ return `${params.baseUrl}/v1/completions`;
41
+ }
42
+ return params.baseUrl;
43
+ };
44
+
45
+ export const TOGETHER_CONFIG: ProviderConfig = {
46
+ baseUrl: TOGETHER_API_BASE_URL,
47
+ makeBody,
48
+ makeHeaders,
49
+ makeUrl,
50
+ };
@@ -18,7 +18,7 @@ export async function audioClassification(
18
18
  const payload = preparePayload(args);
19
19
  const res = await request<AudioClassificationOutput>(payload, {
20
20
  ...options,
21
- taskHint: "audio-classification",
21
+ task: "audio-classification",
22
22
  });
23
23
  const isValidOutput =
24
24
  Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
@@ -39,7 +39,7 @@ export async function audioToAudio(args: AudioToAudioArgs, options?: Options): P
39
39
  const payload = preparePayload(args);
40
40
  const res = await request<AudioToAudioOutput>(payload, {
41
41
  ...options,
42
- taskHint: "audio-to-audio",
42
+ task: "audio-to-audio",
43
43
  });
44
44
 
45
45
  return validateOutput(res);
@@ -19,7 +19,7 @@ export async function automaticSpeechRecognition(
19
19
  const payload = await buildPayload(args);
20
20
  const res = await request<AutomaticSpeechRecognitionOutput>(payload, {
21
21
  ...options,
22
- taskHint: "automatic-speech-recognition",
22
+ task: "automatic-speech-recognition",
23
23
  });
24
24
  const isValidOutput = typeof res?.text === "string";
25
25
  if (!isValidOutput) {
@@ -24,7 +24,7 @@ export async function textToSpeech(args: TextToSpeechArgs, options?: Options): P
24
24
  : args;
25
25
  const res = await request<Blob | OutputUrlTextToSpeechGeneration>(payload, {
26
26
  ...options,
27
- taskHint: "text-to-speech",
27
+ task: "text-to-speech",
28
28
  });
29
29
  if (res instanceof Blob) {
30
30
  return res;
@@ -7,10 +7,8 @@ import { makeRequestOptions } from "../../lib/makeRequestOptions";
7
7
  export async function request<T>(
8
8
  args: RequestArgs,
9
9
  options?: Options & {
10
- /** When a model can be used for multiple tasks, and we want to run a non-default task */
11
- task?: string | InferenceTask;
12
- /** To load default model if needed */
13
- taskHint?: InferenceTask;
10
+ /** In most cases (unless we pass a endpointUrl) we know the task */
11
+ task?: InferenceTask;
14
12
  /** Is chat completion compatible */
15
13
  chatCompletion?: boolean;
16
14
  }
@@ -9,10 +9,8 @@ import { getLines, getMessages } from "../../vendor/fetch-event-source/parse";
9
9
  export async function* streamingRequest<T>(
10
10
  args: RequestArgs,
11
11
  options?: Options & {
12
- /** When a model can be used for multiple tasks, and we want to run a non-default task */
13
- task?: string | InferenceTask;
14
- /** To load default model if needed */
15
- taskHint?: InferenceTask;
12
+ /** In most cases (unless we pass a endpointUrl) we know the task */
13
+ task?: InferenceTask;
16
14
  /** Is chat completion compatible */
17
15
  chatCompletion?: boolean;
18
16
  }
@@ -17,7 +17,7 @@ export async function imageClassification(
17
17
  const payload = preparePayload(args);
18
18
  const res = await request<ImageClassificationOutput>(payload, {
19
19
  ...options,
20
- taskHint: "image-classification",
20
+ task: "image-classification",
21
21
  });
22
22
  const isValidOutput =
23
23
  Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
@@ -17,7 +17,7 @@ export async function imageSegmentation(
17
17
  const payload = preparePayload(args);
18
18
  const res = await request<ImageSegmentationOutput>(payload, {
19
19
  ...options,
20
- taskHint: "image-segmentation",
20
+ task: "image-segmentation",
21
21
  });
22
22
  const isValidOutput =
23
23
  Array.isArray(res) &&
@@ -28,7 +28,7 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P
28
28
  }
29
29
  const res = await request<Blob>(reqArgs, {
30
30
  ...options,
31
- taskHint: "image-to-image",
31
+ task: "image-to-image",
32
32
  });
33
33
  const isValidOutput = res && res instanceof Blob;
34
34
  if (!isValidOutput) {
@@ -14,7 +14,7 @@ export async function imageToText(args: ImageToTextArgs, options?: Options): Pro
14
14
  const res = (
15
15
  await request<[ImageToTextOutput]>(payload, {
16
16
  ...options,
17
- taskHint: "image-to-text",
17
+ task: "image-to-text",
18
18
  })
19
19
  )?.[0];
20
20
 
@@ -14,7 +14,7 @@ export async function objectDetection(args: ObjectDetectionArgs, options?: Optio
14
14
  const payload = preparePayload(args);
15
15
  const res = await request<ObjectDetectionOutput>(payload, {
16
16
  ...options,
17
- taskHint: "object-detection",
17
+ task: "object-detection",
18
18
  });
19
19
  const isValidOutput =
20
20
  Array.isArray(res) &&
@@ -73,7 +73,7 @@ export async function textToImage(args: TextToImageArgs, options?: TextToImageOp
73
73
  | HyperbolicTextToImageOutput
74
74
  >(payload, {
75
75
  ...options,
76
- taskHint: "text-to-image",
76
+ task: "text-to-image",
77
77
  });
78
78
 
79
79
  if (res && typeof res === "object") {
@@ -35,7 +35,7 @@ export async function textToVideo(args: TextToVideoArgs, options?: Options): Pro
35
35
  : args;
36
36
  const res = await request<FalAiOutput | ReplicateOutput>(payload, {
37
37
  ...options,
38
- taskHint: "text-to-video",
38
+ task: "text-to-video",
39
39
  });
40
40
 
41
41
  if (args.provider === "fal-ai") {
@@ -48,7 +48,7 @@ export async function zeroShotImageClassification(
48
48
  const payload = await preparePayload(args);
49
49
  const res = await request<ZeroShotImageClassificationOutput>(payload, {
50
50
  ...options,
51
- taskHint: "zero-shot-image-classification",
51
+ task: "zero-shot-image-classification",
52
52
  });
53
53
  const isValidOutput =
54
54
  Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
@@ -32,7 +32,7 @@ export async function documentQuestionAnswering(
32
32
  const res = toArray(
33
33
  await request<DocumentQuestionAnsweringOutput | DocumentQuestionAnsweringOutput[number]>(reqArgs, {
34
34
  ...options,
35
- taskHint: "document-question-answering",
35
+ task: "document-question-answering",
36
36
  })
37
37
  );
38
38
 
@@ -29,7 +29,7 @@ export async function visualQuestionAnswering(
29
29
  } as RequestArgs;
30
30
  const res = await request<VisualQuestionAnsweringOutput>(reqArgs, {
31
31
  ...options,
32
- taskHint: "visual-question-answering",
32
+ task: "visual-question-answering",
33
33
  });
34
34
  const isValidOutput =
35
35
  Array.isArray(res) &&
@@ -12,7 +12,7 @@ export async function chatCompletion(
12
12
  ): Promise<ChatCompletionOutput> {
13
13
  const res = await request<ChatCompletionOutput>(args, {
14
14
  ...options,
15
- taskHint: "text-generation",
15
+ task: "text-generation",
16
16
  chatCompletion: true,
17
17
  });
18
18
 
@@ -11,7 +11,7 @@ export async function* chatCompletionStream(
11
11
  ): AsyncGenerator<ChatCompletionStreamOutput> {
12
12
  yield* streamingRequest<ChatCompletionStreamOutput>(args, {
13
13
  ...options,
14
- taskHint: "text-generation",
14
+ task: "text-generation",
15
15
  chatCompletion: true,
16
16
  });
17
17
  }
@@ -1,16 +1,9 @@
1
+ import type { FeatureExtractionInput } from "@huggingface/tasks";
1
2
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
2
3
  import type { BaseArgs, Options } from "../../types";
3
4
  import { request } from "../custom/request";
4
5
 
5
- export type FeatureExtractionArgs = BaseArgs & {
6
- /**
7
- * The inputs is a string or a list of strings to get the features from.
8
- *
9
- * inputs: "That is a happy person",
10
- *
11
- */
12
- inputs: string | string[];
13
- };
6
+ export type FeatureExtractionArgs = BaseArgs & FeatureExtractionInput;
14
7
 
15
8
  /**
16
9
  * Returned values are a multidimensional array of floats (dimension depending on if you sent a string or a list of string, and if the automatic reduction, usually mean_pooling for instance was applied for you or not. This should be explained on the model's README).
@@ -26,7 +19,7 @@ export async function featureExtraction(
26
19
  ): Promise<FeatureExtractionOutput> {
27
20
  const res = await request<FeatureExtractionOutput>(args, {
28
21
  ...options,
29
- taskHint: "feature-extraction",
22
+ task: "feature-extraction",
30
23
  });
31
24
  let isValidOutput = true;
32
25
 
@@ -11,7 +11,7 @@ export type FillMaskArgs = BaseArgs & FillMaskInput;
11
11
  export async function fillMask(args: FillMaskArgs, options?: Options): Promise<FillMaskOutput> {
12
12
  const res = await request<FillMaskOutput>(args, {
13
13
  ...options,
14
- taskHint: "fill-mask",
14
+ task: "fill-mask",
15
15
  });
16
16
  const isValidOutput =
17
17
  Array.isArray(res) &&
@@ -14,7 +14,7 @@ export async function questionAnswering(
14
14
  ): Promise<QuestionAnsweringOutput[number]> {
15
15
  const res = await request<QuestionAnsweringOutput | QuestionAnsweringOutput[number]>(args, {
16
16
  ...options,
17
- taskHint: "question-answering",
17
+ task: "question-answering",
18
18
  });
19
19
  const isValidOutput = Array.isArray(res)
20
20
  ? res.every(
@@ -15,7 +15,7 @@ export async function sentenceSimilarity(
15
15
  ): Promise<SentenceSimilarityOutput> {
16
16
  const res = await request<SentenceSimilarityOutput>(prepareInput(args), {
17
17
  ...options,
18
- taskHint: "sentence-similarity",
18
+ task: "sentence-similarity",
19
19
  });
20
20
 
21
21
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
@@ -11,7 +11,7 @@ export type SummarizationArgs = BaseArgs & SummarizationInput;
11
11
  export async function summarization(args: SummarizationArgs, options?: Options): Promise<SummarizationOutput> {
12
12
  const res = await request<SummarizationOutput[]>(args, {
13
13
  ...options,
14
- taskHint: "summarization",
14
+ task: "summarization",
15
15
  });
16
16
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string");
17
17
  if (!isValidOutput) {
@@ -14,7 +14,7 @@ export async function tableQuestionAnswering(
14
14
  ): Promise<TableQuestionAnsweringOutput[number]> {
15
15
  const res = await request<TableQuestionAnsweringOutput | TableQuestionAnsweringOutput[number]>(args, {
16
16
  ...options,
17
- taskHint: "table-question-answering",
17
+ task: "table-question-answering",
18
18
  });
19
19
  const isValidOutput = Array.isArray(res) ? res.every((elem) => validate(elem)) : validate(res);
20
20
  if (!isValidOutput) {
@@ -15,7 +15,7 @@ export async function textClassification(
15
15
  const res = (
16
16
  await request<TextClassificationOutput>(args, {
17
17
  ...options,
18
- taskHint: "text-classification",
18
+ task: "text-classification",
19
19
  })
20
20
  )?.[0];
21
21
  const isValidOutput =
@@ -39,7 +39,7 @@ export async function textGeneration(
39
39
  args.prompt = args.inputs;
40
40
  const raw = await request<TogeteherTextCompletionOutput>(args, {
41
41
  ...options,
42
- taskHint: "text-generation",
42
+ task: "text-generation",
43
43
  });
44
44
  const isValidOutput =
45
45
  typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
@@ -63,7 +63,7 @@ export async function textGeneration(
63
63
  };
64
64
  const raw = await request<HyperbolicTextCompletionOutput>(payload, {
65
65
  ...options,
66
- taskHint: "text-generation",
66
+ task: "text-generation",
67
67
  });
68
68
  const isValidOutput =
69
69
  typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
@@ -78,7 +78,7 @@ export async function textGeneration(
78
78
  const res = toArray(
79
79
  await request<TextGenerationOutput | TextGenerationOutput[]>(args, {
80
80
  ...options,
81
- taskHint: "text-generation",
81
+ task: "text-generation",
82
82
  })
83
83
  );
84
84
 
@@ -91,6 +91,6 @@ export async function* textGenerationStream(
91
91
  ): AsyncGenerator<TextGenerationStreamOutput> {
92
92
  yield* streamingRequest<TextGenerationStreamOutput>(args, {
93
93
  ...options,
94
- taskHint: "text-generation",
94
+ task: "text-generation",
95
95
  });
96
96
  }
@@ -16,7 +16,7 @@ export async function tokenClassification(
16
16
  const res = toArray(
17
17
  await request<TokenClassificationOutput[number] | TokenClassificationOutput>(args, {
18
18
  ...options,
19
- taskHint: "token-classification",
19
+ task: "token-classification",
20
20
  })
21
21
  );
22
22
  const isValidOutput =
@@ -10,7 +10,7 @@ export type TranslationArgs = BaseArgs & TranslationInput;
10
10
  export async function translation(args: TranslationArgs, options?: Options): Promise<TranslationOutput> {
11
11
  const res = await request<TranslationOutput>(args, {
12
12
  ...options,
13
- taskHint: "translation",
13
+ task: "translation",
14
14
  });
15
15
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");
16
16
  if (!isValidOutput) {
@@ -16,7 +16,7 @@ export async function zeroShotClassification(
16
16
  const res = toArray(
17
17
  await request<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>(args, {
18
18
  ...options,
19
- taskHint: "zero-shot-classification",
19
+ task: "zero-shot-classification",
20
20
  })
21
21
  );
22
22
  const isValidOutput =
@@ -27,7 +27,7 @@ export async function tabularClassification(
27
27
  ): Promise<TabularClassificationOutput> {
28
28
  const res = await request<TabularClassificationOutput>(args, {
29
29
  ...options,
30
- taskHint: "tabular-classification",
30
+ task: "tabular-classification",
31
31
  });
32
32
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
33
33
  if (!isValidOutput) {
@@ -27,7 +27,7 @@ export async function tabularRegression(
27
27
  ): Promise<TabularRegressionOutput> {
28
28
  const res = await request<TabularRegressionOutput>(args, {
29
29
  ...options,
30
- taskHint: "tabular-regression",
30
+ task: "tabular-regression",
31
31
  });
32
32
  const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
33
33
  if (!isValidOutput) {
package/src/types.ts CHANGED
@@ -1,4 +1,4 @@
1
- import type { ChatCompletionInput, PipelineType } from "@huggingface/tasks";
1
+ import type { ChatCompletionInput, FeatureExtractionInput, PipelineType } from "@huggingface/tasks";
2
2
 
3
3
  /**
4
4
  * HF model id, like "meta-llama/Llama-3.3-70B-Instruct"
@@ -86,7 +86,33 @@ export type RequestArgs = BaseArgs &
86
86
  | { text: string }
87
87
  | { audio_url: string }
88
88
  | ChatCompletionInput
89
+ | FeatureExtractionInput
89
90
  ) & {
90
91
  parameters?: Record<string, unknown>;
91
- accessToken?: string;
92
92
  };
93
+
94
+ export interface ProviderConfig {
95
+ baseUrl: string;
96
+ makeBody: (params: BodyParams) => Record<string, unknown>;
97
+ makeHeaders: (params: HeaderParams) => Record<string, string>;
98
+ makeUrl: (params: UrlParams) => string;
99
+ }
100
+
101
+ export interface HeaderParams {
102
+ accessToken?: string;
103
+ authMethod: "none" | "hf-token" | "credentials-include" | "provider-key";
104
+ }
105
+
106
+ export interface UrlParams {
107
+ baseUrl: string;
108
+ model: string;
109
+ task?: InferenceTask;
110
+ chatCompletion?: boolean;
111
+ }
112
+
113
+ export interface BodyParams {
114
+ args: Record<string, unknown>;
115
+ chatCompletion?: boolean;
116
+ model: string;
117
+ task?: InferenceTask;
118
+ }