@huggingface/inference 3.7.0 → 3.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. package/dist/index.cjs +1369 -941
  2. package/dist/index.js +1371 -943
  3. package/dist/src/lib/getInferenceProviderMapping.d.ts +21 -0
  4. package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -0
  5. package/dist/src/lib/getProviderHelper.d.ts +37 -0
  6. package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
  7. package/dist/src/lib/makeRequestOptions.d.ts +5 -5
  8. package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
  9. package/dist/src/providers/black-forest-labs.d.ts +14 -18
  10. package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
  11. package/dist/src/providers/cerebras.d.ts +4 -2
  12. package/dist/src/providers/cerebras.d.ts.map +1 -1
  13. package/dist/src/providers/cohere.d.ts +5 -2
  14. package/dist/src/providers/cohere.d.ts.map +1 -1
  15. package/dist/src/providers/consts.d.ts +2 -3
  16. package/dist/src/providers/consts.d.ts.map +1 -1
  17. package/dist/src/providers/fal-ai.d.ts +50 -3
  18. package/dist/src/providers/fal-ai.d.ts.map +1 -1
  19. package/dist/src/providers/fireworks-ai.d.ts +5 -2
  20. package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
  21. package/dist/src/providers/hf-inference.d.ts +126 -2
  22. package/dist/src/providers/hf-inference.d.ts.map +1 -1
  23. package/dist/src/providers/hyperbolic.d.ts +31 -2
  24. package/dist/src/providers/hyperbolic.d.ts.map +1 -1
  25. package/dist/src/providers/nebius.d.ts +20 -18
  26. package/dist/src/providers/nebius.d.ts.map +1 -1
  27. package/dist/src/providers/novita.d.ts +21 -18
  28. package/dist/src/providers/novita.d.ts.map +1 -1
  29. package/dist/src/providers/openai.d.ts +4 -2
  30. package/dist/src/providers/openai.d.ts.map +1 -1
  31. package/dist/src/providers/providerHelper.d.ts +182 -0
  32. package/dist/src/providers/providerHelper.d.ts.map +1 -0
  33. package/dist/src/providers/replicate.d.ts +23 -19
  34. package/dist/src/providers/replicate.d.ts.map +1 -1
  35. package/dist/src/providers/sambanova.d.ts +4 -2
  36. package/dist/src/providers/sambanova.d.ts.map +1 -1
  37. package/dist/src/providers/together.d.ts +32 -2
  38. package/dist/src/providers/together.d.ts.map +1 -1
  39. package/dist/src/snippets/getInferenceSnippets.d.ts +2 -1
  40. package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
  41. package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
  42. package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
  43. package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
  44. package/dist/src/tasks/audio/utils.d.ts +2 -1
  45. package/dist/src/tasks/audio/utils.d.ts.map +1 -1
  46. package/dist/src/tasks/custom/request.d.ts +0 -2
  47. package/dist/src/tasks/custom/request.d.ts.map +1 -1
  48. package/dist/src/tasks/custom/streamingRequest.d.ts +0 -2
  49. package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
  50. package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
  51. package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
  52. package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
  53. package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
  54. package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
  55. package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
  56. package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
  57. package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
  58. package/dist/src/tasks/index.d.ts +6 -6
  59. package/dist/src/tasks/index.d.ts.map +1 -1
  60. package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
  61. package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
  62. package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
  63. package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
  64. package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
  65. package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
  66. package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
  67. package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
  68. package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
  69. package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
  70. package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
  71. package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
  72. package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
  73. package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
  74. package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
  75. package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
  76. package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
  77. package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
  78. package/dist/src/types.d.ts +5 -13
  79. package/dist/src/types.d.ts.map +1 -1
  80. package/dist/src/utils/request.d.ts +3 -2
  81. package/dist/src/utils/request.d.ts.map +1 -1
  82. package/package.json +3 -3
  83. package/src/lib/getInferenceProviderMapping.ts +96 -0
  84. package/src/lib/getProviderHelper.ts +270 -0
  85. package/src/lib/makeRequestOptions.ts +78 -97
  86. package/src/providers/black-forest-labs.ts +73 -22
  87. package/src/providers/cerebras.ts +6 -27
  88. package/src/providers/cohere.ts +9 -28
  89. package/src/providers/consts.ts +5 -2
  90. package/src/providers/fal-ai.ts +224 -77
  91. package/src/providers/fireworks-ai.ts +8 -29
  92. package/src/providers/hf-inference.ts +557 -34
  93. package/src/providers/hyperbolic.ts +107 -29
  94. package/src/providers/nebius.ts +65 -29
  95. package/src/providers/novita.ts +68 -32
  96. package/src/providers/openai.ts +6 -32
  97. package/src/providers/providerHelper.ts +354 -0
  98. package/src/providers/replicate.ts +124 -34
  99. package/src/providers/sambanova.ts +5 -30
  100. package/src/providers/together.ts +92 -28
  101. package/src/snippets/getInferenceSnippets.ts +39 -14
  102. package/src/snippets/templates.exported.ts +25 -25
  103. package/src/tasks/audio/audioClassification.ts +5 -8
  104. package/src/tasks/audio/audioToAudio.ts +4 -27
  105. package/src/tasks/audio/automaticSpeechRecognition.ts +5 -4
  106. package/src/tasks/audio/textToSpeech.ts +5 -29
  107. package/src/tasks/audio/utils.ts +2 -1
  108. package/src/tasks/custom/request.ts +3 -3
  109. package/src/tasks/custom/streamingRequest.ts +4 -3
  110. package/src/tasks/cv/imageClassification.ts +4 -8
  111. package/src/tasks/cv/imageSegmentation.ts +4 -9
  112. package/src/tasks/cv/imageToImage.ts +4 -7
  113. package/src/tasks/cv/imageToText.ts +4 -7
  114. package/src/tasks/cv/objectDetection.ts +4 -19
  115. package/src/tasks/cv/textToImage.ts +9 -137
  116. package/src/tasks/cv/textToVideo.ts +17 -64
  117. package/src/tasks/cv/zeroShotImageClassification.ts +4 -8
  118. package/src/tasks/index.ts +6 -6
  119. package/src/tasks/multimodal/documentQuestionAnswering.ts +4 -19
  120. package/src/tasks/multimodal/visualQuestionAnswering.ts +4 -12
  121. package/src/tasks/nlp/chatCompletion.ts +5 -20
  122. package/src/tasks/nlp/chatCompletionStream.ts +4 -3
  123. package/src/tasks/nlp/featureExtraction.ts +4 -19
  124. package/src/tasks/nlp/fillMask.ts +4 -17
  125. package/src/tasks/nlp/questionAnswering.ts +11 -26
  126. package/src/tasks/nlp/sentenceSimilarity.ts +4 -8
  127. package/src/tasks/nlp/summarization.ts +4 -7
  128. package/src/tasks/nlp/tableQuestionAnswering.ts +10 -30
  129. package/src/tasks/nlp/textClassification.ts +4 -9
  130. package/src/tasks/nlp/textGeneration.ts +11 -79
  131. package/src/tasks/nlp/textGenerationStream.ts +3 -1
  132. package/src/tasks/nlp/tokenClassification.ts +11 -23
  133. package/src/tasks/nlp/translation.ts +4 -7
  134. package/src/tasks/nlp/zeroShotClassification.ts +11 -21
  135. package/src/tasks/tabular/tabularClassification.ts +4 -7
  136. package/src/tasks/tabular/tabularRegression.ts +4 -7
  137. package/src/types.ts +5 -14
  138. package/src/utils/request.ts +7 -4
  139. package/dist/src/lib/getProviderModelId.d.ts +0 -10
  140. package/dist/src/lib/getProviderModelId.d.ts.map +0 -1
  141. package/src/lib/getProviderModelId.ts +0 -74
