@huggingface/inference 3.4.0 → 3.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,487 @@
1
+ import { openAIbaseUrl, type SnippetInferenceProvider } from "@huggingface/tasks";
2
+ import type { PipelineType, WidgetType } from "@huggingface/tasks/src/pipelines.js";
3
+ import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingface/tasks/src/tasks/index.js";
4
+ import {
5
+ type InferenceSnippet,
6
+ type ModelDataMinimal,
7
+ getModelInputSnippet,
8
+ stringifyGenerationConfig,
9
+ stringifyMessages,
10
+ } from "@huggingface/tasks";
11
+
12
+ const HFH_INFERENCE_CLIENT_METHODS: Partial<Record<WidgetType, string>> = {
13
+ "audio-classification": "audio_classification",
14
+ "audio-to-audio": "audio_to_audio",
15
+ "automatic-speech-recognition": "automatic_speech_recognition",
16
+ "text-to-speech": "text_to_speech",
17
+ "image-classification": "image_classification",
18
+ "image-segmentation": "image_segmentation",
19
+ "image-to-image": "image_to_image",
20
+ "image-to-text": "image_to_text",
21
+ "object-detection": "object_detection",
22
+ "text-to-image": "text_to_image",
23
+ "text-to-video": "text_to_video",
24
+ "zero-shot-image-classification": "zero_shot_image_classification",
25
+ "document-question-answering": "document_question_answering",
26
+ "visual-question-answering": "visual_question_answering",
27
+ "feature-extraction": "feature_extraction",
28
+ "fill-mask": "fill_mask",
29
+ "question-answering": "question_answering",
30
+ "sentence-similarity": "sentence_similarity",
31
+ summarization: "summarization",
32
+ "table-question-answering": "table_question_answering",
33
+ "text-classification": "text_classification",
34
+ "text-generation": "text_generation",
35
+ "token-classification": "token_classification",
36
+ translation: "translation",
37
+ "zero-shot-classification": "zero_shot_classification",
38
+ "tabular-classification": "tabular_classification",
39
+ "tabular-regression": "tabular_regression",
40
+ };
41
+
42
+ const snippetImportInferenceClient = (accessToken: string, provider: SnippetInferenceProvider): string =>
43
+ `\
44
+ from huggingface_hub import InferenceClient
45
+
46
+ client = InferenceClient(
47
+ provider="${provider}",
48
+ api_key="${accessToken || "{API_TOKEN}"}"
49
+ )`;
50
+
51
+ export const snippetConversational = (
52
+ model: ModelDataMinimal,
53
+ accessToken: string,
54
+ provider: SnippetInferenceProvider,
55
+ providerModelId?: string,
56
+ opts?: {
57
+ streaming?: boolean;
58
+ messages?: ChatCompletionInputMessage[];
59
+ temperature?: GenerationParameters["temperature"];
60
+ max_tokens?: GenerationParameters["max_tokens"];
61
+ top_p?: GenerationParameters["top_p"];
62
+ }
63
+ ): InferenceSnippet[] => {
64
+ const streaming = opts?.streaming ?? true;
65
+ const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[];
66
+ const messages = opts?.messages ?? exampleMessages;
67
+ const messagesStr = stringifyMessages(messages, { attributeKeyQuotes: true });
68
+
69
+ const config = {
70
+ ...(opts?.temperature ? { temperature: opts.temperature } : undefined),
71
+ max_tokens: opts?.max_tokens ?? 500,
72
+ ...(opts?.top_p ? { top_p: opts.top_p } : undefined),
73
+ };
74
+ const configStr = stringifyGenerationConfig(config, {
75
+ indent: "\n\t",
76
+ attributeValueConnector: "=",
77
+ });
78
+
79
+ if (streaming) {
80
+ return [
81
+ {
82
+ client: "huggingface_hub",
83
+ content: `\
84
+ ${snippetImportInferenceClient(accessToken, provider)}
85
+
86
+ messages = ${messagesStr}
87
+
88
+ stream = client.chat.completions.create(
89
+ model="${model.id}",
90
+ messages=messages,
91
+ ${configStr}
92
+ stream=True
93
+ )
94
+
95
+ for chunk in stream:
96
+ print(chunk.choices[0].delta.content, end="")`,
97
+ },
98
+ {
99
+ client: "openai",
100
+ content: `\
101
+ from openai import OpenAI
102
+
103
+ client = OpenAI(
104
+ base_url="${openAIbaseUrl(provider)}",
105
+ api_key="${accessToken || "{API_TOKEN}"}"
106
+ )
107
+
108
+ messages = ${messagesStr}
109
+
110
+ stream = client.chat.completions.create(
111
+ model="${providerModelId ?? model.id}",
112
+ messages=messages,
113
+ ${configStr}
114
+ stream=True
115
+ )
116
+
117
+ for chunk in stream:
118
+ print(chunk.choices[0].delta.content, end="")`,
119
+ },
120
+ ];
121
+ } else {
122
+ return [
123
+ {
124
+ client: "huggingface_hub",
125
+ content: `\
126
+ ${snippetImportInferenceClient(accessToken, provider)}
127
+
128
+ messages = ${messagesStr}
129
+
130
+ completion = client.chat.completions.create(
131
+ model="${model.id}",
132
+ messages=messages,
133
+ ${configStr}
134
+ )
135
+
136
+ print(completion.choices[0].message)`,
137
+ },
138
+ {
139
+ client: "openai",
140
+ content: `\
141
+ from openai import OpenAI
142
+
143
+ client = OpenAI(
144
+ base_url="${openAIbaseUrl(provider)}",
145
+ api_key="${accessToken || "{API_TOKEN}"}"
146
+ )
147
+
148
+ messages = ${messagesStr}
149
+
150
+ completion = client.chat.completions.create(
151
+ model="${providerModelId ?? model.id}",
152
+ messages=messages,
153
+ ${configStr}
154
+ )
155
+
156
+ print(completion.choices[0].message)`,
157
+ },
158
+ ];
159
+ }
160
+ };
161
+
162
+ export const snippetZeroShotClassification = (model: ModelDataMinimal): InferenceSnippet[] => {
163
+ return [
164
+ {
165
+ client: "requests",
166
+ content: `\
167
+ def query(payload):
168
+ response = requests.post(API_URL, headers=headers, json=payload)
169
+ return response.json()
170
+
171
+ output = query({
172
+ "inputs": ${getModelInputSnippet(model)},
173
+ "parameters": {"candidate_labels": ["refund", "legal", "faq"]},
174
+ })`,
175
+ },
176
+ ];
177
+ };
178
+
179
+ export const snippetZeroShotImageClassification = (model: ModelDataMinimal): InferenceSnippet[] => {
180
+ return [
181
+ {
182
+ client: "requests",
183
+ content: `\
184
+ def query(data):
185
+ with open(data["image_path"], "rb") as f:
186
+ img = f.read()
187
+ payload={
188
+ "parameters": data["parameters"],
189
+ "inputs": base64.b64encode(img).decode("utf-8")
190
+ }
191
+ response = requests.post(API_URL, headers=headers, json=payload)
192
+ return response.json()
193
+
194
+ output = query({
195
+ "image_path": ${getModelInputSnippet(model)},
196
+ "parameters": {"candidate_labels": ["cat", "dog", "llama"]},
197
+ })`,
198
+ },
199
+ ];
200
+ };
201
+
202
+ export const snippetBasic = (
203
+ model: ModelDataMinimal,
204
+ accessToken: string,
205
+ provider: SnippetInferenceProvider
206
+ ): InferenceSnippet[] => {
207
+ return [
208
+ ...(model.pipeline_tag && model.pipeline_tag in HFH_INFERENCE_CLIENT_METHODS
209
+ ? [
210
+ {
211
+ client: "huggingface_hub",
212
+ content: `\
213
+ ${snippetImportInferenceClient(accessToken, provider)}
214
+
215
+ result = client.${HFH_INFERENCE_CLIENT_METHODS[model.pipeline_tag]}(
216
+ model="${model.id}",
217
+ inputs=${getModelInputSnippet(model)},
218
+ provider="${provider}",
219
+ )
220
+
221
+ print(result)
222
+ `,
223
+ },
224
+ ]
225
+ : []),
226
+ {
227
+ client: "requests",
228
+ content: `\
229
+ def query(payload):
230
+ response = requests.post(API_URL, headers=headers, json=payload)
231
+ return response.json()
232
+
233
+ output = query({
234
+ "inputs": ${getModelInputSnippet(model)},
235
+ })`,
236
+ },
237
+ ];
238
+ };
239
+
240
+ export const snippetFile = (model: ModelDataMinimal): InferenceSnippet[] => {
241
+ return [
242
+ {
243
+ client: "requests",
244
+ content: `\
245
+ def query(filename):
246
+ with open(filename, "rb") as f:
247
+ data = f.read()
248
+ response = requests.post(API_URL, headers=headers, data=data)
249
+ return response.json()
250
+
251
+ output = query(${getModelInputSnippet(model)})`,
252
+ },
253
+ ];
254
+ };
255
+
256
+ export const snippetTextToImage = (
257
+ model: ModelDataMinimal,
258
+ accessToken: string,
259
+ provider: SnippetInferenceProvider,
260
+ providerModelId?: string
261
+ ): InferenceSnippet[] => {
262
+ return [
263
+ {
264
+ client: "huggingface_hub",
265
+ content: `\
266
+ ${snippetImportInferenceClient(accessToken, provider)}
267
+
268
+ # output is a PIL.Image object
269
+ image = client.text_to_image(
270
+ ${getModelInputSnippet(model)},
271
+ model="${model.id}"
272
+ )`,
273
+ },
274
+ ...(provider === "fal-ai"
275
+ ? [
276
+ {
277
+ client: "fal-client",
278
+ content: `\
279
+ import fal_client
280
+
281
+ result = fal_client.subscribe(
282
+ "${providerModelId ?? model.id}",
283
+ arguments={
284
+ "prompt": ${getModelInputSnippet(model)},
285
+ },
286
+ )
287
+ print(result)
288
+ `,
289
+ },
290
+ ]
291
+ : []),
292
+ ...(provider === "hf-inference"
293
+ ? [
294
+ {
295
+ client: "requests",
296
+ content: `\
297
+ def query(payload):
298
+ response = requests.post(API_URL, headers=headers, json=payload)
299
+ return response.content
300
+
301
+ image_bytes = query({
302
+ "inputs": ${getModelInputSnippet(model)},
303
+ })
304
+
305
+ # You can access the image with PIL.Image for example
306
+ import io
307
+ from PIL import Image
308
+ image = Image.open(io.BytesIO(image_bytes))`,
309
+ },
310
+ ]
311
+ : []),
312
+ ];
313
+ };
314
+
315
+ export const snippetTextToVideo = (
316
+ model: ModelDataMinimal,
317
+ accessToken: string,
318
+ provider: SnippetInferenceProvider
319
+ ): InferenceSnippet[] => {
320
+ return ["fal-ai", "replicate"].includes(provider)
321
+ ? [
322
+ {
323
+ client: "huggingface_hub",
324
+ content: `\
325
+ ${snippetImportInferenceClient(accessToken, provider)}
326
+
327
+ video = client.text_to_video(
328
+ ${getModelInputSnippet(model)},
329
+ model="${model.id}"
330
+ )`,
331
+ },
332
+ ]
333
+ : [];
334
+ };
335
+
336
+ export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet[] => {
337
+ return [
338
+ {
339
+ client: "requests",
340
+ content: `\
341
+ def query(payload):
342
+ response = requests.post(API_URL, headers=headers, json=payload)
343
+ return response.content
344
+
345
+ response = query({
346
+ "inputs": {"data": ${getModelInputSnippet(model)}},
347
+ })`,
348
+ },
349
+ ];
350
+ };
351
+
352
+ export const snippetTextToAudio = (model: ModelDataMinimal): InferenceSnippet[] => {
353
+ // Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged
354
+ // with the latest update to inference-api (IA).
355
+ // Transformers IA returns a byte object (wav file), whereas AIC returns wav and sampling_rate.
356
+ if (model.library_name === "transformers") {
357
+ return [
358
+ {
359
+ client: "requests",
360
+ content: `\
361
+ def query(payload):
362
+ response = requests.post(API_URL, headers=headers, json=payload)
363
+ return response.content
364
+
365
+ audio_bytes = query({
366
+ "inputs": ${getModelInputSnippet(model)},
367
+ })
368
+ # You can access the audio with IPython.display for example
369
+ from IPython.display import Audio
370
+ Audio(audio_bytes)`,
371
+ },
372
+ ];
373
+ } else {
374
+ return [
375
+ {
376
+ client: "requests",
377
+ content: `\
378
+ def query(payload):
379
+ response = requests.post(API_URL, headers=headers, json=payload)
380
+ return response.json()
381
+
382
+ audio, sampling_rate = query({
383
+ "inputs": ${getModelInputSnippet(model)},
384
+ })
385
+ # You can access the audio with IPython.display for example
386
+ from IPython.display import Audio
387
+ Audio(audio, rate=sampling_rate)`,
388
+ },
389
+ ];
390
+ }
391
+ };
392
+
393
+ export const snippetDocumentQuestionAnswering = (model: ModelDataMinimal): InferenceSnippet[] => {
394
+ return [
395
+ {
396
+ client: "requests",
397
+ content: `\
398
+ def query(payload):
399
+ with open(payload["image"], "rb") as f:
400
+ img = f.read()
401
+ payload["image"] = base64.b64encode(img).decode("utf-8")
402
+ response = requests.post(API_URL, headers=headers, json=payload)
403
+ return response.json()
404
+
405
+ output = query({
406
+ "inputs": ${getModelInputSnippet(model)},
407
+ })`,
408
+ },
409
+ ];
410
+ };
411
+
412
+ export const pythonSnippets: Partial<
413
+ Record<
414
+ PipelineType,
415
+ (
416
+ model: ModelDataMinimal,
417
+ accessToken: string,
418
+ provider: SnippetInferenceProvider,
419
+ providerModelId?: string,
420
+ opts?: Record<string, unknown>
421
+ ) => InferenceSnippet[]
422
+ >
423
+ > = {
424
+ // Same order as in tasks/src/pipelines.ts
425
+ "text-classification": snippetBasic,
426
+ "token-classification": snippetBasic,
427
+ "table-question-answering": snippetBasic,
428
+ "question-answering": snippetBasic,
429
+ "zero-shot-classification": snippetZeroShotClassification,
430
+ translation: snippetBasic,
431
+ summarization: snippetBasic,
432
+ "feature-extraction": snippetBasic,
433
+ "text-generation": snippetBasic,
434
+ "text2text-generation": snippetBasic,
435
+ "image-text-to-text": snippetConversational,
436
+ "fill-mask": snippetBasic,
437
+ "sentence-similarity": snippetBasic,
438
+ "automatic-speech-recognition": snippetFile,
439
+ "text-to-image": snippetTextToImage,
440
+ "text-to-video": snippetTextToVideo,
441
+ "text-to-speech": snippetTextToAudio,
442
+ "text-to-audio": snippetTextToAudio,
443
+ "audio-to-audio": snippetFile,
444
+ "audio-classification": snippetFile,
445
+ "image-classification": snippetFile,
446
+ "tabular-regression": snippetTabular,
447
+ "tabular-classification": snippetTabular,
448
+ "object-detection": snippetFile,
449
+ "image-segmentation": snippetFile,
450
+ "document-question-answering": snippetDocumentQuestionAnswering,
451
+ "image-to-text": snippetFile,
452
+ "zero-shot-image-classification": snippetZeroShotImageClassification,
453
+ };
454
+
455
+ export function getPythonInferenceSnippet(
456
+ model: ModelDataMinimal,
457
+ accessToken: string,
458
+ provider: SnippetInferenceProvider,
459
+ providerModelId?: string,
460
+ opts?: Record<string, unknown>
461
+ ): InferenceSnippet[] {
462
+ if (model.tags.includes("conversational")) {
463
+ // Conversational model detected, so we display a code snippet that features the Messages API
464
+ return snippetConversational(model, accessToken, provider, providerModelId, opts);
465
+ } else {
466
+ const snippets =
467
+ model.pipeline_tag && model.pipeline_tag in pythonSnippets
468
+ ? pythonSnippets[model.pipeline_tag]?.(model, accessToken, provider, providerModelId) ?? []
469
+ : [];
470
+
471
+ return snippets.map((snippet) => {
472
+ return {
473
+ ...snippet,
474
+ content:
475
+ snippet.client === "requests"
476
+ ? `\
477
+ import requests
478
+
479
+ API_URL = "${openAIbaseUrl(provider)}"
480
+ headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
481
+
482
+ ${snippet.content}`
483
+ : snippet.content,
484
+ };
485
+ });
486
+ }
487
+ }
package/src/types.ts CHANGED
@@ -1,4 +1,4 @@
1
- import type { ChatCompletionInput, FeatureExtractionInput, PipelineType } from "@huggingface/tasks";
1
+ import type { ChatCompletionInput, PipelineType } from "@huggingface/tasks";
2
2
 
