@huggingface/tasks 0.14.0 → 0.15.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. package/dist/commonjs/index.d.ts +1 -0
  2. package/dist/commonjs/index.d.ts.map +1 -1
  3. package/dist/commonjs/index.js +1 -0
  4. package/dist/commonjs/inference-providers.d.ts +10 -0
  5. package/dist/commonjs/inference-providers.d.ts.map +1 -0
  6. package/dist/commonjs/inference-providers.js +16 -0
  7. package/dist/commonjs/snippets/curl.d.ts +8 -8
  8. package/dist/commonjs/snippets/curl.d.ts.map +1 -1
  9. package/dist/commonjs/snippets/curl.js +58 -30
  10. package/dist/commonjs/snippets/js.d.ts +11 -10
  11. package/dist/commonjs/snippets/js.d.ts.map +1 -1
  12. package/dist/commonjs/snippets/js.js +162 -53
  13. package/dist/commonjs/snippets/python.d.ts +12 -12
  14. package/dist/commonjs/snippets/python.d.ts.map +1 -1
  15. package/dist/commonjs/snippets/python.js +141 -71
  16. package/dist/commonjs/snippets/types.d.ts +1 -1
  17. package/dist/commonjs/snippets/types.d.ts.map +1 -1
  18. package/dist/esm/index.d.ts +1 -0
  19. package/dist/esm/index.d.ts.map +1 -1
  20. package/dist/esm/index.js +1 -0
  21. package/dist/esm/inference-providers.d.ts +10 -0
  22. package/dist/esm/inference-providers.d.ts.map +1 -0
  23. package/dist/esm/inference-providers.js +12 -0
  24. package/dist/esm/snippets/curl.d.ts +8 -8
  25. package/dist/esm/snippets/curl.d.ts.map +1 -1
  26. package/dist/esm/snippets/curl.js +58 -29
  27. package/dist/esm/snippets/js.d.ts +11 -10
  28. package/dist/esm/snippets/js.d.ts.map +1 -1
  29. package/dist/esm/snippets/js.js +159 -50
  30. package/dist/esm/snippets/python.d.ts +12 -12
  31. package/dist/esm/snippets/python.d.ts.map +1 -1
  32. package/dist/esm/snippets/python.js +140 -69
  33. package/dist/esm/snippets/types.d.ts +1 -1
  34. package/dist/esm/snippets/types.d.ts.map +1 -1
  35. package/package.json +1 -1
  36. package/src/index.ts +2 -0
  37. package/src/inference-providers.ts +16 -0
  38. package/src/snippets/curl.ts +72 -23
  39. package/src/snippets/js.ts +189 -56
  40. package/src/snippets/python.ts +154 -75
  41. package/src/snippets/types.ts +1 -1
@@ -1,17 +1,23 @@
1
+ import { HF_HUB_INFERENCE_PROXY_TEMPLATE, openAIbaseUrl, type InferenceProvider } from "../inference-providers.js";
1
2
  import type { PipelineType } from "../pipelines.js";
2
3
  import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
3
4
  import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
4
5
  import { getModelInputSnippet } from "./inputs.js";
5
6
  import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
6
7
 
7
- const snippetImportInferenceClient = (model: ModelDataMinimal, accessToken: string): string =>
8
- `from huggingface_hub import InferenceClient
9
- client = InferenceClient("${model.id}", token="${accessToken || "{API_TOKEN}"}")
10
- `;
8
+ const snippetImportInferenceClient = (accessToken: string, provider: InferenceProvider): string =>
9
+ `\
10
+ from huggingface_hub import InferenceClient
11
+
12
+ client = InferenceClient(
13
+ provider="${provider}",
14
+ api_key="${accessToken || "{API_TOKEN}"}"
15
+ )`;
11
16
 