@@ -1,23 +1,10 @@
1
- import { HF_HUB_URL, HF_ROUTER_URL, HF_HEADER_X_BILL_TO } from "../config";
2
- import { BLACK_FOREST_LABS_CONFIG } from "../providers/black-forest-labs";
3
- import { CEREBRAS_CONFIG } from "../providers/cerebras";
4
- import { COHERE_CONFIG } from "../providers/cohere";
5
- import { FAL_AI_CONFIG } from "../providers/fal-ai";
6
- import { FIREWORKS_AI_CONFIG } from "../providers/fireworks-ai";
7
- import { HF_INFERENCE_CONFIG } from "../providers/hf-inference";
8
- import { HYPERBOLIC_CONFIG } from "../providers/hyperbolic";
9
- import { NEBIUS_CONFIG } from "../providers/nebius";
10
- import { NOVITA_CONFIG } from "../providers/novita";
11
- import { REPLICATE_CONFIG } from "../providers/replicate";
12
- import { SAMBANOVA_CONFIG } from "../providers/sambanova";
13
- import { TOGETHER_CONFIG } from "../providers/together";
14
- import { OPENAI_CONFIG } from "../providers/openai";
15
- import type { InferenceProvider, InferenceTask, Options, ProviderConfig, RequestArgs } from "../types";
1
+ import { name as packageName, version as packageVersion } from "../../package.json";
2
+ import { HF_HEADER_X_BILL_TO, HF_HUB_URL } from "../config";
3
+ import type { InferenceTask, Options, RequestArgs } from "../types";
4
+ import type { InferenceProviderModelMapping } from "./getInferenceProviderMapping";
5
+ import { getInferenceProviderMapping } from "./getInferenceProviderMapping";
6
+ import type { getProviderHelper } from "./getProviderHelper";
16
7
  import { isUrl } from "./isUrl";