3
3
  /**
4
4
  * HF model id, like "meta-llama/Llama-3.3-70B-Instruct"
@@ -30,6 +30,7 @@ export type InferenceTask = Exclude<PipelineType, "other">;
30
30
 
31
31
  export const INFERENCE_PROVIDERS = [
32
32
  "black-forest-labs",
33
+ "cerebras",
33
34
  "cohere",
34
35
  "fal-ai",
35
36
  "fireworks-ai",
@@ -37,6 +38,7 @@ export const INFERENCE_PROVIDERS = [
37
38
  "hyperbolic",
38
39
  "nebius",
39
40
  "novita",
41
+ "openai",
40
42
  "replicate",
41
43
  "sambanova",
42
44
  "together",
@@ -87,7 +89,6 @@ export type RequestArgs = BaseArgs &
87
89
  | { text: string }
88
90
  | { audio_url: string }
89
91
  | ChatCompletionInput
90
- | FeatureExtractionInput
91
92
  ) & {
92
93
  parameters?: Record<string, unknown>;
93
94
  };
@@ -97,6 +98,7 @@ export interface ProviderConfig {
97
98
  makeBody: (params: BodyParams) => Record<string, unknown>;
98
99
  makeHeaders: (params: HeaderParams) => Record<string, string>;
99
100
  makeUrl: (params: UrlParams) => string;
101
+ clientSideRoutingOnly?: boolean;
100
102
  }
101
103
 
102
104
  export interface HeaderParams {