@huggingface/inference 4.11.2 → 4.12.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. package/README.md +2 -0
  2. package/dist/commonjs/errors.d.ts +3 -0
  3. package/dist/commonjs/errors.d.ts.map +1 -1
  4. package/dist/commonjs/errors.js +8 -1
  5. package/dist/commonjs/lib/getInferenceProviderMapping.d.ts.map +1 -1
  6. package/dist/commonjs/lib/getInferenceProviderMapping.js +11 -0
  7. package/dist/commonjs/lib/getProviderHelper.d.ts.map +1 -1
  8. package/dist/commonjs/lib/getProviderHelper.js +6 -0
  9. package/dist/commonjs/package.d.ts +1 -1
  10. package/dist/commonjs/package.js +1 -1
  11. package/dist/commonjs/providers/consts.d.ts.map +1 -1
  12. package/dist/commonjs/providers/consts.js +1 -0
  13. package/dist/commonjs/providers/providerHelper.d.ts +5 -1
  14. package/dist/commonjs/providers/providerHelper.d.ts.map +1 -1
  15. package/dist/commonjs/providers/providerHelper.js +13 -1
  16. package/dist/commonjs/providers/wavespeed.d.ts +41 -0
  17. package/dist/commonjs/providers/wavespeed.d.ts.map +1 -0
  18. package/dist/commonjs/providers/wavespeed.js +103 -0
  19. package/dist/commonjs/tasks/nlp/chatCompletion.d.ts.map +1 -1
  20. package/dist/commonjs/tasks/nlp/chatCompletion.js +10 -2
  21. package/dist/commonjs/types.d.ts +8 -2
  22. package/dist/commonjs/types.d.ts.map +1 -1
  23. package/dist/commonjs/types.js +32 -1
  24. package/dist/esm/errors.d.ts +3 -0
  25. package/dist/esm/errors.d.ts.map +1 -1
  26. package/dist/esm/errors.js +6 -0
  27. package/dist/esm/lib/getInferenceProviderMapping.d.ts.map +1 -1
  28. package/dist/esm/lib/getInferenceProviderMapping.js +11 -0
  29. package/dist/esm/lib/getProviderHelper.d.ts.map +1 -1
  30. package/dist/esm/lib/getProviderHelper.js +6 -0
  31. package/dist/esm/package.d.ts +1 -1
  32. package/dist/esm/package.js +1 -1
  33. package/dist/esm/providers/consts.d.ts.map +1 -1
  34. package/dist/esm/providers/consts.js +1 -0
  35. package/dist/esm/providers/providerHelper.d.ts +5 -1
  36. package/dist/esm/providers/providerHelper.d.ts.map +1 -1
  37. package/dist/esm/providers/providerHelper.js +12 -1
  38. package/dist/esm/providers/wavespeed.d.ts +41 -0
  39. package/dist/esm/providers/wavespeed.d.ts.map +1 -0
  40. package/dist/esm/providers/wavespeed.js +97 -0
  41. package/dist/esm/tasks/nlp/chatCompletion.d.ts.map +1 -1
  42. package/dist/esm/tasks/nlp/chatCompletion.js +10 -2
  43. package/dist/esm/types.d.ts +8 -2
  44. package/dist/esm/types.d.ts.map +1 -1
  45. package/dist/esm/types.js +31 -0
  46. package/package.json +2 -2
  47. package/src/errors.ts +7 -0
  48. package/src/lib/getInferenceProviderMapping.ts +11 -0
  49. package/src/lib/getProviderHelper.ts +6 -0
  50. package/src/package.ts +1 -1
  51. package/src/providers/consts.ts +1 -0
  52. package/src/providers/providerHelper.ts +15 -2
  53. package/src/providers/wavespeed.ts +185 -0
  54. package/src/tasks/nlp/chatCompletion.ts +10 -2
  55. package/src/types.ts +32 -0