17
- import { version as packageVersion, name as packageName } from "../../package.json";
18
- import { getProviderModelId } from "./getProviderModelId";
19
-
20
- const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
21
8
 
22
9
  /**
23
10
  * Lazy-loaded from huggingface.co/api/tasks when needed
@@ -25,25 +12,6 @@ const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
25
12
  */
26
13
  let tasks: Record<string, { models: { id: string }[] }> | null = null;
27
14
 
28
- /**
29
- * Config to define how to serialize requests for each provider
30
- */
31
- const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
32
- "black-forest-labs": BLACK_FOREST_LABS_CONFIG,
33
- cerebras: CEREBRAS_CONFIG,
34
- cohere: COHERE_CONFIG,
35
- "fal-ai": FAL_AI_CONFIG,
36
- "fireworks-ai": FIREWORKS_AI_CONFIG,
37
- "hf-inference": HF_INFERENCE_CONFIG,
38
- hyperbolic: HYPERBOLIC_CONFIG,
39
- openai: OPENAI_CONFIG,
40
- nebius: NEBIUS_CONFIG,
41
- novita: NOVITA_CONFIG,
42
- replicate: REPLICATE_CONFIG,
43
- sambanova: SAMBANOVA_CONFIG,
44
- together: TOGETHER_CONFIG,
45
- };
46
-
47
15
  /**
48
16
  * Helper that prepares request arguments.
49
17
  * This async version handle the model ID resolution step.
@@ -53,16 +21,15 @@ export async function makeRequestOptions(
53
21
  data?: Blob | ArrayBuffer;
54
22
  stream?: boolean;
55
23
  },
24
+ providerHelper: ReturnType<typeof getProviderHelper>,
56
25
  options?: Options & {
57
26
  /** In most cases (unless we pass a endpointUrl) we know the task */
58
27
  task?: InferenceTask;
59
- chatCompletion?: boolean;
60
28
  }
