@huggingface/inference 4.11.2 → 4.12.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +2 -0
- package/dist/commonjs/errors.d.ts +3 -0
- package/dist/commonjs/errors.d.ts.map +1 -1
- package/dist/commonjs/errors.js +8 -1
- package/dist/commonjs/lib/getInferenceProviderMapping.d.ts.map +1 -1
- package/dist/commonjs/lib/getInferenceProviderMapping.js +11 -0
- package/dist/commonjs/lib/getProviderHelper.d.ts.map +1 -1
- package/dist/commonjs/lib/getProviderHelper.js +6 -0
- package/dist/commonjs/package.d.ts +1 -1
- package/dist/commonjs/package.js +1 -1
- package/dist/commonjs/providers/consts.d.ts.map +1 -1
- package/dist/commonjs/providers/consts.js +1 -0
- package/dist/commonjs/providers/providerHelper.d.ts +5 -1
- package/dist/commonjs/providers/providerHelper.d.ts.map +1 -1
- package/dist/commonjs/providers/providerHelper.js +13 -1
- package/dist/commonjs/providers/wavespeed.d.ts +41 -0
- package/dist/commonjs/providers/wavespeed.d.ts.map +1 -0
- package/dist/commonjs/providers/wavespeed.js +103 -0
- package/dist/commonjs/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/commonjs/tasks/nlp/chatCompletion.js +10 -2
- package/dist/commonjs/types.d.ts +8 -2
- package/dist/commonjs/types.d.ts.map +1 -1
- package/dist/commonjs/types.js +32 -1
- package/dist/esm/errors.d.ts +3 -0
- package/dist/esm/errors.d.ts.map +1 -1
- package/dist/esm/errors.js +6 -0
- package/dist/esm/lib/getInferenceProviderMapping.d.ts.map +1 -1
- package/dist/esm/lib/getInferenceProviderMapping.js +11 -0
- package/dist/esm/lib/getProviderHelper.d.ts.map +1 -1
- package/dist/esm/lib/getProviderHelper.js +6 -0
- package/dist/esm/package.d.ts +1 -1
- package/dist/esm/package.js +1 -1
- package/dist/esm/providers/consts.d.ts.map +1 -1
- package/dist/esm/providers/consts.js +1 -0
- package/dist/esm/providers/providerHelper.d.ts +5 -1
- package/dist/esm/providers/providerHelper.d.ts.map +1 -1
- package/dist/esm/providers/providerHelper.js +12 -1
- package/dist/esm/providers/wavespeed.d.ts +41 -0
- package/dist/esm/providers/wavespeed.d.ts.map +1 -0
- package/dist/esm/providers/wavespeed.js +97 -0
- package/dist/esm/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/esm/tasks/nlp/chatCompletion.js +10 -2
- package/dist/esm/types.d.ts +8 -2
- package/dist/esm/types.d.ts.map +1 -1
- package/dist/esm/types.js +31 -0
- package/package.json +2 -2
- package/src/errors.ts +7 -0
- package/src/lib/getInferenceProviderMapping.ts +11 -0
- package/src/lib/getProviderHelper.ts +6 -0
- package/src/package.ts +1 -1
- package/src/providers/consts.ts +1 -0
- package/src/providers/providerHelper.ts +15 -2
- package/src/providers/wavespeed.ts +185 -0
- package/src/tasks/nlp/chatCompletion.ts +10 -2
- package/src/types.ts +32 -0
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
import type { TextToImageArgs } from "../tasks/cv/textToImage.js";
|
|
2
|
+
import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js";
|
|
3
|
+
import type { TextToVideoArgs } from "../tasks/cv/textToVideo.js";
|
|
4
|
+
import type { BodyParams, RequestArgs, UrlParams } from "../types.js";
|
|
5
|
+
import { delay } from "../utils/delay.js";
|
|
6
|
+
import { omit } from "../utils/omit.js";
|
|
7
|
+
import { base64FromBytes } from "../utils/base64FromBytes.js";
|
|
8
|
+
import type { TextToImageTaskHelper, TextToVideoTaskHelper, ImageToImageTaskHelper } from "./providerHelper.js";
|
|
9
|
+
import { TaskProviderHelper } from "./providerHelper.js";
|
|
10
|
+
import {
|
|
11
|
+
InferenceClientInputError,
|
|
12
|
+
InferenceClientProviderApiError,
|
|
13
|
+
InferenceClientProviderOutputError,
|
|
14
|
+
} from "../errors.js";
|
|
15
|
+
|
|
16
|
+
const WAVESPEEDAI_API_BASE_URL = "https://api.wavespeed.ai";
|
|
17
|
+
|
|
18
|
+
/**
|
|
19
|
+
* Response structure for task status and results
|
|
20
|
+
*/
|
|
21
|
+
interface WaveSpeedAITaskResponse {
|
|
22
|
+
id: string;
|
|
23
|
+
model: string;
|
|
24
|
+
outputs: string[];
|
|
25
|
+
urls: {
|
|
26
|
+
get: string;
|
|
27
|
+
};
|
|
28
|
+
has_nsfw_contents: boolean[];
|
|
29
|
+
status: "created" | "processing" | "completed" | "failed";
|
|
30
|
+
created_at: string;
|
|
31
|
+
error: string;
|
|
32
|
+
executionTime: number;
|
|
33
|
+
timings: {
|
|
34
|
+
inference: number;
|
|
35
|
+
};
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
/**
|
|
39
|
+
* Response structure for initial task submission
|
|
40
|
+
*/
|
|
41
|
+
interface WaveSpeedAISubmitResponse {
|
|
42
|
+
id: string;
|
|
43
|
+
urls: {
|
|
44
|
+
get: string;
|
|
45
|
+
};
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
/**
|
|
49
|
+
* Response structure for WaveSpeed AI API
|
|
50
|
+
*/
|
|
51
|
+
interface WaveSpeedAIResponse {
|
|
52
|
+
code: number;
|
|
53
|
+
message: string;
|
|
54
|
+
data: WaveSpeedAITaskResponse;
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
/**
|
|
58
|
+
* Response structure for WaveSpeed AI API with submit response data
|
|
59
|
+
*/
|
|
60
|
+
interface WaveSpeedAISubmitTaskResponse {
|
|
61
|
+
code: number;
|
|
62
|
+
message: string;
|
|
63
|
+
data: WaveSpeedAISubmitResponse;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
abstract class WavespeedAITask extends TaskProviderHelper {
|
|
67
|
+
constructor(url?: string) {
|
|
68
|
+
super("wavespeed", url || WAVESPEEDAI_API_BASE_URL);
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
makeRoute(params: UrlParams): string {
|
|
72
|
+
return `/api/v3/${params.model}`;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
preparePayload(params: BodyParams<ImageToImageArgs | TextToImageArgs | TextToVideoArgs>): Record<string, unknown> {
|
|
76
|
+
const payload: Record<string, unknown> = {
|
|
77
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
78
|
+
...params.args.parameters,
|
|
79
|
+
prompt: params.args.inputs,
|
|
80
|
+
};
|
|
81
|
+
// Add LoRA support if adapter is specified in the mapping
|
|
82
|
+
if (params.mapping?.adapter === "lora") {
|
|
83
|
+
payload.loras = [
|
|
84
|
+
{
|
|
85
|
+
path: params.mapping.hfModelId,
|
|
86
|
+
scale: 1, // Default scale value
|
|
87
|
+
},
|
|
88
|
+
];
|
|
89
|
+
}
|
|
90
|
+
return payload;
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
override async getResponse(
|
|
94
|
+
response: WaveSpeedAISubmitTaskResponse,
|
|
95
|
+
url?: string,
|
|
96
|
+
headers?: Record<string, string>
|
|
97
|
+
): Promise<Blob> {
|
|
98
|
+
if (!headers) {
|
|
99
|
+
throw new InferenceClientInputError("Headers are required for WaveSpeed AI API calls");
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
const resultUrl = response.data.urls.get;
|
|
103
|
+
|
|
104
|
+
// Poll for results until completion
|
|
105
|
+
while (true) {
|
|
106
|
+
const resultResponse = await fetch(resultUrl, { headers });
|
|
107
|
+
|
|
108
|
+
if (!resultResponse.ok) {
|
|
109
|
+
throw new InferenceClientProviderApiError(
|
|
110
|
+
"Failed to fetch response status from WaveSpeed AI API",
|
|
111
|
+
{ url: resultUrl, method: "GET" },
|
|
112
|
+
{
|
|
113
|
+
requestId: resultResponse.headers.get("x-request-id") ?? "",
|
|
114
|
+
status: resultResponse.status,
|
|
115
|
+
body: await resultResponse.text(),
|
|
116
|
+
}
|
|
117
|
+
);
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
const result: WaveSpeedAIResponse = await resultResponse.json();
|
|
121
|
+
const taskResult = result.data;
|
|
122
|
+
|
|
123
|
+
switch (taskResult.status) {
|
|
124
|
+
case "completed": {
|
|
125
|
+
// Get the media data from the first output URL
|
|
126
|
+
if (!taskResult.outputs?.[0]) {
|
|
127
|
+
throw new InferenceClientProviderOutputError(
|
|
128
|
+
"Received malformed response from WaveSpeed AI API: No output URL in completed response"
|
|
129
|
+
);
|
|
130
|
+
}
|
|
131
|
+
const mediaResponse = await fetch(taskResult.outputs[0]);
|
|
132
|
+
if (!mediaResponse.ok) {
|
|
133
|
+
throw new InferenceClientProviderApiError(
|
|
134
|
+
"Failed to fetch generation output from WaveSpeed AI API",
|
|
135
|
+
{ url: taskResult.outputs[0], method: "GET" },
|
|
136
|
+
{
|
|
137
|
+
requestId: mediaResponse.headers.get("x-request-id") ?? "",
|
|
138
|
+
status: mediaResponse.status,
|
|
139
|
+
body: await mediaResponse.text(),
|
|
140
|
+
}
|
|
141
|
+
);
|
|
142
|
+
}
|
|
143
|
+
return await mediaResponse.blob();
|
|
144
|
+
}
|
|
145
|
+
case "failed": {
|
|
146
|
+
throw new InferenceClientProviderOutputError(taskResult.error || "Task failed");
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
default: {
|
|
150
|
+
// Wait before polling again
|
|
151
|
+
await delay(500);
|
|
152
|
+
continue;
|
|
153
|
+
}
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
export class WavespeedAITextToImageTask extends WavespeedAITask implements TextToImageTaskHelper {
|
|
160
|
+
constructor() {
|
|
161
|
+
super(WAVESPEEDAI_API_BASE_URL);
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
export class WavespeedAITextToVideoTask extends WavespeedAITask implements TextToVideoTaskHelper {
|
|
166
|
+
constructor() {
|
|
167
|
+
super(WAVESPEEDAI_API_BASE_URL);
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
export class WavespeedAIImageToImageTask extends WavespeedAITask implements ImageToImageTaskHelper {
|
|
172
|
+
constructor() {
|
|
173
|
+
super(WAVESPEEDAI_API_BASE_URL);
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
async preparePayloadAsync(args: ImageToImageArgs): Promise<RequestArgs> {
|
|
177
|
+
return {
|
|
178
|
+
...args,
|
|
179
|
+
inputs: args.parameters?.prompt,
|
|
180
|
+
image: base64FromBytes(
|
|
181
|
+
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer())
|
|
182
|
+
),
|
|
183
|
+
};
|
|
184
|
+
}
|
|
185
|
+
}
|
|
@@ -3,6 +3,8 @@ import { resolveProvider } from "../../lib/getInferenceProviderMapping.js";
|
|
|
3
3
|
import { getProviderHelper } from "../../lib/getProviderHelper.js";
|
|
4
4
|
import type { BaseArgs, Options } from "../../types.js";
|
|
5
5
|
import { innerRequest } from "../../utils/request.js";
|
|
6
|
+
import type { ConversationalTaskHelper, TaskProviderHelper } from "../../providers/providerHelper.js";
|
|
7
|
+
import { AutoRouterConversationalTask } from "../../providers/providerHelper.js";
|
|
6
8
|
|
|
7
9
|
/**
|
|
8
10
|
* Use the chat completion endpoint to generate a response to a prompt, using OpenAI message completion API no stream
|
|
@@ -11,8 +13,14 @@ export async function chatCompletion(
|
|
|
11
13
|
args: BaseArgs & ChatCompletionInput,
|
|
12
14
|
options?: Options
|
|
13
15
|
): Promise<ChatCompletionOutput> {
|
|
14
|
-
|
|
15
|
-
|
|
16
|
+
let providerHelper: ConversationalTaskHelper & TaskProviderHelper;
|
|
17
|
+
if (!args.provider || args.provider === "auto") {
|
|
18
|
+
// Special case: we have a dedicated auto-router for conversational models. No need to fetch provider mapping.
|
|
19
|
+
providerHelper = new AutoRouterConversationalTask();
|
|
20
|
+
} else {
|
|
21
|
+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
|
|
22
|
+
providerHelper = getProviderHelper(provider, "conversational");
|
|
23
|
+
}
|
|
16
24
|
const { data: response } = await innerRequest<ChatCompletionOutput>(args, providerHelper, {
|
|
17
25
|
...options,
|
|
18
26
|
task: "conversational",
|
package/src/types.ts
CHANGED
|
@@ -66,6 +66,7 @@ export const INFERENCE_PROVIDERS = [
|
|
|
66
66
|
"sambanova",
|
|
67
67
|
"scaleway",
|
|
68
68
|
"together",
|
|
69
|
+
"wavespeed",
|
|
69
70
|
"zai-org",
|
|
70
71
|
] as const;
|
|
71
72
|
|
|
@@ -75,6 +76,37 @@ export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
|
|
|
75
76
|
|
|
76
77
|
export type InferenceProviderOrPolicy = (typeof PROVIDERS_OR_POLICIES)[number];
|
|
77
78
|
|
|
79
|
+
/**
|
|
80
|
+
* The org namespace on the HF Hub i.e. hf.co/…
|
|
81
|
+
*
|
|
82
|
+
* Whenever possible, InferenceProvider should == org namespace
|
|
83
|
+
*/
|
|
84
|
+
export const PROVIDERS_HUB_ORGS: Record<InferenceProvider, string> = {
|
|
85
|
+
baseten: "baseten",
|
|
86
|
+
"black-forest-labs": "black-forest-labs",
|
|
87
|
+
cerebras: "cerebras",
|
|
88
|
+
clarifai: "clarifai",
|
|
89
|
+
cohere: "CohereLabs",
|
|
90
|
+
"fal-ai": "fal",
|
|
91
|
+
"featherless-ai": "featherless-ai",
|
|
92
|
+
"fireworks-ai": "fireworks-ai",
|
|
93
|
+
groq: "groq",
|
|
94
|
+
"hf-inference": "hf-inference",
|
|
95
|
+
hyperbolic: "Hyperbolic",
|
|
96
|
+
nebius: "nebius",
|
|
97
|
+
novita: "novita",
|
|
98
|
+
nscale: "nscale",
|
|
99
|
+
openai: "openai",
|
|
100
|
+
ovhcloud: "ovhcloud",
|
|
101
|
+
publicai: "publicai",
|
|
102
|
+
replicate: "replicate",
|
|
103
|
+
sambanova: "sambanovasystems",
|
|
104
|
+
scaleway: "scaleway",
|
|
105
|
+
together: "togethercomputer",
|
|
106
|
+
wavespeed: "wavespeed",
|
|
107
|
+
"zai-org": "zai-org",
|
|
108
|
+
};
|
|
109
|
+
|
|
78
110
|
export interface InferenceProviderMappingEntry {
|
|
79
111
|
adapter?: string;
|
|
80
112
|
adapterWeightsPath?: string;
|