@@ -0,0 +1,185 @@
1
+ import type { TextToImageArgs } from "../tasks/cv/textToImage.js";
2
+ import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js";
3
+ import type { TextToVideoArgs } from "../tasks/cv/textToVideo.js";
4
+ import type { BodyParams, RequestArgs, UrlParams } from "../types.js";
5
+ import { delay } from "../utils/delay.js";
6
+ import { omit } from "../utils/omit.js";
7
+ import { base64FromBytes } from "../utils/base64FromBytes.js";
8
+ import type { TextToImageTaskHelper, TextToVideoTaskHelper, ImageToImageTaskHelper } from "./providerHelper.js";
9
+ import { TaskProviderHelper } from "./providerHelper.js";
10
+ import {
11
+ InferenceClientInputError,
12
+ InferenceClientProviderApiError,
13
+ InferenceClientProviderOutputError,
14
+ } from "../errors.js";
15
+
16
+ const WAVESPEEDAI_API_BASE_URL = "https://api.wavespeed.ai";
17
+
18
+ /**
19
+ * Response structure for task status and results
20
+ */
21
+ interface WaveSpeedAITaskResponse {
22
+ id: string;
23
+ model: string;
24
+ outputs: string[];
25
+ urls: {
26
+ get: string;
27
+ };
28
+ has_nsfw_contents: boolean[];
29
+ status: "created" | "processing" | "completed" | "failed";
30
+ created_at: string;
31
+ error: string;
32
+ executionTime: number;
33
+ timings: {
34
+ inference: number;
35
+ };
36
+ }
37
+
38
+ /**
39
+ * Response structure for initial task submission
40
+ */
41
+ interface WaveSpeedAISubmitResponse {
42
+ id: string;
43
+ urls: {
44
+ get: string;
45
+ };
46
+ }
47
+
48
+ /**
49
+ * Response structure for WaveSpeed AI API
50
+ */
51
+ interface WaveSpeedAIResponse {
52
+ code: number;
53
+ message: string;
54
+ data: WaveSpeedAITaskResponse;
55
+ }
56
+
57
+ /**
58
+ * Response structure for WaveSpeed AI API with submit response data
59
+ */
60
+ interface WaveSpeedAISubmitTaskResponse {
61
+ code: number;
62
+ message: string;
63
+ data: WaveSpeedAISubmitResponse;
64
+ }
65
+
66
+ abstract class WavespeedAITask extends TaskProviderHelper {
67
+ constructor(url?: string) {
68
+ super("wavespeed", url || WAVESPEEDAI_API_BASE_URL);
69
+ }
70
+
71
+ makeRoute(params: UrlParams): string {
72
+ return `/api/v3/${params.model}`;
73
+ }
74
+
75
+ preparePayload(params: BodyParams<ImageToImageArgs | TextToImageArgs | TextToVideoArgs>): Record<string, unknown> {
76
+ const payload: Record<string, unknown> = {
77
+ ...omit(params.args, ["inputs", "parameters"]),
78
+ ...params.args.parameters,
79
+ prompt: params.args.inputs,
80
+ };
81
+ // Add LoRA support if adapter is specified in the mapping
82
+ if (params.mapping?.adapter === "lora") {
83
+ payload.loras = [
84
+ {
85
+ path: params.mapping.hfModelId,
86
+ scale: 1, // Default scale value
87
+ },
88
+ ];
89
+ }
90
+ return payload;
91
+ }
92
+
93
+ override async getResponse(
94
+ response: WaveSpeedAISubmitTaskResponse,
95
+ url?: string,
96
+ headers?: Record<string, string>
97
+ ): Promise<Blob> {
98
+ if (!headers) {
99
+ throw new InferenceClientInputError("Headers are required for WaveSpeed AI API calls");
100
+ }
101
+
102
+ const resultUrl = response.data.urls.get;
103
+
104
+ // Poll for results until completion
105
+ while (true) {
106
+ const resultResponse = await fetch(resultUrl, { headers });
107
+
108
+ if (!resultResponse.ok) {
109
+ throw new InferenceClientProviderApiError(
110
+ "Failed to fetch response status from WaveSpeed AI API",
111
+ { url: resultUrl, method: "GET" },
112
+ {
113
+ requestId: resultResponse.headers.get("x-request-id") ?? "",
114
+ status: resultResponse.status,
115
+ body: await resultResponse.text(),
116
+ }
117
+ );
118
+ }
119
+
120
+ const result: WaveSpeedAIResponse = await resultResponse.json();
121
+ const taskResult = result.data;
122
+
123
+ switch (taskResult.status) {
124
+ case "completed": {
125
+ // Get the media data from the first output URL
126
+ if (!taskResult.outputs?.[0]) {
127
+ throw new InferenceClientProviderOutputError(
128
+ "Received malformed response from WaveSpeed AI API: No output URL in completed response"
129
+ );
130
+ }
131
+ const mediaResponse = await fetch(taskResult.outputs[0]);
132
+ if (!mediaResponse.ok) {
133
+ throw new InferenceClientProviderApiError(
134
+ "Failed to fetch generation output from WaveSpeed AI API",
135
+ { url: taskResult.outputs[0], method: "GET" },
136
+ {
137
+ requestId: mediaResponse.headers.get("x-request-id") ?? "",
138
+ status: mediaResponse.status,
139
+ body: await mediaResponse.text(),
140
+ }
141
+ );
142
+ }
143
+ return await mediaResponse.blob();
144
+ }
145
+ case "failed": {
146
+ throw new InferenceClientProviderOutputError(taskResult.error || "Task failed");
147
+ }
148
+
149
+ default: {
150
+ // Wait before polling again
151
+ await delay(500);
152
+ continue;
153
+ }
154
+ }
155
+ }
156
+ }
157
+ }
158
+
159
+ export class WavespeedAITextToImageTask extends WavespeedAITask implements TextToImageTaskHelper {
160
+ constructor() {
161
+ super(WAVESPEEDAI_API_BASE_URL);
162
+ }
163
+ }
164
+
165
+ export class WavespeedAITextToVideoTask extends WavespeedAITask implements TextToVideoTaskHelper {
166
+ constructor() {
167
+ super(WAVESPEEDAI_API_BASE_URL);
168
+ }
169
+ }
170
+
171
+ export class WavespeedAIImageToImageTask extends WavespeedAITask implements ImageToImageTaskHelper {
172
+ constructor() {
173
+ super(WAVESPEEDAI_API_BASE_URL);
174
+ }
175
+
176
+ async preparePayloadAsync(args: ImageToImageArgs): Promise<RequestArgs> {
177
+ return {
178
+ ...args,
179
+ inputs: args.parameters?.prompt,
180
+ image: base64FromBytes(
181
+ new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer())
182
+ ),
183
+ };
184
+ }
185
+ }
@@ -3,6 +3,8 @@ import { resolveProvider } from "../../lib/getInferenceProviderMapping.js";
3
3
  import { getProviderHelper } from "../../lib/getProviderHelper.js";