12
17
  export const snippetConversational = (
13
18
  model: ModelDataMinimal,
14
19
  accessToken: string,
20
+ provider: InferenceProvider,
15
21
  opts?: {
16
22
  streaming?: boolean;
17
23
  messages?: ChatCompletionInputMessage[];
@@ -39,9 +45,8 @@ export const snippetConversational = (
39
45
  return [
40
46
  {
41
47
  client: "huggingface_hub",
42
- content: `from huggingface_hub import InferenceClient
43
-
44
- client = InferenceClient(api_key="${accessToken || "{API_TOKEN}"}")
48
+ content: `\
49
+ ${snippetImportInferenceClient(accessToken, provider)}
45
50
 
46
51
  messages = ${messagesStr}
47
52
 
@@ -60,7 +65,7 @@ for chunk in stream:
60
65
  content: `from openai import OpenAI
61
66
 
62
67
  client = OpenAI(
63
- base_url="https://api-inference.huggingface.co/v1/",
68
+ base_url="${openAIbaseUrl(provider)}",
64
69
  api_key="${accessToken || "{API_TOKEN}"}"
65
70
  )
66
71
 
@@ -81,9 +86,8 @@ for chunk in stream:
81
86
  return [
82
87
  {
83
88
  client: "huggingface_hub",
84
- content: `from huggingface_hub import InferenceClient
85
-
86
- client = InferenceClient(api_key="${accessToken || "{API_TOKEN}"}")
89
+ content: `\
90
+ ${snippetImportInferenceClient(accessToken, provider)}
87
91
 
88
92
  messages = ${messagesStr}
89
93
 
@@ -100,7 +104,7 @@ print(completion.choices[0].message)`,
100
104
  content: `from openai import OpenAI
101
105
 
102
106
  client = OpenAI(
103
- base_url="https://api-inference.huggingface.co/v1/",
107
+ base_url="${openAIbaseUrl(provider)}",
104
108
  api_key="${accessToken || "{API_TOKEN}"}"
105
109
  )
106
110
 
@@ -118,8 +122,11 @@ print(completion.choices[0].message)`,
118
122
  }
119
123
  };
120
124
 
121
- export const snippetZeroShotClassification = (model: ModelDataMinimal): InferenceSnippet => ({
122
- content: `def query(payload):
125
+ export const snippetZeroShotClassification = (model: ModelDataMinimal): InferenceSnippet[] => {
126
+ return [
127
+ {
128
+ client: "requests",
129
+ content: `def query(payload):
123
130
  response = requests.post(API_URL, headers=headers, json=payload)
124
131
  return response.json()
125
132
 
@@ -127,10 +134,15 @@ output = query({
127
134
  "inputs": ${getModelInputSnippet(model)},
128
135
  "parameters": {"candidate_labels": ["refund", "legal", "faq"]},
129
136
  })`,
130
- });
137
+ },
138
+ ];
139
+ };
131
140
 
132
- export const snippetZeroShotImageClassification = (model: ModelDataMinimal): InferenceSnippet => ({
133
- content: `def query(data):
141
+ export const snippetZeroShotImageClassification = (model: ModelDataMinimal): InferenceSnippet[] => {
142
+ return [
143
+ {
144
+ client: "requests",
145
+ content: `def query(data):
134
146
  with open(data["image_path"], "rb") as f:
135
147
  img = f.read()
136
148
  payload={
@@ -144,40 +156,85 @@ output = query({
144
156
  "image_path": ${getModelInputSnippet(model)},
145
157
  "parameters": {"candidate_labels": ["cat", "dog", "llama"]},
146
158
  })`,
147
- });
159
+ },
160
+ ];
161
+ };
148
162
 
149
- export const snippetBasic = (model: ModelDataMinimal): InferenceSnippet => ({
150
- content: `def query(payload):
163
+ export const snippetBasic = (model: ModelDataMinimal): InferenceSnippet[] => {
164
+ return [
165
+ {
166
+ client: "requests",
167
+ content: `def query(payload):
151
168
  response = requests.post(API_URL, headers=headers, json=payload)
152
169
  return response.json()
153
170
 
154
171
  output = query({
155
172
  "inputs": ${getModelInputSnippet(model)},
156
173
  })`,
157
- });
158
-
159
- export const snippetFile = (model: ModelDataMinimal): InferenceSnippet => ({
160
- content: `def query(filename):
161
- with open(filename, "rb") as f:
162
- data = f.read()
163
- response = requests.post(API_URL, headers=headers, data=data)
164
- return response.json()
165
-
166
- output = query(${getModelInputSnippet(model)})`,
167
- });
168
-
169
- export const snippetTextToImage = (model: ModelDataMinimal, accessToken: string): InferenceSnippet[] => [
170
- {
171
- client: "huggingface_hub",
172
- content: `${snippetImportInferenceClient(model, accessToken)}
174
+ },
175
+ ];
176
+ };
177
+
178
+ export const snippetFile = (model: ModelDataMinimal): InferenceSnippet[] => {
179
+ return [
180
+ {
181
+ client: "requests",
182
+ content: `def query(filename):
183
+ with open(filename, "rb") as f:
184
+ data = f.read()
185
+ response = requests.post(API_URL, headers=headers, data=data)
186
+ return response.json()
187
+
188
+ output = query(${getModelInputSnippet(model)})`,
189
+ },
190
+ ];
191
+ };
192
+
193
+ export const snippetTextToImage = (
194
+ model: ModelDataMinimal,
195
+ accessToken: string,
196
+ provider: InferenceProvider
197
+ ): InferenceSnippet[] => {
198
+ return [
199
+ {
200
+ client: "huggingface_hub",
201
+ content: `\
202
+ ${snippetImportInferenceClient(accessToken, provider)}
203
+
173
204
  # output is a PIL.Image object
174
- image = client.text_to_image(${getModelInputSnippet(model)})`,
205
+ image = client.text_to_image(
206
+ ${getModelInputSnippet(model)},
207
+ model="${model.id}"
208
+ )`,
209
+ },
210
+ ...(provider === "fal-ai"
211
+ ? [
212
+ {
213
+ client: "fal-client",
214
+ content: `\
215
+ import fal_client
216
+
217
+ result = fal_client.subscribe(
218
+ # replace with correct id from fal.ai
219
+ "fal-ai/${model.id}",
220
+ arguments={
221
+ "prompt": ${getModelInputSnippet(model)},
175
222
  },
176
- {
177
- client: "requests",
178
- content: `def query(payload):
223
+ )
224
+ print(result)
225
+ `,
226
+ },
227
+ ]
228
+ : []),
229
+ ...(provider === "hf-inference"
230
+ ? [
231
+ {
232
+ client: "requests",
233
+ content: `\
234
+ def query(payload):
179
235
  response = requests.post(API_URL, headers=headers, json=payload)
180
236
  return response.content
237
+
181
238
  image_bytes = query({
182
239
  "inputs": ${getModelInputSnippet(model)},
183
240
  })
@@ -186,25 +243,35 @@ image_bytes = query({
186
243
  import io
187
244
  from PIL import Image
188
245
  image = Image.open(io.BytesIO(image_bytes))`,
189
- },
190
- ];
246
+ },
247
+ ]
248
+ : []),
249
+ ];
250
+ };
191
251
 
