@huggingface/inference 3.9.2 → 3.11.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. package/README.md +9 -7
  2. package/dist/index.cjs +771 -646
  3. package/dist/index.js +770 -646
  4. package/dist/src/InferenceClient.d.ts +16 -17
  5. package/dist/src/InferenceClient.d.ts.map +1 -1
  6. package/dist/src/lib/getInferenceProviderMapping.d.ts +6 -2
  7. package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -1
  8. package/dist/src/lib/getProviderHelper.d.ts.map +1 -1
  9. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  10. package/dist/src/providers/consts.d.ts.map +1 -1
  11. package/dist/src/providers/ovhcloud.d.ts +38 -0
  12. package/dist/src/providers/ovhcloud.d.ts.map +1 -0
  13. package/dist/src/providers/providerHelper.d.ts +1 -1
  14. package/dist/src/providers/providerHelper.d.ts.map +1 -1
  15. package/dist/src/snippets/getInferenceSnippets.d.ts +1 -1
  16. package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
  17. package/dist/src/snippets/templates.exported.d.ts.map +1 -1
  18. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  19. package/dist/src/tasks/audio/audioToAudio.d.ts.map +1 -1
  20. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  21. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  22. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  23. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  24. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  25. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  26. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  27. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  28. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  29. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  30. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
  31. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  32. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  33. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  34. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  35. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
  36. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  37. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  38. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  39. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  40. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  41. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  42. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  43. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  44. package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
  45. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  46. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  47. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  48. package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
  49. package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
  50. package/dist/src/types.d.ts +7 -5
  51. package/dist/src/types.d.ts.map +1 -1
  52. package/dist/src/utils/typedEntries.d.ts +4 -0
  53. package/dist/src/utils/typedEntries.d.ts.map +1 -0
  54. package/package.json +3 -3
  55. package/src/InferenceClient.ts +32 -43
  56. package/src/lib/getInferenceProviderMapping.ts +68 -19
  57. package/src/lib/getProviderHelper.ts +5 -0
  58. package/src/lib/makeRequestOptions.ts +4 -3
  59. package/src/providers/consts.ts +1 -0
  60. package/src/providers/ovhcloud.ts +75 -0
  61. package/src/providers/providerHelper.ts +1 -1
  62. package/src/snippets/getInferenceSnippets.ts +5 -4
  63. package/src/snippets/templates.exported.ts +7 -3
  64. package/src/tasks/audio/audioClassification.ts +3 -1
  65. package/src/tasks/audio/audioToAudio.ts +4 -1
  66. package/src/tasks/audio/automaticSpeechRecognition.ts +3 -1
  67. package/src/tasks/audio/textToSpeech.ts +2 -1
  68. package/src/tasks/custom/request.ts +3 -1
  69. package/src/tasks/custom/streamingRequest.ts +3 -1
  70. package/src/tasks/cv/imageClassification.ts +3 -1
  71. package/src/tasks/cv/imageSegmentation.ts +3 -1
  72. package/src/tasks/cv/imageToImage.ts +3 -1
  73. package/src/tasks/cv/imageToText.ts +3 -1
  74. package/src/tasks/cv/objectDetection.ts +3 -1
  75. package/src/tasks/cv/textToImage.ts +2 -1
  76. package/src/tasks/cv/textToVideo.ts +2 -1
  77. package/src/tasks/cv/zeroShotImageClassification.ts +3 -1
  78. package/src/tasks/multimodal/documentQuestionAnswering.ts +3 -1
  79. package/src/tasks/multimodal/visualQuestionAnswering.ts +3 -1
  80. package/src/tasks/nlp/chatCompletion.ts +3 -1
  81. package/src/tasks/nlp/chatCompletionStream.ts +3 -1
  82. package/src/tasks/nlp/featureExtraction.ts +3 -1
  83. package/src/tasks/nlp/fillMask.ts +3 -1
  84. package/src/tasks/nlp/questionAnswering.ts +4 -1
  85. package/src/tasks/nlp/sentenceSimilarity.ts +3 -1
  86. package/src/tasks/nlp/summarization.ts +3 -1
  87. package/src/tasks/nlp/tableQuestionAnswering.ts +3 -1
  88. package/src/tasks/nlp/textClassification.ts +3 -1
  89. package/src/tasks/nlp/textGeneration.ts +3 -1
  90. package/src/tasks/nlp/textGenerationStream.ts +3 -1
  91. package/src/tasks/nlp/tokenClassification.ts +3 -1
  92. package/src/tasks/nlp/translation.ts +3 -1
  93. package/src/tasks/nlp/zeroShotClassification.ts +3 -1
  94. package/src/tasks/tabular/tabularClassification.ts +3 -1
  95. package/src/tasks/tabular/tabularRegression.ts +3 -1
  96. package/src/types.ts +9 -4
  97. package/src/utils/typedEntries.ts +5 -0
@@ -1,73 +1,62 @@
1
1
  import * as tasks from "./tasks";
2
- import type { Options, RequestArgs } from "./types";
3
- import type { DistributiveOmit } from "./utils/distributive-omit";
2
+ import type { Options } from "./types";
3
+ import { omit } from "./utils/omit";
4
+ import { typedEntries } from "./utils/typedEntries";
4
5
 