4
4
  import type { BaseArgs, Options } from "../../types.js";
5
5
  import { innerRequest } from "../../utils/request.js";
6
+ import type { ConversationalTaskHelper, TaskProviderHelper } from "../../providers/providerHelper.js";
7
+ import { AutoRouterConversationalTask } from "../../providers/providerHelper.js";
6
8
 
7
9
  /**
8
10
  * Use the chat completion endpoint to generate a response to a prompt, using OpenAI message completion API no stream
@@ -11,8 +13,14 @@ export async function chatCompletion(
11
13
  args: BaseArgs & ChatCompletionInput,
12
14
  options?: Options
13
15
  ): Promise<ChatCompletionOutput> {
14
- const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
15
- const providerHelper = getProviderHelper(provider, "conversational");
16
+ let providerHelper: ConversationalTaskHelper & TaskProviderHelper;
17
+ if (!args.provider || args.provider === "auto") {
18
+ // Special case: we have a dedicated auto-router for conversational models. No need to fetch provider mapping.
19
+ providerHelper = new AutoRouterConversationalTask();
20
+ } else {
21
+ const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
22
+ providerHelper = getProviderHelper(provider, "conversational");
23
+ }
16
24
  const { data: response } = await innerRequest<ChatCompletionOutput>(args, providerHelper, {
17
25
  ...options,
18
26
  task: "conversational",
package/src/types.ts CHANGED
@@ -66,6 +66,7 @@ export const INFERENCE_PROVIDERS = [
66
66
  "sambanova",
67
67
  "scaleway",
68
68
  "together",
69
+ "wavespeed",
69
70
  "zai-org",
70
71
  ] as const;
71
72
 
@@ -75,6 +76,37 @@ export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
75
76
 
76
77
  export type InferenceProviderOrPolicy = (typeof PROVIDERS_OR_POLICIES)[number];
77
78
 
79
+ /**
80
+ * The org namespace on the HF Hub i.e. hf.co/…
81
+ *
82
+ * Whenever possible, InferenceProvider should == org namespace
83
+ */
84
+ export const PROVIDERS_HUB_ORGS: Record<InferenceProvider, string> = {
85
+ baseten: "baseten",
86
+ "black-forest-labs": "black-forest-labs",
87
+ cerebras: "cerebras",
88
+ clarifai: "clarifai",
89
+ cohere: "CohereLabs",
90
+ "fal-ai": "fal",
91
+ "featherless-ai": "featherless-ai",
92
+ "fireworks-ai": "fireworks-ai",
93
+ groq: "groq",
94
+ "hf-inference": "hf-inference",
95
+ hyperbolic: "Hyperbolic",
96
+ nebius: "nebius",
97
+ novita: "novita",
98
+ nscale: "nscale",
99
+ openai: "openai",
100
+ ovhcloud: "ovhcloud",
101
+ publicai: "publicai",
102
+ replicate: "replicate",
103
+ sambanova: "sambanovasystems",
104
+ scaleway: "scaleway",
105
+ together: "togethercomputer",
106
+ wavespeed: "wavespeed",
107
+ "zai-org": "zai-org",
108
+ };
109
+
78
110
  export interface InferenceProviderMappingEntry {
79
111
  adapter?: string;
80
112
  adapterWeightsPath?: string;