@huggingface/inference 4.13.10 → 4.13.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +2 -0
- package/dist/commonjs/lib/getProviderHelper.d.ts.map +1 -1
- package/dist/commonjs/lib/getProviderHelper.js +4 -0
- package/dist/commonjs/package.d.ts +1 -1
- package/dist/commonjs/package.js +1 -1
- package/dist/commonjs/providers/consts.d.ts.map +1 -1
- package/dist/commonjs/providers/consts.js +1 -0
- package/dist/commonjs/providers/nvidia.d.ts +21 -0
- package/dist/commonjs/providers/nvidia.d.ts.map +1 -0
- package/dist/commonjs/providers/nvidia.js +26 -0
- package/dist/commonjs/providers/replicate.d.ts.map +1 -1
- package/dist/commonjs/providers/replicate.js +6 -1
- package/dist/commonjs/snippets/getInferenceSnippets.js +11 -20
- package/dist/commonjs/types.d.ts +2 -2
- package/dist/commonjs/types.d.ts.map +1 -1
- package/dist/commonjs/types.js +2 -0
- package/dist/esm/lib/getProviderHelper.d.ts.map +1 -1
- package/dist/esm/lib/getProviderHelper.js +4 -0
- package/dist/esm/package.d.ts +1 -1
- package/dist/esm/package.js +1 -1
- package/dist/esm/providers/consts.d.ts.map +1 -1
- package/dist/esm/providers/consts.js +1 -0
- package/dist/esm/providers/nvidia.d.ts +21 -0
- package/dist/esm/providers/nvidia.d.ts.map +1 -0
- package/dist/esm/providers/nvidia.js +22 -0
- package/dist/esm/providers/replicate.d.ts.map +1 -1
- package/dist/esm/providers/replicate.js +6 -1
- package/dist/esm/snippets/getInferenceSnippets.js +11 -20
- package/dist/esm/types.d.ts +2 -2
- package/dist/esm/types.d.ts.map +1 -1
- package/dist/esm/types.js +2 -0
- package/package.json +34 -34
- package/src/InferenceClient.ts +2 -2
- package/src/errors.ts +1 -1
- package/src/lib/getDefaultTask.ts +1 -1
- package/src/lib/getInferenceProviderMapping.ts +11 -11
- package/src/lib/getProviderHelper.ts +41 -37
- package/src/lib/makeRequestOptions.ts +11 -11
- package/src/package.ts +1 -1
- package/src/providers/black-forest-labs.ts +3 -3
- package/src/providers/consts.ts +1 -0
- package/src/providers/fal-ai.ts +33 -33
- package/src/providers/featherless-ai.ts +1 -1
- package/src/providers/hf-inference.ts +48 -48
- package/src/providers/hyperbolic.ts +3 -3
- package/src/providers/nebius.ts +1 -1
- package/src/providers/novita.ts +7 -7
- package/src/providers/nscale.ts +2 -2
- package/src/providers/nvidia.ts +23 -0
- package/src/providers/ovhcloud.ts +1 -1
- package/src/providers/providerHelper.ts +7 -7
- package/src/providers/replicate.ts +8 -3
- package/src/providers/sambanova.ts +1 -1
- package/src/providers/together.ts +1 -1
- package/src/providers/wavespeed.ts +10 -10
- package/src/providers/zai-org.ts +7 -7
- package/src/snippets/getInferenceSnippets.ts +26 -26
- package/src/tasks/audio/audioClassification.ts +1 -1
- package/src/tasks/audio/automaticSpeechRecognition.ts +1 -1
- package/src/tasks/audio/utils.ts +1 -1
- package/src/tasks/custom/request.ts +2 -2
- package/src/tasks/custom/streamingRequest.ts +2 -2
- package/src/tasks/cv/imageClassification.ts +1 -1
- package/src/tasks/cv/imageSegmentation.ts +1 -1
- package/src/tasks/cv/textToImage.ts +5 -5
- package/src/tasks/cv/textToVideo.ts +1 -1
- package/src/tasks/cv/zeroShotImageClassification.ts +3 -3
- package/src/tasks/multimodal/documentQuestionAnswering.ts +2 -2
- package/src/tasks/multimodal/visualQuestionAnswering.ts +1 -1
- package/src/tasks/nlp/chatCompletion.ts +1 -1
- package/src/tasks/nlp/chatCompletionStream.ts +1 -1
- package/src/tasks/nlp/featureExtraction.ts +1 -1
- package/src/tasks/nlp/questionAnswering.ts +2 -2
- package/src/tasks/nlp/sentenceSimilarity.ts +1 -1
- package/src/tasks/nlp/tableQuestionAnswering.ts +2 -2
- package/src/tasks/nlp/textClassification.ts +1 -1
- package/src/tasks/nlp/textGeneration.ts +1 -1
- package/src/tasks/nlp/textGenerationStream.ts +1 -1
- package/src/tasks/nlp/tokenClassification.ts +2 -2
- package/src/tasks/nlp/zeroShotClassification.ts +2 -2
- package/src/tasks/tabular/tabularClassification.ts +1 -1
- package/src/tasks/tabular/tabularRegression.ts +1 -1
- package/src/types.ts +2 -0
- package/src/utils/pick.ts +1 -1
- package/src/utils/request.ts +20 -20
- package/src/utils/typedEntries.ts +1 -1
package/src/providers/novita.ts
CHANGED
|
@@ -79,7 +79,7 @@ export class NovitaTextToVideoTask extends TaskProviderHelper implements TextToV
|
|
|
79
79
|
override async getResponse(
|
|
80
80
|
response: NovitaAsyncAPIOutput,
|
|
81
81
|
url?: string,
|
|
82
|
-
headers?: Record<string, string
|
|
82
|
+
headers?: Record<string, string>,
|
|
83
83
|
): Promise<Blob> {
|
|
84
84
|
if (!url || !headers) {
|
|
85
85
|
throw new InferenceClientInputError("URL and headers are required for text-to-video task");
|
|
@@ -87,7 +87,7 @@ export class NovitaTextToVideoTask extends TaskProviderHelper implements TextToV
|
|
|
87
87
|
const taskId = response.task_id;
|
|
88
88
|
if (!taskId) {
|
|
89
89
|
throw new InferenceClientProviderOutputError(
|
|
90
|
-
"Received malformed response from Novita text-to-video API: no task ID found in the response"
|
|
90
|
+
"Received malformed response from Novita text-to-video API: no task ID found in the response",
|
|
91
91
|
);
|
|
92
92
|
}
|
|
93
93
|
|
|
@@ -111,7 +111,7 @@ export class NovitaTextToVideoTask extends TaskProviderHelper implements TextToV
|
|
|
111
111
|
requestId: resultResponse.headers.get("x-request-id") ?? "",
|
|
112
112
|
status: resultResponse.status,
|
|
113
113
|
body: await resultResponse.text(),
|
|
114
|
-
}
|
|
114
|
+
},
|
|
115
115
|
);
|
|
116
116
|
}
|
|
117
117
|
try {
|
|
@@ -128,12 +128,12 @@ export class NovitaTextToVideoTask extends TaskProviderHelper implements TextToV
|
|
|
128
128
|
status = taskResult.task.status;
|
|
129
129
|
} else {
|
|
130
130
|
throw new InferenceClientProviderOutputError(
|
|
131
|
-
"Received malformed response from Novita text-to-video API: failed to get task status"
|
|
131
|
+
"Received malformed response from Novita text-to-video API: failed to get task status",
|
|
132
132
|
);
|
|
133
133
|
}
|
|
134
134
|
} catch (error) {
|
|
135
135
|
throw new InferenceClientProviderOutputError(
|
|
136
|
-
"Received malformed response from Novita text-to-video API: failed to parse task result"
|
|
136
|
+
"Received malformed response from Novita text-to-video API: failed to parse task result",
|
|
137
137
|
);
|
|
138
138
|
}
|
|
139
139
|
}
|
|
@@ -159,8 +159,8 @@ export class NovitaTextToVideoTask extends TaskProviderHelper implements TextToV
|
|
|
159
159
|
} else {
|
|
160
160
|
throw new InferenceClientProviderOutputError(
|
|
161
161
|
`Received malformed response from Novita text-to-video API: expected { videos: [{ video_url: string }] } format, got instead: ${JSON.stringify(
|
|
162
|
-
taskResult
|
|
163
|
-
)}
|
|
162
|
+
taskResult,
|
|
163
|
+
)}`,
|
|
164
164
|
);
|
|
165
165
|
}
|
|
166
166
|
}
|
package/src/providers/nscale.ts
CHANGED
|
@@ -42,7 +42,7 @@ export class NscaleTextToImageTask extends TaskProviderHelper implements TextToI
|
|
|
42
42
|
preparePayload(params: BodyParams<TextToImageInput>): Record<string, unknown> {
|
|
43
43
|
if (params.outputType === "url") {
|
|
44
44
|
throw new InferenceClientInputError(
|
|
45
|
-
"nscale provider does not support URL output. Use outputType 'blob', 'dataUrl' or 'json' instead."
|
|
45
|
+
"nscale provider does not support URL output. Use outputType 'blob', 'dataUrl' or 'json' instead.",
|
|
46
46
|
);
|
|
47
47
|
}
|
|
48
48
|
return {
|
|
@@ -62,7 +62,7 @@ export class NscaleTextToImageTask extends TaskProviderHelper implements TextToI
|
|
|
62
62
|
response: NscaleCloudBase64ImageGeneration,
|
|
63
63
|
url?: string,
|
|
64
64
|
headers?: HeadersInit,
|
|
65
|
-
outputType?: OutputType
|
|
65
|
+
outputType?: OutputType,
|
|
66
66
|
): Promise<string | Blob | Record<string, unknown>> {
|
|
67
67
|
if (
|
|
68
68
|
typeof response === "object" &&
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* See the registered mapping of HF model ID => NVIDIA model ID here:
|
|
3
|
+
*
|
|
4
|
+
* https://huggingface.co/api/partners/nvidia/models
|
|
5
|
+
*
|
|
6
|
+
* This is a publicly available mapping.
|
|
7
|
+
*
|
|
8
|
+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
|
|
9
|
+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
|
|
10
|
+
*
|
|
11
|
+
* - If you work at NVIDIA and want to update this mapping, please use the model mapping API we provide on huggingface.co
|
|
12
|
+
* - If you're a community member and want to add a new supported HF model to NVIDIA, please open an issue on the present repo
|
|
13
|
+
* and we will tag NVIDIA team members.
|
|
14
|
+
*
|
|
15
|
+
* Thanks!
|
|
16
|
+
*/
|
|
17
|
+
import { BaseConversationalTask } from "./providerHelper.js";
|
|
18
|
+
|
|
19
|
+
export class NvidiaConversationalTask extends BaseConversationalTask {
|
|
20
|
+
constructor() {
|
|
21
|
+
super("nvidia", "https://integrate.api.nvidia.com");
|
|
22
|
+
}
|
|
23
|
+
}
|
|
@@ -52,7 +52,7 @@ export class OvhCloudTextGenerationTask extends BaseTextGenerationTask {
|
|
|
52
52
|
? {
|
|
53
53
|
max_tokens: (params.args.parameters as Record<string, unknown>).max_new_tokens,
|
|
54
54
|
...omit(params.args.parameters as Record<string, unknown>, "max_new_tokens"),
|
|
55
|
-
|
|
55
|
+
}
|
|
56
56
|
: undefined),
|
|
57
57
|
prompt: params.args.inputs,
|
|
58
58
|
};
|
|
@@ -75,7 +75,7 @@ export abstract class TaskProviderHelper {
|
|
|
75
75
|
constructor(
|
|
76
76
|
readonly provider: InferenceProvider,
|
|
77
77
|
protected baseUrl: string,
|
|
78
|
-
readonly clientSideRoutingOnly: boolean = false
|
|
78
|
+
readonly clientSideRoutingOnly: boolean = false,
|
|
79
79
|
) {}
|
|
80
80
|
|
|
81
81
|
/**
|
|
@@ -86,7 +86,7 @@ export abstract class TaskProviderHelper {
|
|
|
86
86
|
response: unknown,
|
|
87
87
|
url?: string,
|
|
88
88
|
headers?: HeadersInit,
|
|
89
|
-
outputType?: OutputType
|
|
89
|
+
outputType?: OutputType,
|
|
90
90
|
): Promise<unknown>;
|
|
91
91
|
|
|
92
92
|
/**
|
|
@@ -152,7 +152,7 @@ export interface TextToImageTaskHelper {
|
|
|
152
152
|
response: unknown,
|
|
153
153
|
url?: string,
|
|
154
154
|
headers?: HeadersInit,
|
|
155
|
-
outputType?: OutputType
|
|
155
|
+
outputType?: OutputType,
|
|
156
156
|
): Promise<string | Blob | Record<string, unknown>>;
|
|
157
157
|
preparePayload(params: BodyParams<TextToImageInput & BaseArgs>): Record<string, unknown>;
|
|
158
158
|
}
|
|
@@ -282,7 +282,7 @@ export interface TextToAudioTaskHelper {
|
|
|
282
282
|
export interface AudioToAudioTaskHelper {
|
|
283
283
|
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<AudioToAudioOutput[]>;
|
|
284
284
|
preparePayload(
|
|
285
|
-
params: BodyParams<BaseArgs & { inputs: Blob } & Record<string, unknown
|
|
285
|
+
params: BodyParams<BaseArgs & { inputs: Blob } & Record<string, unknown>>,
|
|
286
286
|
): Record<string, unknown> | BodyInit;
|
|
287
287
|
}
|
|
288
288
|
export interface AutomaticSpeechRecognitionTaskHelper {
|
|
@@ -315,14 +315,14 @@ export interface VisualQuestionAnsweringTaskHelper {
|
|
|
315
315
|
export interface TabularClassificationTaskHelper {
|
|
316
316
|
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<number[]>;
|
|
317
317
|
preparePayload(
|
|
318
|
-
params: BodyParams<BaseArgs & { inputs: { data: Record<string, string[]> } } & Record<string, unknown
|
|
318
|
+
params: BodyParams<BaseArgs & { inputs: { data: Record<string, string[]> } } & Record<string, unknown>>,
|
|
319
319
|
): Record<string, unknown> | BodyInit;
|
|
320
320
|
}
|
|
321
321
|
|
|
322
322
|
export interface TabularRegressionTaskHelper {
|
|
323
323
|
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<number[]>;
|
|
324
324
|
preparePayload(
|
|
325
|
-
params: BodyParams<BaseArgs & { inputs: { data: Record<string, string[]> } } & Record<string, unknown
|
|
325
|
+
params: BodyParams<BaseArgs & { inputs: { data: Record<string, string[]> } } & Record<string, unknown>>,
|
|
326
326
|
): Record<string, unknown> | BodyInit;
|
|
327
327
|
}
|
|
328
328
|
|
|
@@ -387,7 +387,7 @@ export class BaseTextGenerationTask extends TaskProviderHelper implements TextGe
|
|
|
387
387
|
res.length > 0 &&
|
|
388
388
|
res.every(
|
|
389
389
|
(x): x is { generated_text: string } =>
|
|
390
|
-
typeof x === "object" && !!x && "generated_text" in x && typeof x.generated_text === "string"
|
|
390
|
+
typeof x === "object" && !!x && "generated_text" in x && typeof x.generated_text === "string",
|
|
391
391
|
)
|
|
392
392
|
) {
|
|
393
393
|
return res[0];
|
|
@@ -92,7 +92,7 @@ export class ReplicateTextToImageTask extends ReplicateTask implements TextToIma
|
|
|
92
92
|
res: ReplicateOutput | Blob,
|
|
93
93
|
url?: string,
|
|
94
94
|
headers?: Record<string, string>,
|
|
95
|
-
outputType?: OutputType
|
|
95
|
+
outputType?: OutputType,
|
|
96
96
|
): Promise<string | Blob | Record<string, unknown>> {
|
|
97
97
|
void url;
|
|
98
98
|
void headers;
|
|
@@ -236,18 +236,23 @@ export class ReplicateAutomaticSpeechRecognitionTask
|
|
|
236
236
|
}
|
|
237
237
|
}
|
|
238
238
|
throw new InferenceClientProviderOutputError(
|
|
239
|
-
"Received malformed response from Replicate automatic-speech-recognition API"
|
|
239
|
+
"Received malformed response from Replicate automatic-speech-recognition API",
|
|
240
240
|
);
|
|
241
241
|
}
|
|
242
242
|
}
|
|
243
243
|
|
|
244
244
|
export class ReplicateImageToImageTask extends ReplicateTask implements ImageToImageTaskHelper {
|
|
245
245
|
override preparePayload(params: BodyParams<ImageToImageArgs>): Record<string, unknown> {
|
|
246
|
+
const imageInput = params.args.inputs; // This will be processed in preparePayloadAsync
|
|
246
247
|
return {
|
|
247
248
|
input: {
|
|
248
249
|
...omit(params.args, ["inputs", "parameters"]),
|
|
249
250
|
...params.args.parameters,
|
|
250
|
-
|
|
251
|
+
// Different Replicate models expect the image in different keys
|
|
252
|
+
image: imageInput,
|
|
253
|
+
images: [imageInput],
|
|
254
|
+
input_image: imageInput,
|
|
255
|
+
input_images: [imageInput],
|
|
251
256
|
lora_weights:
|
|
252
257
|
params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath
|
|
253
258
|
? `https://huggingface.co/${params.mapping.hfModelId}`
|
|
@@ -54,7 +54,7 @@ export class SambanovaFeatureExtractionTask extends TaskProviderHelper implement
|
|
|
54
54
|
return response.data.map((item) => item.embedding);
|
|
55
55
|
}
|
|
56
56
|
throw new InferenceClientProviderOutputError(
|
|
57
|
-
"Received malformed response from Sambanova feature-extraction (embeddings) API"
|
|
57
|
+
"Received malformed response from Sambanova feature-extraction (embeddings) API",
|
|
58
58
|
);
|
|
59
59
|
}
|
|
60
60
|
|
|
@@ -119,7 +119,7 @@ export class TogetherTextToImageTask extends TaskProviderHelper implements TextT
|
|
|
119
119
|
response: TogetherImageGeneration,
|
|
120
120
|
url?: string,
|
|
121
121
|
headers?: HeadersInit,
|
|
122
|
-
outputType?: OutputType
|
|
122
|
+
outputType?: OutputType,
|
|
123
123
|
): Promise<string | Blob | Record<string, unknown>> {
|
|
124
124
|
if (
|
|
125
125
|
typeof response === "object" &&
|
|
@@ -76,10 +76,10 @@ interface WaveSpeedAISubmitTaskResponse {
|
|
|
76
76
|
|
|
77
77
|
async function buildImagesField(
|
|
78
78
|
inputs: Blob | ArrayBuffer,
|
|
79
|
-
hasImages: unknown
|
|
79
|
+
hasImages: unknown,
|
|
80
80
|
): Promise<{ base: string; images: string[] }> {
|
|
81
81
|
const base = base64FromBytes(
|
|
82
|
-
new Uint8Array(inputs instanceof ArrayBuffer ? inputs : await (inputs as Blob).arrayBuffer())
|
|
82
|
+
new Uint8Array(inputs instanceof ArrayBuffer ? inputs : await (inputs as Blob).arrayBuffer()),
|
|
83
83
|
);
|
|
84
84
|
const images =
|
|
85
85
|
Array.isArray(hasImages) && hasImages.every((value): value is string => typeof value === "string")
|
|
@@ -105,7 +105,7 @@ abstract class WavespeedAITask extends TaskProviderHelper {
|
|
|
105
105
|
| TextToImageArgs
|
|
106
106
|
| TextToVideoArgs
|
|
107
107
|
| ImageToVideoArgs
|
|
108
|
-
|
|
108
|
+
>,
|
|
109
109
|
): Record<string, unknown> {
|
|
110
110
|
const payload: Record<string, unknown> = {
|
|
111
111
|
...omit(params.args, ["inputs", "parameters"]),
|
|
@@ -128,7 +128,7 @@ abstract class WavespeedAITask extends TaskProviderHelper {
|
|
|
128
128
|
response: WaveSpeedAISubmitTaskResponse,
|
|
129
129
|
url?: string,
|
|
130
130
|
headers?: Record<string, string>,
|
|
131
|
-
outputType?: OutputType
|
|
131
|
+
outputType?: OutputType,
|
|
132
132
|
): Promise<string | Blob | Record<string, unknown>> {
|
|
133
133
|
if (!url || !headers) {
|
|
134
134
|
throw new InferenceClientInputError("Headers are required for WaveSpeed AI API calls");
|
|
@@ -154,7 +154,7 @@ abstract class WavespeedAITask extends TaskProviderHelper {
|
|
|
154
154
|
requestId: resultResponse.headers.get("x-request-id") ?? "",
|
|
155
155
|
status: resultResponse.status,
|
|
156
156
|
body: await resultResponse.text(),
|
|
157
|
-
}
|
|
157
|
+
},
|
|
158
158
|
);
|
|
159
159
|
}
|
|
160
160
|
|
|
@@ -166,7 +166,7 @@ abstract class WavespeedAITask extends TaskProviderHelper {
|
|
|
166
166
|
// Get the media data from the first output URL
|
|
167
167
|
if (!taskResult.outputs?.[0]) {
|
|
168
168
|
throw new InferenceClientProviderOutputError(
|
|
169
|
-
"Received malformed response from WaveSpeed AI API: No output URL in completed response"
|
|
169
|
+
"Received malformed response from WaveSpeed AI API: No output URL in completed response",
|
|
170
170
|
);
|
|
171
171
|
}
|
|
172
172
|
const mediaUrl = taskResult.outputs[0];
|
|
@@ -188,7 +188,7 @@ abstract class WavespeedAITask extends TaskProviderHelper {
|
|
|
188
188
|
requestId: mediaResponse.headers.get("x-request-id") ?? "",
|
|
189
189
|
status: mediaResponse.status,
|
|
190
190
|
body: await mediaResponse.text(),
|
|
191
|
-
}
|
|
191
|
+
},
|
|
192
192
|
);
|
|
193
193
|
}
|
|
194
194
|
const blob = await mediaResponse.blob();
|
|
@@ -222,7 +222,7 @@ export class WavespeedAITextToVideoTask extends WavespeedAITask implements TextT
|
|
|
222
222
|
override async getResponse(
|
|
223
223
|
response: WaveSpeedAISubmitTaskResponse,
|
|
224
224
|
url?: string,
|
|
225
|
-
headers?: Record<string, string
|
|
225
|
+
headers?: Record<string, string>,
|
|
226
226
|
): Promise<Blob> {
|
|
227
227
|
return super.getResponse(response, url, headers) as Promise<Blob>;
|
|
228
228
|
}
|
|
@@ -243,7 +243,7 @@ export class WavespeedAIImageToImageTask extends WavespeedAITask implements Imag
|
|
|
243
243
|
override async getResponse(
|
|
244
244
|
response: WaveSpeedAISubmitTaskResponse,
|
|
245
245
|
url?: string,
|
|
246
|
-
headers?: Record<string, string
|
|
246
|
+
headers?: Record<string, string>,
|
|
247
247
|
): Promise<Blob> {
|
|
248
248
|
return super.getResponse(response, url, headers) as Promise<Blob>;
|
|
249
249
|
}
|
|
@@ -264,7 +264,7 @@ export class WavespeedAIImageToVideoTask extends WavespeedAITask implements Imag
|
|
|
264
264
|
override async getResponse(
|
|
265
265
|
response: WaveSpeedAISubmitTaskResponse,
|
|
266
266
|
url?: string,
|
|
267
|
-
headers?: Record<string, string
|
|
267
|
+
headers?: Record<string, string>,
|
|
268
268
|
): Promise<Blob> {
|
|
269
269
|
return super.getResponse(response, url, headers) as Promise<Blob>;
|
|
270
270
|
}
|
package/src/providers/zai-org.ts
CHANGED
|
@@ -97,7 +97,7 @@ export class ZaiTextToImageTask extends TaskProviderHelper implements TextToImag
|
|
|
97
97
|
response: ZaiTextToImageResponse,
|
|
98
98
|
url?: string,
|
|
99
99
|
headers?: Record<string, string>,
|
|
100
|
-
outputType?: OutputType
|
|
100
|
+
outputType?: OutputType,
|
|
101
101
|
): Promise<string | Blob | Record<string, unknown>> {
|
|
102
102
|
if (!url || !headers) {
|
|
103
103
|
throw new InferenceClientInputError(`URL and headers are required for 'text-to-image' task`);
|
|
@@ -111,8 +111,8 @@ export class ZaiTextToImageTask extends TaskProviderHelper implements TextToImag
|
|
|
111
111
|
) {
|
|
112
112
|
throw new InferenceClientProviderOutputError(
|
|
113
113
|
`Received malformed response from ZAI text-to-image API: expected { id: string, task_status: string }, got: ${JSON.stringify(
|
|
114
|
-
response
|
|
115
|
-
)}
|
|
114
|
+
response,
|
|
115
|
+
)}`,
|
|
116
116
|
);
|
|
117
117
|
}
|
|
118
118
|
|
|
@@ -145,7 +145,7 @@ export class ZaiTextToImageTask extends TaskProviderHelper implements TextToImag
|
|
|
145
145
|
throw new InferenceClientProviderApiError(
|
|
146
146
|
`Failed to fetch result from ZAI text-to-image API: ${resp.status}`,
|
|
147
147
|
{ url: pollUrl, method: "GET" },
|
|
148
|
-
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
|
|
148
|
+
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() },
|
|
149
149
|
);
|
|
150
150
|
}
|
|
151
151
|
|
|
@@ -165,8 +165,8 @@ export class ZaiTextToImageTask extends TaskProviderHelper implements TextToImag
|
|
|
165
165
|
) {
|
|
166
166
|
throw new InferenceClientProviderOutputError(
|
|
167
167
|
`Received malformed response from ZAI text-to-image API: expected { image_result: Array<{ url: string }> }, got: ${JSON.stringify(
|
|
168
|
-
result
|
|
169
|
-
)}
|
|
168
|
+
result,
|
|
169
|
+
)}`,
|
|
170
170
|
);
|
|
171
171
|
}
|
|
172
172
|
|
|
@@ -186,7 +186,7 @@ export class ZaiTextToImageTask extends TaskProviderHelper implements TextToImag
|
|
|
186
186
|
}
|
|
187
187
|
|
|
188
188
|
throw new InferenceClientProviderOutputError(
|
|
189
|
-
`Timed out while waiting for the result from ZAI API - aborting after ${MAX_POLL_ATTEMPTS} attempts
|
|
189
|
+
`Timed out while waiting for the result from ZAI API - aborting after ${MAX_POLL_ATTEMPTS} attempts`,
|
|
190
190
|
);
|
|
191
191
|
}
|
|
192
192
|
}
|
|
@@ -72,7 +72,7 @@ const hasTemplate = (language: InferenceSnippetLanguage, client: Client, templat
|
|
|
72
72
|
const loadTemplate = (
|
|
73
73
|
language: InferenceSnippetLanguage,
|
|
74
74
|
client: Client,
|
|
75
|
-
templateName: string
|
|
75
|
+
templateName: string,
|
|
76
76
|
): ((data: TemplateParams) => string) => {
|
|
77
77
|
const template = templates[language]?.[client]?.[templateName];
|
|
78
78
|
if (!template) {
|
|
@@ -148,7 +148,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
|
|
|
148
148
|
model: ModelDataMinimal,
|
|
149
149
|
provider: InferenceProviderOrPolicy,
|
|
150
150
|
inferenceProviderMapping?: InferenceProviderMappingEntry,
|
|
151
|
-
opts?: InferenceSnippetOptions
|
|
151
|
+
opts?: InferenceSnippetOptions,
|
|
152
152
|
): InferenceSnippet[] => {
|
|
153
153
|
const logger = getLogger();
|
|
154
154
|
const providerModelId = inferenceProviderMapping?.providerId ?? model.id;
|
|
@@ -180,8 +180,8 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
|
|
|
180
180
|
const inputs = opts?.inputs
|
|
181
181
|
? { inputs: opts.inputs }
|
|
182
182
|
: inputPreparationFn
|
|
183
|
-
|
|
184
|
-
|
|
183
|
+
? inputPreparationFn(model, opts)
|
|
184
|
+
: { inputs: getModelInputSnippet(model) };
|
|
185
185
|
const request = makeRequestOptionsFromResolvedModel(
|
|
186
186
|
providerModelId,
|
|
187
187
|
providerHelper,
|
|
@@ -195,7 +195,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
|
|
|
195
195
|
{
|
|
196
196
|
task,
|
|
197
197
|
billTo: opts?.billTo,
|
|
198
|
-
}
|
|
198
|
+
},
|
|
199
199
|
);
|
|
200
200
|
|
|
201
201
|
/// Parse request.info.body if not a binary.
|
|
@@ -218,11 +218,11 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
|
|
|
218
218
|
? {
|
|
219
219
|
...inputs,
|
|
220
220
|
model: `${model.id}:${provider}`,
|
|
221
|
-
|
|
221
|
+
}
|
|
222
222
|
: {
|
|
223
223
|
...inputs,
|
|
224
224
|
model: `${model.id}`, // if no :provider => auto
|
|
225
|
-
|
|
225
|
+
}
|
|
226
226
|
: providerInputs;
|
|
227
227
|
|
|
228
228
|
/// Prepare template injection data
|
|
@@ -265,7 +265,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
|
|
|
265
265
|
? provider !== "auto"
|
|
266
266
|
? `${model.id}:${provider}` // e.g. "moonshotai/Kimi-K2-Instruct:groq"
|
|
267
267
|
: model.id
|
|
268
|
-
: providerModelId ?? model.id,
|
|
268
|
+
: (providerModelId ?? model.id),
|
|
269
269
|
billTo: opts?.billTo,
|
|
270
270
|
endpointUrl: opts?.endpointUrl,
|
|
271
271
|
task,
|
|
@@ -324,7 +324,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
|
|
|
324
324
|
snippet,
|
|
325
325
|
language,
|
|
326
326
|
provider,
|
|
327
|
-
opts?.endpointUrl
|
|
327
|
+
opts?.endpointUrl,
|
|
328
328
|
);
|
|
329
329
|
}
|
|
330
330
|
|
|
@@ -354,7 +354,7 @@ const prepareConversationalInput = (
|
|
|
354
354
|
temperature?: GenerationParameters["temperature"];
|
|
355
355
|
max_tokens?: GenerationParameters["max_new_tokens"];
|
|
356
356
|
top_p?: GenerationParameters["top_p"];
|
|
357
|
-
}
|
|
357
|
+
},
|
|
358
358
|
): object => {
|
|
359
359
|
return {
|
|
360
360
|
messages: opts?.messages ?? getModelInputSnippet(model),
|
|
@@ -381,7 +381,7 @@ const snippets: Partial<
|
|
|
381
381
|
model: ModelDataMinimal,
|
|
382
382
|
provider: InferenceProviderOrPolicy,
|
|
383
383
|
inferenceProviderMapping?: InferenceProviderMappingEntry,
|
|
384
|
-
opts?: InferenceSnippetOptions
|
|
384
|
+
opts?: InferenceSnippetOptions,
|
|
385
385
|
) => InferenceSnippet[]
|
|
386
386
|
>
|
|
387
387
|
> = {
|
|
@@ -422,10 +422,10 @@ export function getInferenceSnippets(
|
|
|
422
422
|
model: ModelDataMinimal,
|
|
423
423
|
provider: InferenceProviderOrPolicy,
|
|
424
424
|
inferenceProviderMapping?: InferenceProviderMappingEntry,
|
|
425
|
-
opts?: Record<string, unknown
|
|
425
|
+
opts?: Record<string, unknown>,
|
|
426
426
|
): InferenceSnippet[] {
|
|
427
427
|
return model.pipeline_tag && model.pipeline_tag in snippets
|
|
428
|
-
? snippets[model.pipeline_tag]?.(model, provider, inferenceProviderMapping, opts) ?? []
|
|
428
|
+
? (snippets[model.pipeline_tag]?.(model, provider, inferenceProviderMapping, opts) ?? [])
|
|
429
429
|
: [];
|
|
430
430
|
}
|
|
431
431
|
|
|
@@ -447,7 +447,7 @@ function formatBody(obj: object, format: "curl" | "json" | "python" | "ts"): str
|
|
|
447
447
|
const formattedValue = JSON.stringify(value, null, 4).replace(/"/g, '"');
|
|
448
448
|
return `${key}=${formattedValue},`;
|
|
449
449
|
})
|
|
450
|
-
.join("\n")
|
|
450
|
+
.join("\n"),
|
|
451
451
|
);
|
|
452
452
|
|
|
453
453
|
case "ts":
|
|
@@ -507,7 +507,7 @@ function replaceAccessTokenPlaceholder(
|
|
|
507
507
|
snippet: string,
|
|
508
508
|
language: InferenceSnippetLanguage,
|
|
509
509
|
provider: InferenceProviderOrPolicy,
|
|
510
|
-
endpointUrl?: string
|
|
510
|
+
endpointUrl?: string,
|
|
511
511
|
): string {
|
|
512
512
|
// If "opts.accessToken" is not set, the snippets are generated with a placeholder.
|
|
513
513
|
// Once snippets are rendered, we replace the placeholder with code to fetch the access token from an environment variable.
|
|
@@ -522,49 +522,49 @@ function replaceAccessTokenPlaceholder(
|
|
|
522
522
|
const accessTokenEnvVar = useHfToken
|
|
523
523
|
? "HF_TOKEN" // e.g. routed request or hf-inference
|
|
524
524
|
: endpointUrl
|
|
525
|
-
|
|
526
|
-
|
|
525
|
+
? "API_TOKEN"
|
|
526
|
+
: provider.toUpperCase().replace("-", "_") + "_API_KEY"; // e.g. "REPLICATE_API_KEY"
|
|
527
527
|
|
|
528
528
|
// Replace the placeholder with the env variable
|
|
529
529
|
if (language === "sh") {
|
|
530
530
|
snippet = snippet.replace(
|
|
531
531
|
`'Authorization: Bearer ${placeholder}'`,
|
|
532
|
-
`"Authorization: Bearer $${accessTokenEnvVar}"
|
|
532
|
+
`"Authorization: Bearer $${accessTokenEnvVar}"`, // e.g. "Authorization: Bearer $HF_TOKEN"
|
|
533
533
|
);
|
|
534
534
|
} else if (language === "python") {
|
|
535
535
|
snippet = "import os\n" + snippet;
|
|
536
536
|
snippet = snippet.replace(
|
|
537
537
|
`"${placeholder}"`,
|
|
538
|
-
`os.environ["${accessTokenEnvVar}"]
|
|
538
|
+
`os.environ["${accessTokenEnvVar}"]`, // e.g. os.environ["HF_TOKEN")
|
|
539
539
|
);
|
|
540
540
|
snippet = snippet.replace(
|
|
541
541
|
`"Bearer ${placeholder}"`,
|
|
542
|
-
`f"Bearer {os.environ['${accessTokenEnvVar}']}"
|
|
542
|
+
`f"Bearer {os.environ['${accessTokenEnvVar}']}"`, // e.g. f"Bearer {os.environ['HF_TOKEN']}"
|
|
543
543
|
);
|
|
544
544
|
snippet = snippet.replace(
|
|
545
545
|
`"Key ${placeholder}"`,
|
|
546
|
-
`f"Key {os.environ['${accessTokenEnvVar}']}"
|
|
546
|
+
`f"Key {os.environ['${accessTokenEnvVar}']}"`, // e.g. f"Key {os.environ['FAL_AI_API_KEY']}"
|
|
547
547
|
);
|
|
548
548
|
snippet = snippet.replace(
|
|
549
549
|
`"X-Key ${placeholder}"`,
|
|
550
|
-
`f"X-Key {os.environ['${accessTokenEnvVar}']}"
|
|
550
|
+
`f"X-Key {os.environ['${accessTokenEnvVar}']}"`, // e.g. f"X-Key {os.environ['BLACK_FOREST_LABS_API_KEY']}"
|
|
551
551
|
);
|
|
552
552
|
} else if (language === "js") {
|
|
553
553
|
snippet = snippet.replace(
|
|
554
554
|
`"${placeholder}"`,
|
|
555
|
-
`process.env.${accessTokenEnvVar}
|
|
555
|
+
`process.env.${accessTokenEnvVar}`, // e.g. process.env.HF_TOKEN
|
|
556
556
|
);
|
|
557
557
|
snippet = snippet.replace(
|
|
558
558
|
`Authorization: "Bearer ${placeholder}",`,
|
|
559
|
-
`Authorization: \`Bearer $\{process.env.${accessTokenEnvVar}}
|
|
559
|
+
`Authorization: \`Bearer $\{process.env.${accessTokenEnvVar}}\`,`, // e.g. Authorization: `Bearer ${process.env.HF_TOKEN}`,
|
|
560
560
|
);
|
|
561
561
|
snippet = snippet.replace(
|
|
562
562
|
`Authorization: "Key ${placeholder}",`,
|
|
563
|
-
`Authorization: \`Key $\{process.env.${accessTokenEnvVar}}
|
|
563
|
+
`Authorization: \`Key $\{process.env.${accessTokenEnvVar}}\`,`, // e.g. Authorization: `Key ${process.env.FAL_AI_API_KEY}`,
|
|
564
564
|
);
|
|
565
565
|
snippet = snippet.replace(
|
|
566
566
|
`Authorization: "X-Key ${placeholder}",`,
|
|
567
|
-
`Authorization: \`X-Key $\{process.env.${accessTokenEnvVar}}
|
|
567
|
+
`Authorization: \`X-Key $\{process.env.${accessTokenEnvVar}}\`,`, // e.g. Authorization: `X-Key ${process.env.BLACK_FOREST_LABS_AI_API_KEY}`,
|
|
568
568
|
);
|
|
569
569
|
}
|
|
570
570
|
return snippet;
|
|
@@ -14,7 +14,7 @@ export type AudioClassificationArgs = BaseArgs & (AudioClassificationInput | Leg
|
|
|
14
14
|
*/
|
|
15
15
|
export async function audioClassification(
|
|
16
16
|
args: AudioClassificationArgs,
|
|
17
|
-
options?: Options
|
|
17
|
+
options?: Options,
|
|
18
18
|
): Promise<AudioClassificationOutput> {
|
|
19
19
|
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
20
20
|
const providerHelper = getProviderHelper(provider, "audio-classification");
|
|
@@ -12,7 +12,7 @@ export type AutomaticSpeechRecognitionArgs = BaseArgs & (AutomaticSpeechRecognit
|
|
|
12
12
|
*/
|
|
13
13
|
export async function automaticSpeechRecognition(
|
|
14
14
|
args: AutomaticSpeechRecognitionArgs,
|
|
15
|
-
options?: Options
|
|
15
|
+
options?: Options,
|
|
16
16
|
): Promise<AutomaticSpeechRecognitionOutput> {
|
|
17
17
|
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
18
18
|
const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
|
package/src/tasks/audio/utils.ts
CHANGED
|
@@ -13,11 +13,11 @@ export async function request<T>(
|
|
|
13
13
|
options?: Options & {
|
|
14
14
|
/** In most cases (unless we pass a endpointUrl) we know the task */
|
|
15
15
|
task?: InferenceTask;
|
|
16
|
-
}
|
|
16
|
+
},
|
|
17
17
|
): Promise<T> {
|
|
18
18
|
const logger = getLogger();
|
|
19
19
|
logger.warn(
|
|
20
|
-
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
|
|
20
|
+
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead.",
|
|
21
21
|
);
|
|
22
22
|
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
23
23
|
const providerHelper = getProviderHelper(provider, options?.task);
|
|
@@ -13,11 +13,11 @@ export async function* streamingRequest<T>(
|
|
|
13
13
|
options?: Options & {
|
|
14
14
|
/** In most cases (unless we pass a endpointUrl) we know the task */
|
|
15
15
|
task?: InferenceTask;
|
|
16
|
-
}
|
|
16
|
+
},
|
|
17
17
|
): AsyncGenerator<T> {
|
|
18
18
|
const logger = getLogger();
|
|
19
19
|
logger.warn(
|
|
20
|
-
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
|
|
20
|
+
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead.",
|
|
21
21
|
);
|
|
22
22
|
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
23
23
|
const providerHelper = getProviderHelper(provider, options?.task);
|
|
@@ -13,7 +13,7 @@ export type ImageClassificationArgs = BaseArgs & (ImageClassificationInput | Leg
|
|
|
13
13
|
*/
|
|
14
14
|
export async function imageClassification(
|
|
15
15
|
args: ImageClassificationArgs,
|
|
16
|
-
options?: Options
|
|
16
|
+
options?: Options,
|
|
17
17
|
): Promise<ImageClassificationOutput> {
|
|
18
18
|
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
19
19
|
const providerHelper = getProviderHelper(provider, "image-classification");
|
|
@@ -13,7 +13,7 @@ export type ImageSegmentationArgs = BaseArgs & ImageSegmentationInput;
|
|
|
13
13
|
*/
|
|
14
14
|
export async function imageSegmentation(
|
|
15
15
|
args: ImageSegmentationArgs,
|
|
16
|
-
options?: Options
|
|
16
|
+
options?: Options,
|
|
17
17
|
): Promise<ImageSegmentationOutput> {
|
|
18
18
|
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
19
19
|
const providerHelper = getProviderHelper(provider, "image-segmentation");
|
|
@@ -17,23 +17,23 @@ interface TextToImageOptions extends Options {
|
|
|
17
17
|
*/
|
|
18
18
|
export async function textToImage(
|
|
19
19
|
args: TextToImageArgs,
|
|
20
|
-
options?: TextToImageOptions & { outputType: "url" }
|
|
20
|
+
options?: TextToImageOptions & { outputType: "url" },
|
|
21
21
|
): Promise<string>;
|
|
22
22
|
export async function textToImage(
|
|
23
23
|
args: TextToImageArgs,
|
|
24
|
-
options?: TextToImageOptions & { outputType: "dataUrl" }
|
|
24
|
+
options?: TextToImageOptions & { outputType: "dataUrl" },
|
|
25
25
|
): Promise<string>;
|
|
26
26
|
export async function textToImage(
|
|
27
27
|
args: TextToImageArgs,
|
|
28
|
-
options?: TextToImageOptions & { outputType?: undefined | "blob" }
|
|
28
|
+
options?: TextToImageOptions & { outputType?: undefined | "blob" },
|
|
29
29
|
): Promise<Blob>;
|
|
30
30
|
export async function textToImage(
|
|
31
31
|
args: TextToImageArgs,
|
|
32
|
-
options?: TextToImageOptions & { outputType?: undefined | "json" }
|
|
32
|
+
options?: TextToImageOptions & { outputType?: undefined | "json" },
|
|
33
33
|
): Promise<Record<string, unknown>>;
|
|
34
34
|
export async function textToImage(
|
|
35
35
|
args: TextToImageArgs,
|
|
36
|
-
options?: TextToImageOptions
|
|
36
|
+
options?: TextToImageOptions,
|
|
37
37
|
): Promise<Blob | string | Record<string, unknown>> {
|
|
38
38
|
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
39
39
|
const providerHelper = getProviderHelper(provider, "text-to-image");
|
|
@@ -21,7 +21,7 @@ export async function textToVideo(args: TextToVideoArgs, options?: Options): Pro
|
|
|
21
21
|
{
|
|
22
22
|
...options,
|
|
23
23
|
task: "text-to-video",
|
|
24
|
-
}
|
|
24
|
+
},
|
|
25
25
|
);
|
|
26
26
|
const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "text-to-video" });
|
|
27
27
|
return providerHelper.getResponse(response, url, info.headers as Record<string, string>);
|