5
6
  /* eslint-disable @typescript-eslint/no-empty-interface */
6
7
  /* eslint-disable @typescript-eslint/no-unsafe-declaration-merging */
7
8
 
8
9
  type Task = typeof tasks;
9
10
 
10
- type TaskWithNoAccessToken = {
11
- [key in keyof Task]: (
12
- args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken">,
13
- options?: Parameters<Task[key]>[1]
14
- ) => ReturnType<Task[key]>;
15
- };
16
-
17
- type TaskWithNoAccessTokenNoEndpointUrl = {
18
- [key in keyof Task]: (
19
- args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "endpointUrl">,
20
- options?: Parameters<Task[key]>[1]
21
- ) => ReturnType<Task[key]>;
22
- };
23
-
24
11
  export class InferenceClient {
25
12
  private readonly accessToken: string;
26
13
  private readonly defaultOptions: Options;
27
14
 
28
- constructor(accessToken = "", defaultOptions: Options = {}) {
15
+ constructor(
16
+ accessToken = "",
17
+ defaultOptions: Options & {
18
+ endpointUrl?: string;
19
+ } = {}
20
+ ) {
29
21
  this.accessToken = accessToken;
30
22
  this.defaultOptions = defaultOptions;
31
23
 
32
- for (const [name, fn] of Object.entries(tasks)) {
24
+ for (const [name, fn] of typedEntries(tasks)) {
33
25
  Object.defineProperty(this, name, {
34
26
  enumerable: false,
35
- value: (params: RequestArgs, options: Options) =>
27
+ value: (params: Parameters<typeof fn>[0], options: Parameters<typeof fn>[1]) =>
36
28
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
37
- fn({ ...params, accessToken } as any, { ...defaultOptions, ...options }),
29
+ (fn as any)(
30
+ /// ^ The cast of fn to any is necessary, otherwise TS can't compile because the generated union type is too complex
31
+ { endpointUrl: defaultOptions.endpointUrl, accessToken, ...params },
32
+ {
33
+ ...omit(defaultOptions, ["endpointUrl"]),
34
+ ...options,
35
+ }
36
+ ),
38
37
  });
39
38
  }
40
39
  }
41
40
 
42
41
  /**
43
- * Returns copy of InferenceClient tied to a specified endpoint.
42
+ * Returns a new instance of InferenceClient tied to a specified endpoint.
43
+ *
44
+ * For backward compatibility mostly.
44
45
  */
45
- public endpoint(endpointUrl: string): InferenceClientEndpoint {
46
- return new InferenceClientEndpoint(endpointUrl, this.accessToken, this.defaultOptions);
47
- }
48
- }
49
-
50
- export class InferenceClientEndpoint {
51
- constructor(endpointUrl: string, accessToken = "", defaultOptions: Options = {}) {
52
- accessToken;
53
- defaultOptions;
54
-
55
- for (const [name, fn] of Object.entries(tasks)) {
56
- Object.defineProperty(this, name, {
57
- enumerable: false,
58
- value: (params: RequestArgs, options: Options) =>
59
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
60
- fn({ ...params, accessToken, endpointUrl } as any, { ...defaultOptions, ...options }),
61
- });
62
- }
46
+ public endpoint(endpointUrl: string): InferenceClient {
47
+ return new InferenceClient(this.accessToken, { ...this.defaultOptions, endpointUrl });
63
48
  }
64
49
  }
65
50
 
66
- export interface InferenceClient extends TaskWithNoAccessToken {}
67
-
68
- export interface InferenceClientEndpoint extends TaskWithNoAccessTokenNoEndpointUrl {}
51
+ export interface InferenceClient extends Task {}
69
52
 
70
53
  /**
71
- * For backward compatibility only.
54
+ * For backward compatibility only, will remove soon.
55
+ * @deprecated replace with InferenceClient
72
56
  */
73
57
  export class HfInference extends InferenceClient {}
58
+ /**
59
+ * For backward compatibility only, will remove soon.
60
+ * @deprecated replace with InferenceClient
61
+ */
62
+ export class InferenceClientEndpoint extends InferenceClient {}
@@ -1,8 +1,8 @@
1
1
  import type { WidgetType } from "@huggingface/tasks";
2
- import type { InferenceProvider, ModelId } from "../types";
3
2
  import { HF_HUB_URL } from "../config";
4
3
  import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts";
5
4
  import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference";
5
+ import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../types";
6
6
  import { typedInclude } from "../utils/typedInclude";
7
7
 
8
8
  export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
@@ -20,44 +20,62 @@ export interface InferenceProviderModelMapping {
20
20
  task: WidgetType;
21
21
  }
22
22
 
23
- export async function getInferenceProviderMapping(
24
- params: {
25
- accessToken?: string;
26
- modelId: ModelId;
27
- provider: InferenceProvider;
28
- task: WidgetType;
29
- },
30
- options: {
23
+ export async function fetchInferenceProviderMappingForModel(
24
+ modelId: ModelId,
25
+ accessToken?: string,
26
+ options?: {
31
27
  fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
32
28
  }
33
- ): Promise<InferenceProviderModelMapping | null> {
34
- if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
35
- return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
36
- }
29
+ ): Promise<InferenceProviderMapping> {
37
30
  let inferenceProviderMapping: InferenceProviderMapping | null;
38
- if (inferenceProviderMappingCache.has(params.modelId)) {
31
+ if (inferenceProviderMappingCache.has(modelId)) {
39
32
  // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
40
- inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId)!;
33
+ inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!;
41
34
  } else {
42
35
  const resp = await (options?.fetch ?? fetch)(
43
- `${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`,
36
+ `${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
44
37
  {
45
- headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {},
38
+ headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {},
46
39
  }
47
40
  );
48
41
  if (resp.status === 404) {
49
- throw new Error(`Model ${params.modelId} does not exist`);
42
+ throw new Error(`Model ${modelId} does not exist`);
50
43
  }
51
44
  inferenceProviderMapping = await resp
52
45
  .json()
53
46
  .then((json) => json.inferenceProviderMapping)
54
47
  .catch(() => null);
48
+
49
+ if (inferenceProviderMapping) {
50
+ inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
51
+ }
55
52
  }
56
53
 
57
54
  if (!inferenceProviderMapping) {
58
- throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`);
55
+ throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
59
56
  }
57
+ return inferenceProviderMapping;
58
+ }
60
59
 
60
+ export async function getInferenceProviderMapping(
61
+ params: {
62
+ accessToken?: string;
63
+ modelId: ModelId;
64
+ provider: InferenceProvider;
65
+ task: WidgetType;
66
+ },
67
+ options: {
68
+ fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
69
+ }
70
+ ): Promise<InferenceProviderModelMapping | null> {
71
+ if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
72
+ return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
73
+ }
74
+ const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
75
+ params.modelId,
76
+ params.accessToken,
77
+ options
78
+ );
61
79
  const providerMapping = inferenceProviderMapping[params.provider];
62
80
  if (providerMapping) {
63
81
  const equivalentTasks =
@@ -78,3 +96,34 @@ export async function getInferenceProviderMapping(
78
96
  }
79
97
  return null;
80
98
  }
99
+
100
+ export async function resolveProvider(
101
+ provider?: InferenceProviderOrPolicy,
102
+ modelId?: string,
103
+ endpointUrl?: string
104
+ ): Promise<InferenceProvider> {
105
+ if (endpointUrl) {
106
+ if (provider) {
107
+ throw new Error("Specifying both endpointUrl and provider is not supported.");
108
+ }
109
+ /// Defaulting to hf-inference helpers / API
110
+ return "hf-inference";
111
+ }
112
+ if (!provider) {
113
+ console.log(
114
+ "Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
115
+ );
116
+ provider = "auto";
117
+ }
118
+ if (provider === "auto") {
119
+ if (!modelId) {
120
+ throw new Error("Specifying a model is required when provider is 'auto'");
121
+ }
122
+ const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
123
+ provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider | undefined;
124
+ }
125
+ if (!provider) {
126
+ throw new Error(`No Inference Provider available for model ${modelId}.`);
127
+ }
128
+ return provider;
129
+ }
@@ -11,6 +11,7 @@ import * as Nebius from "../providers/nebius";
11
11
  import * as Novita from "../providers/novita";
12
12
  import * as Nscale from "../providers/nscale";
13
13
  import * as OpenAI from "../providers/openai";
14
+ import * as OvhCloud from "../providers/ovhcloud";
14
15
  import type {
15
16
  AudioClassificationTaskHelper,
16
17
  AudioToAudioTaskHelper,
@@ -126,6 +127,10 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
126
127
  openai: {
127
128
  conversational: new OpenAI.OpenAIConversationalTask(),
128
129
  },
130
+ ovhcloud: {
131
+ conversational: new OvhCloud.OvhCloudConversationalTask(),
132
+ "text-generation": new OvhCloud.OvhCloudTextGenerationTask(),
133
+ },
129
134
  replicate: {
130
135
  "text-to-image": new Replicate.ReplicateTextToImageTask(),
131
136
  "text-to-speech": new Replicate.ReplicateTextToSpeechTask(),
@@ -27,8 +27,8 @@ export async function makeRequestOptions(
27
27
  task?: InferenceTask;
28
28
  }
29
29
  ): Promise<{ url: string; info: RequestInit }> {
30
- const { provider: maybeProvider, model: maybeModel } = args;
31
- const provider = maybeProvider ?? "hf-inference";
30
+ const { model: maybeModel } = args;
31
+ const provider = providerHelper.provider;
32
32
  const { task } = options ?? {};
33
33
 
34
34
  // Validate inputs
@@ -113,8 +113,9 @@ export function makeRequestOptionsFromResolvedModel(
113
113
  ): { url: string; info: RequestInit } {
114
114
  const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
115
115
  void model;
116
+ void maybeProvider;
116
117
 
117
- const provider = maybeProvider ?? "hf-inference";
118
+ const provider = providerHelper.provider;
118
119
 
119
120
  const { includeCredentials, task, signal, billTo } = options ?? {};
120
121
  const authMethod = (() => {
@@ -32,6 +32,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
32
32
  novita: {},
33
33
  nscale: {},
34
34
  openai: {},
35
+ ovhcloud: {},
35
36
  replicate: {},
36
37
  sambanova: {},
37
38
  together: {},
@@ -0,0 +1,75 @@
1
+ /**
2
+ * See the registered mapping of HF model ID => OVHcloud model ID here:
3
+ *
4
+ * https://huggingface.co/api/partners/ovhcloud/models
5
+ *
6
+ * This is a publicly available mapping.
7
+ *
8
+ * If you want to try to run inference for a new model locally before it's registered on huggingface.co,
9
+ * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
10
+ *
11
+ * - If you work at OVHcloud and want to update this mapping, please use the model mapping API we provide on huggingface.co
12
+ * - If you're a community member and want to add a new supported HF model to OVHcloud, please open an issue on the present repo
13
+ * and we will tag OVHcloud team members.
14
+ *
15
+ * Thanks!
16
+ */
17
+
18
+ import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper";
19
+ import type { ChatCompletionOutput, TextGenerationOutput, TextGenerationOutputFinishReason } from "@huggingface/tasks";
20
+ import { InferenceOutputError } from "../lib/InferenceOutputError";
21
+ import type { BodyParams } from "../types";
22
+ import { omit } from "../utils/omit";
23
+ import type { TextGenerationInput } from "@huggingface/tasks";
24
+
25
+ const OVHCLOUD_API_BASE_URL = "https://oai.endpoints.kepler.ai.cloud.ovh.net";
26
+
27
+ interface OvhCloudTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
28
+ choices: Array<{
29
+ text: string;
30
+ finish_reason: TextGenerationOutputFinishReason;
31
+ logprobs: unknown;
32
+ index: number;
33
+ }>;
34
+ }
35
+
36
+ export class OvhCloudConversationalTask extends BaseConversationalTask {
37
+ constructor() {
38
+ super("ovhcloud", OVHCLOUD_API_BASE_URL);
39
+ }
40
+ }
41
+
42
+ export class OvhCloudTextGenerationTask extends BaseTextGenerationTask {
43
+ constructor() {
44
+ super("ovhcloud", OVHCLOUD_API_BASE_URL);
45
+ }
46
+
47
+ override preparePayload(params: BodyParams<TextGenerationInput>): Record<string, unknown> {
48
+ return {
49
+ model: params.model,
50
+ ...omit(params.args, ["inputs", "parameters"]),
51
+ ...(params.args.parameters
52
+ ? {
53
+ max_tokens: (params.args.parameters as Record<string, unknown>).max_new_tokens,
54
+ ...omit(params.args.parameters as Record<string, unknown>, "max_new_tokens"),
55
+ }
56
+ : undefined),
57
+ prompt: params.args.inputs,
58
+ };
59
+ }
60
+
61
+ override async getResponse(response: OvhCloudTextCompletionOutput): Promise<TextGenerationOutput> {
62
+ if (
63
+ typeof response === "object" &&
64
+ "choices" in response &&
65
+ Array.isArray(response?.choices) &&
66
+ typeof response?.model === "string"
67
+ ) {
68
+ const completion = response.choices[0];
69
+ return {
70
+ generated_text: completion.text,
71
+ };
72
+ }
73
+ throw new InferenceOutputError("Expected OVHcloud text generation response format");
74
+ }
75
+ }
@@ -56,7 +56,7 @@ import { toArray } from "../utils/toArray";
56
56
  */
57
57
  export abstract class TaskProviderHelper {
58
58
  constructor(
59
- private provider: InferenceProvider,
59
+ readonly provider: InferenceProvider,
60
60
  private baseUrl: string,
61
61
  readonly clientSideRoutingOnly: boolean = false
62
62
  ) {}
@@ -8,11 +8,11 @@ import {
8
8
  } from "@huggingface/tasks";
9
9
  import type { PipelineType, WidgetType } from "@huggingface/tasks/src/pipelines.js";
10
10
  import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingface/tasks/src/tasks/index.js";
11
+ import type { InferenceProviderModelMapping } from "../lib/getInferenceProviderMapping";
12
+ import { getProviderHelper } from "../lib/getProviderHelper";
11
13
  import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions";
12
14
  import type { InferenceProvider, InferenceTask, RequestArgs } from "../types";
13
15
  import { templates } from "./templates.exported";
14
- import type { InferenceProviderModelMapping } from "../lib/getInferenceProviderMapping";
15
- import { getProviderHelper } from "../lib/getProviderHelper";
16
16
 
17
17
  export type InferenceSnippetOptions = { streaming?: boolean; billTo?: string } & Record<string, unknown>;
18
18
 
@@ -112,6 +112,7 @@ const HF_JS_METHODS: Partial<Record<WidgetType, string>> = {
112
112
  "text-generation": "textGeneration",
113
113
  "text2text-generation": "textGeneration",
114
114
  "token-classification": "tokenClassification",
115
+ "text-to-speech": "textToSpeech",
115
116
  translation: "translation",
116
117
  };
117
118
 
@@ -271,7 +272,7 @@ const prepareConversationalInput = (
271
272
  return {
272
273
  messages: opts?.messages ?? getModelInputSnippet(model),
273
274
  ...(opts?.temperature ? { temperature: opts?.temperature } : undefined),
274
- max_tokens: opts?.max_tokens ?? 512,
275
+ ...(opts?.max_tokens ? { max_tokens: opts?.max_tokens } : undefined),
275
276
  ...(opts?.top_p ? { top_p: opts?.top_p } : undefined),
276
277
  };
277
278
  };
@@ -310,7 +311,7 @@ const snippets: Partial<
310
311
  "text-generation": snippetGenerator("basic"),
311
312
  "text-to-audio": snippetGenerator("textToAudio"),
312
313
  "text-to-image": snippetGenerator("textToImage"),
313
- "text-to-speech": snippetGenerator("textToAudio"),
314
+ "text-to-speech": snippetGenerator("textToSpeech"),
314
315
  "text-to-video": snippetGenerator("textToVideo"),
315
316
  "text2text-generation": snippetGenerator("basic"),
316
317
  "token-classification": snippetGenerator("basic"),
@@ -7,6 +7,7 @@ export const templates: Record<string, Record<string, Record<string, string>>> =
7
7
  "basicImage": "async function query(data) {\n\tconst response = await fetch(\n\t\t\"{{ fullUrl }}\",\n\t\t{\n\t\t\theaders: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n\t\t\t\t\"Content-Type\": \"image/jpeg\",\n{% if billTo %}\n\t\t\t\t\"X-HF-Bill-To\": \"{{ billTo }}\",\n{% endif %}\t\t\t},\n\t\t\tmethod: \"POST\",\n\t\t\tbody: JSON.stringify(data),\n\t\t}\n\t);\n\tconst result = await response.json();\n\treturn result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});",
8
8
  "textToAudio": "{% if model.library_name == \"transformers\" %}\nasync function query(data) {\n\tconst response = await fetch(\n\t\t\"{{ fullUrl }}\",\n\t\t{\n\t\t\theaders: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n\t\t\t\t\"Content-Type\": \"application/json\",\n{% if billTo %}\n\t\t\t\t\"X-HF-Bill-To\": \"{{ billTo }}\",\n{% endif %}\t\t\t},\n\t\t\tmethod: \"POST\",\n\t\t\tbody: JSON.stringify(data),\n\t\t}\n\t);\n\tconst result = await response.blob();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n\tconst response = await fetch(\n\t\t\"{{ fullUrl }}\",\n\t\t{\n\t\t\theaders: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n\t\t\t\t\"Content-Type\": \"application/json\",\n\t\t\t},\n\t\t\tmethod: \"POST\",\n\t\t\tbody: JSON.stringify(data),\n\t\t}\n\t);\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ",
9
9
  "textToImage": "async function query(data) {\n\tconst response = await fetch(\n\t\t\"{{ fullUrl }}\",\n\t\t{\n\t\t\theaders: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n\t\t\t\t\"Content-Type\": \"application/json\",\n{% if billTo %}\n\t\t\t\t\"X-HF-Bill-To\": \"{{ billTo }}\",\n{% endif %}\t\t\t},\n\t\t\tmethod: \"POST\",\n\t\t\tbody: JSON.stringify(data),\n\t\t}\n\t);\n\tconst result = await response.blob();\n\treturn result;\n}\n\n\nquery({ {{ providerInputs.asTsString }} }).then((response) => {\n // Use image\n});",
10
+ "textToSpeech": "{% if model.library_name == \"transformers\" %}\nasync function query(data) {\n\tconst response = await fetch(\n\t\t\"{{ fullUrl }}\",\n\t\t{\n\t\t\theaders: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n\t\t\t\t\"Content-Type\": \"application/json\",\n{% if billTo %}\n\t\t\t\t\"X-HF-Bill-To\": \"{{ billTo }}\",\n{% endif %}\t\t\t},\n\t\t\tmethod: \"POST\",\n\t\t\tbody: JSON.stringify(data),\n\t\t}\n\t);\n\tconst result = await response.blob();\n return result;\n}\n\nquery({ text: {{ inputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n\tconst response = await fetch(\n\t\t\"{{ fullUrl }}\",\n\t\t{\n\t\t\theaders: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n\t\t\t\t\"Content-Type\": \"application/json\",\n\t\t\t},\n\t\t\tmethod: \"POST\",\n\t\t\tbody: JSON.stringify(data),\n\t\t}\n\t);\n const result = await response.json();\n return result;\n}\n\nquery({ text: {{ inputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ",
10
11
  "zeroShotClassification": "async function query(data) {\n const response = await fetch(\n\t\t\"{{ fullUrl }}\",\n {\n headers: {\n\t\t\t\tAuthorization: \"{{ authorizationHeader }}\",\n \"Content-Type\": \"application/json\",\n{% if billTo %}\n \"X-HF-Bill-To\": \"{{ billTo }}\",\n{% endif %} },\n method: \"POST\",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({\n inputs: {{ providerInputs.asObj.inputs }},\n parameters: { candidate_labels: [\"refund\", \"legal\", \"faq\"] }\n}).then((response) => {\n console.log(JSON.stringify(response));\n});"
11
12
  },
12
13
  "huggingface.js": {
@@ -16,7 +17,8 @@ export const templates: Record<string, Record<string, Record<string, string>>> =
16
17
  "conversational": "import { InferenceClient } from \"@huggingface/inference\";\n\nconst client = new InferenceClient(\"{{ accessToken }}\");\n\nconst chatCompletion = await client.chatCompletion({\n provider: \"{{ provider }}\",\n model: \"{{ model.id }}\",\n{{ inputs.asTsString }}\n}{% if billTo %}, {\n billTo: \"{{ billTo }}\",\n}{% endif %});\n\nconsole.log(chatCompletion.choices[0].message);",
17
18
  "conversationalStream": "import { InferenceClient } from \"@huggingface/inference\";\n\nconst client = new InferenceClient(\"{{ accessToken }}\");\n\nlet out = \"\";\n\nconst stream = client.chatCompletionStream({\n provider: \"{{ provider }}\",\n model: \"{{ model.id }}\",\n{{ inputs.asTsString }}\n}{% if billTo %}, {\n billTo: \"{{ billTo }}\",\n}{% endif %});\n\nfor await (const chunk of stream) {\n\tif (chunk.choices && chunk.choices.length > 0) {\n\t\tconst newContent = chunk.choices[0].delta.content;\n\t\tout += newContent;\n\t\tconsole.log(newContent);\n\t}\n}",
18
19
  "textToImage": "import { InferenceClient } from \"@huggingface/inference\";\n\nconst client = new InferenceClient(\"{{ accessToken }}\");\n\nconst image = await client.textToImage({\n provider: \"{{ provider }}\",\n model: \"{{ model.id }}\",\n\tinputs: {{ inputs.asObj.inputs }},\n\tparameters: { num_inference_steps: 5 },\n}{% if billTo %}, {\n billTo: \"{{ billTo }}\",\n}{% endif %});\n/// Use the generated image (it's a Blob)",
19
- "textToVideo": "import { InferenceClient } from \"@huggingface/inference\";\n\nconst client = new InferenceClient(\"{{ accessToken }}\");\n\nconst image = await client.textToVideo({\n provider: \"{{ provider }}\",\n model: \"{{ model.id }}\",\n\tinputs: {{ inputs.asObj.inputs }},\n}{% if billTo %}, {\n billTo: \"{{ billTo }}\",\n}{% endif %});\n// Use the generated video (it's a Blob)"
20
+ "textToSpeech": "import { InferenceClient } from \"@huggingface/inference\";\n\nconst client = new InferenceClient(\"{{ accessToken }}\");\n\nconst audio = await client.textToSpeech({\n provider: \"{{ provider }}\",\n model: \"{{ model.id }}\",\n\tinputs: {{ inputs.asObj.inputs }},\n}{% if billTo %}, {\n billTo: \"{{ billTo }}\",\n}{% endif %});\n// Use the generated audio (it's a Blob)",
21
+ "textToVideo": "import { InferenceClient } from \"@huggingface/inference\";\n\nconst client = new InferenceClient(\"{{ accessToken }}\");\n\nconst video = await client.textToVideo({\n provider: \"{{ provider }}\",\n model: \"{{ model.id }}\",\n\tinputs: {{ inputs.asObj.inputs }},\n}{% if billTo %}, {\n billTo: \"{{ billTo }}\",\n}{% endif %});\n// Use the generated video (it's a Blob)"
20
22
  },
21
23
  "openai": {
22
24
  "conversational": "import { OpenAI } from \"openai\";\n\nconst client = new OpenAI({\n\tbaseURL: \"{{ baseUrl }}\",\n\tapiKey: \"{{ accessToken }}\",\n{% if billTo %}\n\tdefaultHeaders: {\n\t\t\"X-HF-Bill-To\": \"{{ billTo }}\" \n\t}\n{% endif %}\n});\n\nconst chatCompletion = await client.chat.completions.create({\n\tmodel: \"{{ providerModelId }}\",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);",
@@ -25,7 +27,7 @@ export const templates: Record<string, Record<string, Record<string, string>>> =
25
27
  },
26
28
  "python": {
27
29
  "fal_client": {
28
- "textToImage": "{% if provider == \"fal-ai\" %}\nimport fal_client\n\nresult = fal_client.subscribe(\n \"{{ providerModelId }}\",\n arguments={\n \"prompt\": {{ inputs.asObj.inputs }},\n },\n)\nprint(result)\n{% endif %} "
30
+ "textToImage": "{% if provider == \"fal-ai\" %}\nimport fal_client\n\n{% if providerInputs.asObj.loras is defined and providerInputs.asObj.loras != none %}\nresult = fal_client.subscribe(\n \"{{ providerModelId }}\",\n arguments={\n \"prompt\": {{ inputs.asObj.inputs }},\n \"loras\":{{ providerInputs.asObj.loras | tojson }},\n },\n)\n{% else %}\nresult = fal_client.subscribe(\n \"{{ providerModelId }}\",\n arguments={\n \"prompt\": {{ inputs.asObj.inputs }},\n },\n)\n{% endif %} \nprint(result)\n{% endif %} "
29
31
  },
30
32
  "huggingface_hub": {
31
33
  "basic": "result = client.{{ methodName }}(\n inputs={{ inputs.asObj.inputs }},\n model=\"{{ model.id }}\",\n)",
@@ -37,6 +39,7 @@ export const templates: Record<string, Record<string, Record<string, string>>> =
37
39
  "imageToImage": "# output is a PIL.Image object\nimage = client.image_to_image(\n \"{{ inputs.asObj.inputs }}\",\n prompt=\"{{ inputs.asObj.parameters.prompt }}\",\n model=\"{{ model.id }}\",\n) ",
38
40
  "importInferenceClient": "from huggingface_hub import InferenceClient\n\nclient = InferenceClient(\n provider=\"{{ provider }}\",\n api_key=\"{{ accessToken }}\",\n{% if billTo %}\n bill_to=\"{{ billTo }}\",\n{% endif %}\n)",
39
41
  "textToImage": "# output is a PIL.Image object\nimage = client.text_to_image(\n {{ inputs.asObj.inputs }},\n model=\"{{ model.id }}\",\n) ",
42
+ "textToSpeech": "# audio is returned as bytes\naudio = client.text_to_speech(\n {{ inputs.asObj.inputs }},\n model=\"{{ model.id }}\",\n) \n",
40
43
  "textToVideo": "video = client.text_to_video(\n {{ inputs.asObj.inputs }},\n model=\"{{ model.id }}\",\n) "
41
44
  },
42
45
  "openai": {
@@ -53,8 +56,9 @@ export const templates: Record<string, Record<string, Record<string, string>>> =
53
56
  "imageToImage": "def query(payload):\n with open(payload[\"inputs\"], \"rb\") as f:\n img = f.read()\n payload[\"inputs\"] = base64.b64encode(img).decode(\"utf-8\")\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nimage_bytes = query({\n{{ providerInputs.asJsonString }}\n})\n\n# You can access the image with PIL.Image for example\nimport io\nfrom PIL import Image\nimage = Image.open(io.BytesIO(image_bytes)) ",
54
57
  "importRequests": "{% if importBase64 %}\nimport base64\n{% endif %}\n{% if importJson %}\nimport json\n{% endif %}\nimport requests\n\nAPI_URL = \"{{ fullUrl }}\"\nheaders = {\n \"Authorization\": \"{{ authorizationHeader }}\",\n{% if billTo %}\n \"X-HF-Bill-To\": \"{{ billTo }}\"\n{% endif %}\n}",
55
58
  "tabular": "def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nresponse = query({\n \"inputs\": {\n \"data\": {{ providerInputs.asObj.inputs }}\n },\n}) ",
56
- "textToAudio": "{% if model.library_name == \"transformers\" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\naudio_bytes = query({\n \"inputs\": {{ providerInputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio_bytes)\n{% else %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\naudio, sampling_rate = query({\n \"inputs\": {{ providerInputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio, rate=sampling_rate)\n{% endif %} ",
59
+ "textToAudio": "{% if model.library_name == \"transformers\" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\naudio_bytes = query({\n \"inputs\": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio_bytes)\n{% else %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\naudio, sampling_rate = query({\n \"inputs\": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio, rate=sampling_rate)\n{% endif %} ",
57
60
  "textToImage": "{% if provider == \"hf-inference\" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nimage_bytes = query({\n \"inputs\": {{ providerInputs.asObj.inputs }},\n})\n\n# You can access the image with PIL.Image for example\nimport io\nfrom PIL import Image\nimage = Image.open(io.BytesIO(image_bytes))\n{% endif %}",
61
+ "textToSpeech": "{% if model.library_name == \"transformers\" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\naudio_bytes = query({\n \"text\": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio_bytes)\n{% else %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\naudio, sampling_rate = query({\n \"text\": {{ inputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio, rate=sampling_rate)\n{% endif %} ",
58
62
  "zeroShotClassification": "def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\noutput = query({\n \"inputs\": {{ providerInputs.asObj.inputs }},\n \"parameters\": {\"candidate_labels\": [\"refund\", \"legal\", \"faq\"]},\n}) ",
59
63
  "zeroShotImageClassification": "def query(data):\n with open(data[\"image_path\"], \"rb\") as f:\n img = f.read()\n payload={\n \"parameters\": data[\"parameters\"],\n \"inputs\": base64.b64encode(img).decode(\"utf-8\")\n }\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\noutput = query({\n \"image_path\": {{ providerInputs.asObj.inputs }},\n \"parameters\": {\"candidate_labels\": [\"cat\", \"dog\", \"llama\"]},\n}) "
60
64
  }
@@ -1,4 +1,5 @@
1
1
  import type { AudioClassificationInput, AudioClassificationOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -15,7 +16,8 @@ export async function audioClassification(
15
16
  args: AudioClassificationArgs,
16
17
  options?: Options
17
18
  ): Promise<AudioClassificationOutput> {
18
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
19
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
20
+ const providerHelper = getProviderHelper(provider, "audio-classification");
19
21
  const payload = preparePayload(args);
20
22
  const { data: res } = await innerRequest<AudioClassificationOutput>(payload, providerHelper, {
21
23
  ...options,
@@ -1,3 +1,4 @@
1
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
1
2
  import { getProviderHelper } from "../../lib/getProviderHelper";
2
3
  import type { BaseArgs, Options } from "../../types";
3
4
  import { innerRequest } from "../../utils/request";
@@ -36,7 +37,9 @@ export interface AudioToAudioOutput {
36
37
  * Example model: speechbrain/sepformer-wham does audio source separation.
37
38
  */
38
39
  export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
39
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
40
+ const model = "inputs" in args ? args.model : undefined;
41
+ const provider = await resolveProvider(args.provider, model);
42
+ const providerHelper = getProviderHelper(provider, "audio-to-audio");
40
43
  const payload = preparePayload(args);
41
44
  const { data: res } = await innerRequest<AudioToAudioOutput>(payload, providerHelper, {
42
45
  ...options,
@@ -1,4 +1,5 @@
1
1
  import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import { InferenceOutputError } from "../../lib/InferenceOutputError";
4
5
  import { FAL_AI_SUPPORTED_BLOB_TYPES } from "../../providers/fal-ai";
@@ -18,7 +19,8 @@ export async function automaticSpeechRecognition(
18
19
  args: AutomaticSpeechRecognitionArgs,
19
20
  options?: Options
20
21
  ): Promise<AutomaticSpeechRecognitionOutput> {
21
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
22
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
23
+ const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
22
24
  const payload = await buildPayload(args);
23
25
  const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, providerHelper, {
24
26
  ...options,
@@ -1,4 +1,5 @@
1
1
  import type { TextToSpeechInput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -12,7 +13,7 @@ interface OutputUrlTextToSpeechGeneration {
12
13
  * Recommended model: espnet/kan-bayashi_ljspeech_vits
13
14
  */
14
15
  export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
15
- const provider = args.provider ?? "hf-inference";
16
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
16
17
  const providerHelper = getProviderHelper(provider, "text-to-speech");
17
18
  const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, providerHelper, {
18
19
  ...options,
@@ -1,3 +1,4 @@
1
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
1
2
  import { getProviderHelper } from "../../lib/getProviderHelper";
2
3
  import type { InferenceTask, Options, RequestArgs } from "../../types";
3
4
  import { innerRequest } from "../../utils/request";
@@ -16,7 +17,8 @@ export async function request<T>(
16
17
  console.warn(
17
18
  "The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
18
19
  );
19
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
20
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
21
+ const providerHelper = getProviderHelper(provider, options?.task);
20
22
  const result = await innerRequest<T>(args, providerHelper, options);
21
23
  return result.data;
22
24
  }
@@ -1,3 +1,4 @@
1
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
1
2
  import { getProviderHelper } from "../../lib/getProviderHelper";
2
3
  import type { InferenceTask, Options, RequestArgs } from "../../types";
3
4
  import { innerStreamingRequest } from "../../utils/request";
@@ -16,6 +17,7 @@ export async function* streamingRequest<T>(
16
17
  console.warn(
17
18
  "The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
18
19
  );
19
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
20
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
21
+ const providerHelper = getProviderHelper(provider, options?.task);
20
22
  yield* innerStreamingRequest(args, providerHelper, options);
21
23
  }
@@ -1,4 +1,5 @@
1
1
  import type { ImageClassificationInput, ImageClassificationOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -14,7 +15,8 @@ export async function imageClassification(
14
15
  args: ImageClassificationArgs,
15
16
  options?: Options
16
17
  ): Promise<ImageClassificationOutput> {
17
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
18
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
19
+ const providerHelper = getProviderHelper(provider, "image-classification");
18
20
  const payload = preparePayload(args);
19
21
  const { data: res } = await innerRequest<ImageClassificationOutput>(payload, providerHelper, {
20
22
  ...options,
@@ -1,4 +1,5 @@
1
1
  import type { ImageSegmentationInput, ImageSegmentationOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -14,7 +15,8 @@ export async function imageSegmentation(
14
15
  args: ImageSegmentationArgs,
15
16
  options?: Options
16
17
  ): Promise<ImageSegmentationOutput> {
17
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
18
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
19
+ const providerHelper = getProviderHelper(provider, "image-segmentation");
18
20
  const payload = preparePayload(args);
19
21
  const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, providerHelper, {
20
22
  ...options,
@@ -1,4 +1,5 @@
1
1
  import type { ImageToImageInput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options, RequestArgs } from "../../types";
4
5
  import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -11,7 +12,8 @@ export type ImageToImageArgs = BaseArgs & ImageToImageInput;
11
12
  * Recommended model: lllyasviel/sd-controlnet-depth
12
13
  */
13
14
  export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> {
14
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
15
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
16
+ const providerHelper = getProviderHelper(provider, "image-to-image");
15
17
  let reqArgs: RequestArgs;
16
18
  if (!args.parameters) {
17
19
  reqArgs = {
@@ -1,4 +1,5 @@
1
1
  import type { ImageToTextInput, ImageToTextOutput } from "@huggingface/tasks";
2
+ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
2
3
  import { getProviderHelper } from "../../lib/getProviderHelper";
3
4
  import type { BaseArgs, Options } from "../../types";
4
5
  import { innerRequest } from "../../utils/request";
@@ -10,7 +11,8 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput);
10
11
  * This task reads some image input and outputs the text caption.
11
12
  */
12
13
  export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
13
- const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
14
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
15
+ const providerHelper = getProviderHelper(provider, "image-to-text");
14
16
  const payload = preparePayload(args);
15
17
  const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, providerHelper, {
16
18
  ...options,