61
29
  ): Promise<{ url: string; info: RequestInit }> {
62
30
  const { provider: maybeProvider, model: maybeModel } = args;
63
31
  const provider = maybeProvider ?? "hf-inference";
64
- const providerConfig = providerConfigs[provider];
65
- const { task, chatCompletion } = options ?? {};
32
+ const { task } = options ?? {};
66
33
 
67
34
  // Validate inputs
68
35
  if (args.endpointUrl && provider !== "hf-inference") {
@@ -71,29 +38,61 @@ export async function makeRequestOptions(
71
38
  if (maybeModel && isUrl(maybeModel)) {
72
39
  throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
73
40
  }
41
+
42
+ if (args.endpointUrl) {
43
+ // No need to have maybeModel, or to load default model for a task
44
+ return makeRequestOptionsFromResolvedModel(
45
+ maybeModel ?? args.endpointUrl,
46
+ providerHelper,
47
+ args,
48
+ undefined,
49
+ options
50
+ );
51
+ }
52
+
74
53
  if (!maybeModel && !task) {
75
54
  throw new Error("No model provided, and no task has been specified.");
76
55
  }
77
- if (!providerConfig) {
78
- throw new Error(`No provider config found for provider ${provider}`);
79
- }
80
- if (providerConfig.clientSideRoutingOnly && !maybeModel) {
81
- throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
82
- }
83
56
 
84
57
  // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
85
58
  const hfModel = maybeModel ?? (await loadDefaultModel(task!));
86
- const resolvedModel = providerConfig.clientSideRoutingOnly
87
- ? // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
88
- removeProviderPrefix(maybeModel!, provider)
89
- : await getProviderModelId({ model: hfModel, provider }, args, {
90
- task,
91
- chatCompletion,
92
- fetch: options?.fetch,
93
- });
59
+
60
+ if (providerHelper.clientSideRoutingOnly && !maybeModel) {
61
+ throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
62
+ }
63
+
64
+ const inferenceProviderMapping = providerHelper.clientSideRoutingOnly
65
+ ? ({
66
+ // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
67
+ providerId: removeProviderPrefix(maybeModel!, provider),
68
+ // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
69
+ hfModelId: maybeModel!,
70
+ status: "live",
71
+ // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
72
+ task: task!,
73
+ } satisfies InferenceProviderModelMapping)
74
+ : await getInferenceProviderMapping(
75
+ {
76
+ modelId: hfModel,
77
+ // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
78
+ task: task!,
79
+ provider,
80
+ accessToken: args.accessToken,
81
+ },
82
+ { fetch: options?.fetch }
83
+ );
84
+ if (!inferenceProviderMapping) {
85
+ throw new Error(`We have not been able to find inference provider information for model ${hfModel}.`);
86
+ }
94
87
 
95
88
  // Use the sync version with the resolved model
96
- return makeRequestOptionsFromResolvedModel(resolvedModel, args, options);
89
+ return makeRequestOptionsFromResolvedModel(
90
+ inferenceProviderMapping.providerId,
91
+ providerHelper,
92
+ args,
93
+ inferenceProviderMapping,
94
+ options
95
+ );
97
96
  }
98
97
 
99
98
  /**
@@ -102,25 +101,24 @@ export async function makeRequestOptions(
102
101
  */
103
102
  export function makeRequestOptionsFromResolvedModel(
104
103
  resolvedModel: string,
104
+ providerHelper: ReturnType<typeof getProviderHelper>,
105
105
  args: RequestArgs & {
106
106
  data?: Blob | ArrayBuffer;
107
107
  stream?: boolean;
108
108
  },
109
+ mapping: InferenceProviderModelMapping | undefined,
109
110
  options?: Options & {
110
111
  task?: InferenceTask;
111
- chatCompletion?: boolean;
112
112
  }
113
113
  ): { url: string; info: RequestInit } {
114
114
  const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
115
115
  void model;
116
116
 
117
117
  const provider = maybeProvider ?? "hf-inference";
118
- const providerConfig = providerConfigs[provider];
119
-
120
- const { includeCredentials, task, chatCompletion, signal, billTo } = options ?? {};
121
118
 
119
+ const { includeCredentials, task, signal, billTo } = options ?? {};
122
120
  const authMethod = (() => {
123
- if (providerConfig.clientSideRoutingOnly) {
121
+ if (providerHelper.clientSideRoutingOnly) {
124
122
  // Closed-source providers require an accessToken (cannot be routed).
125
123
  if (accessToken && accessToken.startsWith("hf_")) {
126
124
  throw new Error(`Provider ${provider} is closed-source and does not support HF tokens.`);
@@ -138,36 +136,25 @@ export function makeRequestOptionsFromResolvedModel(
138
136
  })();
139
137
 
140
138
  // Make URL
141
- const url = endpointUrl
142
- ? chatCompletion
143
- ? endpointUrl + `/v1/chat/completions`
144
- : endpointUrl
145
- : providerConfig.makeUrl({
146
- authMethod,
147
- baseUrl:
148
- authMethod !== "provider-key"
149
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider)
150
- : providerConfig.makeBaseUrl(task),
151
- model: resolvedModel,
152
- chatCompletion,
153
- task,
154
- });
155
139
 
156
- // Make headers
157
- const binary = "data" in args && !!args.data;
158
- const headers = providerConfig.makeHeaders({
159
- accessToken,
140
+ const modelId = endpointUrl ?? resolvedModel;
141
+ const url = providerHelper.makeUrl({
160
142
  authMethod,
143
+ model: modelId,
144
+ task,
161
145
  });
146
+ // Make headers
147
+ const headers = providerHelper.prepareHeaders(
148
+ {
149
+ accessToken,
150
+ authMethod,
151
+ },
152
+ "data" in args && !!args.data
153
+ );
162
154
  if (billTo) {
163
155
  headers[HF_HEADER_X_BILL_TO] = billTo;
164
156
  }
165
157
 
166
- // Add content-type to headers
167
- if (!binary) {
168
- headers["Content-Type"] = "application/json";
169
- }
170
-
171
158
  // Add user-agent to headers
172
159
  // e.g. @huggingface/inference/3.1.3
173
160
  const ownUserAgent = `${packageName}/${packageVersion}`;
@@ -177,17 +164,12 @@ export function makeRequestOptionsFromResolvedModel(
177
164
  headers["User-Agent"] = userAgent;
178
165
 
179
166
  // Make body
180
- const body = binary
181
- ? args.data
182
- : JSON.stringify(
183
- providerConfig.makeBody({
184
- args: remainingArgs as Record<string, unknown>,
185
- model: resolvedModel,
186
- task,
187
- chatCompletion,
188
- })
189
- );
190
-
167
+ const body = providerHelper.makeBody({
168
+ args: remainingArgs as Record<string, unknown>,
169
+ model: resolvedModel,
170
+ task,
171
+ mapping,
172
+ });
191
173
  /**
192
174
  * For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
193
175
  */
@@ -201,11 +183,10 @@ export function makeRequestOptionsFromResolvedModel(
201
183
  const info: RequestInit = {
202
184
  headers,
203
185
  method: "POST",
204
- body,
186
+ body: body,
205
187
  ...(credentials ? { credentials } : undefined),
206
188
  signal,
207
189
  };
208
-
209
190
  return { url, info };
210
191
  }
211
192
 
@@ -14,33 +14,84 @@
14
14
  *
15
15
  * Thanks!
16
16
  */
17
- import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
17
+ import { InferenceOutputError } from "../lib/InferenceOutputError";
18
+ import type { BodyParams, HeaderParams, UrlParams } from "../types";
19
+ import { delay } from "../utils/delay";
20
+ import { omit } from "../utils/omit";
21
+ import { TaskProviderHelper, type TextToImageTaskHelper } from "./providerHelper";
18
22
 
19
23
  const BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
24
+ interface BlackForestLabsResponse {
25
+ id: string;
26
+ polling_url: string;
27
+ }
20
28
 
21
- const makeBaseUrl = (): string => {
22
- return BLACK_FOREST_LABS_AI_API_BASE_URL;
23
- };
29
+ export class BlackForestLabsTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper {
30
+ constructor() {
31
+ super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
32
+ }
24
33
 
25
- const makeBody = (params: BodyParams): Record<string, unknown> => {
26
- return params.args;
27
- };
34
+ preparePayload(params: BodyParams): Record<string, unknown> {
35
+ return {
36
+ ...omit(params.args, ["inputs", "parameters"]),
37
+ ...(params.args.parameters as Record<string, unknown>),
38
+ prompt: params.args.inputs,
39
+ };
40
+ }
28
41
 
29
- const makeHeaders = (params: HeaderParams): Record<string, string> => {
30
- if (params.authMethod === "provider-key") {
31
- return { "X-Key": `${params.accessToken}` };
32
- } else {
33
- return { Authorization: `Bearer ${params.accessToken}` };
42
+ override prepareHeaders(params: HeaderParams, binary: boolean): Record<string, string> {
43
+ const headers: Record<string, string> = {
44
+ Authorization:
45
+ params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`,
46
+ };
47
+ if (!binary) {
48
+ headers["Content-Type"] = "application/json";
49
+ }
50
+ return headers;
34
51
  }
35
- };
36
52
 
37
- const makeUrl = (params: UrlParams): string => {
38
- return `${params.baseUrl}/v1/${params.model}`;
39
- };
53
+ makeRoute(params: UrlParams): string {
54
+ if (!params) {
55
+ throw new Error("Params are required");
56
+ }
57
+ return `/v1/${params.model}`;
58
+ }
40
59
 
41
- export const BLACK_FOREST_LABS_CONFIG: ProviderConfig = {
42
- makeBaseUrl,
43
- makeBody,
44
- makeHeaders,
45
- makeUrl,
46
- };
60
+ async getResponse(
61
+ response: BlackForestLabsResponse,
62
+ url?: string,
63
+ headers?: HeadersInit,
64
+ outputType?: "url" | "blob"
65
+ ): Promise<string | Blob> {
66
+ const urlObj = new URL(response.polling_url);
67
+ for (let step = 0; step < 5; step++) {
68
+ await delay(1000);
69
+ console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
70
+ urlObj.searchParams.set("attempt", step.toString(10));
71
+ const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
72
+ if (!resp.ok) {
73
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
74
+ }
75
+ const payload = await resp.json();
76
+ if (
77
+ typeof payload === "object" &&
78
+ payload &&
79
+ "status" in payload &&
80
+ typeof payload.status === "string" &&
81
+ payload.status === "Ready" &&
82
+ "result" in payload &&
83
+ typeof payload.result === "object" &&
84
+ payload.result &&
85
+ "sample" in payload.result &&
86
+ typeof payload.result.sample === "string"
87
+ ) {
88
+ if (outputType === "url") {
89
+ return payload.result.sample;
90
+ }
91
+ const image = await fetch(payload.result.sample);
92
+ return await image.blob();
93
+ }
94
+ }
95
+ throw new InferenceOutputError("Failed to fetch result from black forest labs API");
96
+ }
97
+ }
@@ -14,32 +14,11 @@
14
14
  *
15
15
  * Thanks!
16
16
  */
17
- import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
18
17
 
19
- const CEREBRAS_API_BASE_URL = "https://api.cerebras.ai";
18
+ import { BaseConversationalTask } from "./providerHelper";
20
19
 
21
- const makeBaseUrl = (): string => {
22
- return CEREBRAS_API_BASE_URL;
23
- };
24
-
25
- const makeBody = (params: BodyParams): Record<string, unknown> => {
26
- return {
27
- ...params.args,
28
- model: params.model,
29
- };
30
- };
31
-
32
- const makeHeaders = (params: HeaderParams): Record<string, string> => {
33
- return { Authorization: `Bearer ${params.accessToken}` };
34
- };
35
-
36
- const makeUrl = (params: UrlParams): string => {
37
- return `${params.baseUrl}/v1/chat/completions`;
38
- };
39
-
40
- export const CEREBRAS_CONFIG: ProviderConfig = {
41
- makeBaseUrl,
42
- makeBody,
43
- makeHeaders,
44
- makeUrl,
45
- };
20
+ export class CerebrasConversationalTask extends BaseConversationalTask {
21
+ constructor() {
22
+ super("cerebras", "https://api.cerebras.ai");
23
+ }
24
+ }
@@ -14,32 +14,13 @@
14
14
  *
15
15
  * Thanks!
16
16
  */
17
- import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
17
+ import { BaseConversationalTask } from "./providerHelper";
18
18
 
19
- const COHERE_API_BASE_URL = "https://api.cohere.com";
20
-
21
- const makeBaseUrl = (): string => {
22
- return COHERE_API_BASE_URL;
23
- };
24
-
25
- const makeBody = (params: BodyParams): Record<string, unknown> => {
26
- return {
27
- ...params.args,
28
- model: params.model,
29
- };
30
- };
31
-
32
- const makeHeaders = (params: HeaderParams): Record<string, string> => {
33
- return { Authorization: `Bearer ${params.accessToken}` };
34
- };
35
-
36
- const makeUrl = (params: UrlParams): string => {
37
- return `${params.baseUrl}/compatibility/v1/chat/completions`;
38
- };
39
-
40
- export const COHERE_CONFIG: ProviderConfig = {
41
- makeBaseUrl,
42
- makeBody,
43
- makeHeaders,
44
- makeUrl,
45
- };
19
+ export class CohereConversationalTask extends BaseConversationalTask {
20
+ constructor() {
21
+ super("cohere", "https://api.cohere.com");
22
+ }
23
+ override makeRoute(): string {
24
+ return "/compatibility/v1/chat/completions";
25
+ }
26
+ }
@@ -1,7 +1,7 @@
1
+ import type { InferenceProviderModelMapping } from "../lib/getInferenceProviderMapping";
1
2
  import type { InferenceProvider } from "../types";
2
3
  import { type ModelId } from "../types";
3
4
 
4
- type ProviderId = string;
5
5
  /**
6
6
  * If you want to try to run inference for a new model locally before it's registered on huggingface.co
7
7
  * for a given Inference Provider,
@@ -9,7 +9,10 @@ type ProviderId = string;
9
9
  *
10
10
  * We also inject into this dictionary from tests.
11
11
  */
12
- export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelId, ProviderId>> = {
12
+ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
13
+ InferenceProvider,
14
+ Record<ModelId, InferenceProviderModelMapping>
15
+ > = {
13
16
  /**
14
17
  * "HF model ID" => "Model ID on Inference Provider's side"
15
18
  *