192
- export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet => ({
193
- content: `def query(payload):
194
- response = requests.post(API_URL, headers=headers, json=payload)
195
- return response.content
196
- response = query({
197
- "inputs": {"data": ${getModelInputSnippet(model)}},
198
- })`,
199
- });
252
+ export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet[] => {
253
+ return [
254
+ {
255
+ client: "requests",
256
+ content: `def query(payload):
257
+ response = requests.post(API_URL, headers=headers, json=payload)
258
+ return response.content
259
+ response = query({
260
+ "inputs": {"data": ${getModelInputSnippet(model)}},
261
+ })`,
262
+ },
263
+ ];
264
+ };
200
265
 
201
- export const snippetTextToAudio = (model: ModelDataMinimal): InferenceSnippet => {
266
+ export const snippetTextToAudio = (model: ModelDataMinimal): InferenceSnippet[] => {
202
267
  // Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged
203
268
  // with the latest update to inference-api (IA).
204
269
  // Transformers IA returns a byte object (wav file), whereas AIC returns wav and sampling_rate.
205
270
  if (model.library_name === "transformers") {
206
- return {
207
- content: `def query(payload):
271
+ return [
272
+ {
273
+ client: "requests",
274
+ content: `def query(payload):
208
275
  response = requests.post(API_URL, headers=headers, json=payload)
209
276
  return response.content
210
277
 
@@ -214,10 +281,13 @@ audio_bytes = query({
214
281
  # You can access the audio with IPython.display for example
215
282
  from IPython.display import Audio
216
283
  Audio(audio_bytes)`,
217
- };
284
+ },
285
+ ];
218
286
  } else {
219
- return {
220
- content: `def query(payload):
287
+ return [
288
+ {
289
+ client: "requests",
290
+ content: `def query(payload):
221
291
  response = requests.post(API_URL, headers=headers, json=payload)
222
292
  return response.json()
223
293
 
@@ -227,12 +297,16 @@ audio, sampling_rate = query({
227
297
  # You can access the audio with IPython.display for example
228
298
  from IPython.display import Audio
229
299
  Audio(audio, rate=sampling_rate)`,
230
- };
300
+ },
301
+ ];
231
302
  }
232
303
  };
233
304
 
234
- export const snippetDocumentQuestionAnswering = (model: ModelDataMinimal): InferenceSnippet => ({
235
- content: `def query(payload):
305
+ export const snippetDocumentQuestionAnswering = (model: ModelDataMinimal): InferenceSnippet[] => {
306
+ return [
307
+ {
308
+ client: "requests",
309
+ content: `def query(payload):
236
310
  with open(payload["image"], "rb") as f:
237
311
  img = f.read()
238
312
  payload["image"] = base64.b64encode(img).decode("utf-8")
@@ -242,7 +316,9 @@ export const snippetDocumentQuestionAnswering = (model: ModelDataMinimal): Infer
242
316
  output = query({
243
317
  "inputs": ${getModelInputSnippet(model)},
244
318
  })`,
245
- });
319
+ },
320
+ ];
321
+ };
246
322
 
247
323
  export const pythonSnippets: Partial<
248
324
  Record<
@@ -250,8 +326,9 @@ export const pythonSnippets: Partial<
250
326
  (
251
327
  model: ModelDataMinimal,
252
328
  accessToken: string,
329
+ provider: InferenceProvider,
253
330
  opts?: Record<string, unknown>
254
- ) => InferenceSnippet | InferenceSnippet[]
331
+ ) => InferenceSnippet[]
255
332
  >
256
333
  > = {
257
334
  // Same order as in tasks/src/pipelines.ts
@@ -287,35 +364,37 @@ export const pythonSnippets: Partial<
287
364
  export function getPythonInferenceSnippet(
288
365
  model: ModelDataMinimal,
289
366
  accessToken: string,
367
+ provider: InferenceProvider,
290
368
  opts?: Record<string, unknown>
291
- ): InferenceSnippet | InferenceSnippet[] {
369
+ ): InferenceSnippet[] {
292
370
  if (model.tags.includes("conversational")) {
293
371
  // Conversational model detected, so we display a code snippet that features the Messages API
294
- return snippetConversational(model, accessToken, opts);
372
+ return snippetConversational(model, accessToken, provider, opts);
295
373
  } else {
296
- let snippets =
374
+ const snippets =
297
375
  model.pipeline_tag && model.pipeline_tag in pythonSnippets
298
- ? pythonSnippets[model.pipeline_tag]?.(model, accessToken) ?? { content: "" }
299
- : { content: "" };
376
+ ? pythonSnippets[model.pipeline_tag]?.(model, accessToken, provider) ?? []
377
+ : [];
300
378
 
301
- snippets = Array.isArray(snippets) ? snippets : [snippets];
379
+ const baseUrl =
380
+ provider === "hf-inference"
381
+ ? `https://api-inference.huggingface.co/models/${model.id}`
382
+ : HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider);
302
383
 
303
384
  return snippets.map((snippet) => {
304
385
  return {
305
386
  ...snippet,
306
- content: snippet.content.includes("requests")
307
- ? `import requests
387
+ content:
388
+ snippet.client === "requests"
389
+ ? `\
390
+ import requests
308
391
 
309
- API_URL = "https://api-inference.huggingface.co/models/${model.id}"
392
+ API_URL = "${baseUrl}"
310
393
  headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
311
394
 
312
395
  ${snippet.content}`
313
- : snippet.content,
396
+ : snippet.content,
314
397
  };
315
398
  });
316
399
  }
317
400
  }
318
-
319
- export function hasPythonInferenceSnippet(model: ModelDataMinimal): boolean {
320
- return !!model.pipeline_tag && model.pipeline_tag in pythonSnippets;
321
- }
@@ -12,5 +12,5 @@ export type ModelDataMinimal = Pick<
12
12
 
13
13
  export interface InferenceSnippet {
14
14
  content: string;
15
- client?: string; // for instance: `client` could be `huggingface_hub` or `openai` client for Python snippets
15
+ client: string; // for instance: `client` could be `huggingface_hub` or `openai` client for Python snippets
16
16
  }