@huggingface/tasks 0.14.0 → 0.15.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. package/dist/commonjs/index.d.ts +1 -0
  2. package/dist/commonjs/index.d.ts.map +1 -1
  3. package/dist/commonjs/index.js +1 -0
  4. package/dist/commonjs/inference-providers.d.ts +10 -0
  5. package/dist/commonjs/inference-providers.d.ts.map +1 -0
  6. package/dist/commonjs/inference-providers.js +16 -0
  7. package/dist/commonjs/snippets/curl.d.ts +8 -8
  8. package/dist/commonjs/snippets/curl.d.ts.map +1 -1
  9. package/dist/commonjs/snippets/curl.js +58 -30
  10. package/dist/commonjs/snippets/js.d.ts +11 -10
  11. package/dist/commonjs/snippets/js.d.ts.map +1 -1
  12. package/dist/commonjs/snippets/js.js +162 -53
  13. package/dist/commonjs/snippets/python.d.ts +12 -12
  14. package/dist/commonjs/snippets/python.d.ts.map +1 -1
  15. package/dist/commonjs/snippets/python.js +141 -71
  16. package/dist/commonjs/snippets/types.d.ts +1 -1
  17. package/dist/commonjs/snippets/types.d.ts.map +1 -1
  18. package/dist/esm/index.d.ts +1 -0
  19. package/dist/esm/index.d.ts.map +1 -1
  20. package/dist/esm/index.js +1 -0
  21. package/dist/esm/inference-providers.d.ts +10 -0
  22. package/dist/esm/inference-providers.d.ts.map +1 -0
  23. package/dist/esm/inference-providers.js +12 -0
  24. package/dist/esm/snippets/curl.d.ts +8 -8
  25. package/dist/esm/snippets/curl.d.ts.map +1 -1
  26. package/dist/esm/snippets/curl.js +58 -29
  27. package/dist/esm/snippets/js.d.ts +11 -10
  28. package/dist/esm/snippets/js.d.ts.map +1 -1
  29. package/dist/esm/snippets/js.js +159 -50
  30. package/dist/esm/snippets/python.d.ts +12 -12
  31. package/dist/esm/snippets/python.d.ts.map +1 -1
  32. package/dist/esm/snippets/python.js +140 -69
  33. package/dist/esm/snippets/types.d.ts +1 -1
  34. package/dist/esm/snippets/types.d.ts.map +1 -1
  35. package/package.json +1 -1
  36. package/src/index.ts +2 -0
  37. package/src/inference-providers.ts +16 -0
  38. package/src/snippets/curl.ts +72 -23
  39. package/src/snippets/js.ts +189 -56
  40. package/src/snippets/python.ts +154 -75
  41. package/src/snippets/types.ts +1 -1
@@ -1,11 +1,54 @@
1
+ import { openAIbaseUrl, type InferenceProvider } from "../inference-providers.js";
1
2
  import type { PipelineType } from "../pipelines.js";
2
3
  import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
3
4
  import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
4
5
  import { getModelInputSnippet } from "./inputs.js";
5
6
  import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
6
7
 
7
- export const snippetBasic = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
8
- content: `async function query(data) {
8
+ const HFJS_METHODS: Record<string, string> = {
9
+ "text-classification": "textClassification",
10
+ "token-classification": "tokenClassification",
11
+ "table-question-answering": "tableQuestionAnswering",
12
+ "question-answering": "questionAnswering",
13
+ translation: "translation",
14
+ summarization: "summarization",
15
+ "feature-extraction": "featureExtraction",
16
+ "text-generation": "textGeneration",
17
+ "text2text-generation": "textGeneration",
18
+ "fill-mask": "fillMask",
19
+ "sentence-similarity": "sentenceSimilarity",
20
+ };
21
+
22
+ export const snippetBasic = (
23
+ model: ModelDataMinimal,
24
+ accessToken: string,
25
+ provider: InferenceProvider
26
+ ): InferenceSnippet[] => {
27
+ return [
28
+ ...(model.pipeline_tag && model.pipeline_tag in HFJS_METHODS
29
+ ? [
30
+ {
31
+ client: "huggingface.js",
32
+ content: `\
33
+ import { HfInference } from "@huggingface/inference";
34
+
35
+ const client = new HfInference("${accessToken || `{API_TOKEN}`}");
36
+
37
+ const output = await client.${HFJS_METHODS[model.pipeline_tag]}({
38
+ model: "${model.id}",
39
+ inputs: ${getModelInputSnippet(model)},
40
+ provider: "${provider}",
41
+ });
42
+
43
+ console.log(output)
44
+ `,
45
+ },
46
+ ]
47
+ : []),
48
+ {
49
+ client: "fetch",
50
+ content: `\
51
+ async function query(data) {
9
52
  const response = await fetch(
10
53
  "https://api-inference.huggingface.co/models/${model.id}",
11
54
  {
@@ -24,11 +67,14 @@ export const snippetBasic = (model: ModelDataMinimal, accessToken: string): Infe
24
67
  query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
25
68
  console.log(JSON.stringify(response));
26
69
  });`,
27
- });
70
+ },
71
+ ];
72
+ };
28
73
 
29
74
  export const snippetTextGeneration = (
30
75
  model: ModelDataMinimal,
31
76
  accessToken: string,
77
+ provider: InferenceProvider,
32
78
  opts?: {
33
79
  streaming?: boolean;
34
80
  messages?: ChatCompletionInputMessage[];
@@ -36,7 +82,7 @@ export const snippetTextGeneration = (
36
82
  max_tokens?: GenerationParameters["max_tokens"];
37
83
  top_p?: GenerationParameters["top_p"];
38
84
  }
39
- ): InferenceSnippet | InferenceSnippet[] => {
85
+ ): InferenceSnippet[] => {
40
86
  if (model.tags.includes("conversational")) {
41
87
  // Conversational model detected, so we display a code snippet that features the Messages API
42
88
  const streaming = opts?.streaming ?? true;
@@ -67,6 +113,7 @@ let out = "";
67
113
  const stream = client.chatCompletionStream({
68
114
  model: "${model.id}",
69
115
  messages: ${messagesStr},
116
+ provider: "${provider}",
70
117
  ${configStr}
71
118
  });
72
119
 
@@ -83,8 +130,8 @@ for await (const chunk of stream) {
83
130
  content: `import { OpenAI } from "openai";
84
131
 
85
132
  const client = new OpenAI({
86
- baseURL: "https://api-inference.huggingface.co/v1/",
87
- apiKey: "${accessToken || `{API_TOKEN}`}"
133
+ baseURL: "${openAIbaseUrl(provider)}",
134
+ apiKey: "${accessToken || `{API_TOKEN}`}"
88
135
  });
89
136
 
90
137
  let out = "";
@@ -116,6 +163,7 @@ const client = new HfInference("${accessToken || `{API_TOKEN}`}");
116
163
  const chatCompletion = await client.chatCompletion({
117
164
  model: "${model.id}",
118
165
  messages: ${messagesStr},
166
+ provider: "${provider}",
119
167
  ${configStr}
120
168
  });
121
169
 
@@ -126,8 +174,8 @@ console.log(chatCompletion.choices[0].message);`,
126
174
  content: `import { OpenAI } from "openai";
127
175
 
128
176
  const client = new OpenAI({
129
- baseURL: "https://api-inference.huggingface.co/v1/",
130
- apiKey: "${accessToken || `{API_TOKEN}`}"
177
+ baseURL: "${openAIbaseUrl(provider)}",
178
+ apiKey: "${accessToken || `{API_TOKEN}`}"
131
179
  });
132
180
 
133
181
  const chatCompletion = await client.chat.completions.create({
@@ -141,36 +189,66 @@ console.log(chatCompletion.choices[0].message);`,
141
189
  ];
142
190
  }
143
191
  } else {
144
- return snippetBasic(model, accessToken);
192
+ return snippetBasic(model, accessToken, provider);
145
193
  }
146
194
  };
147
195
 
148
- export const snippetZeroShotClassification = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
149
- content: `async function query(data) {
150
- const response = await fetch(
151
- "https://api-inference.huggingface.co/models/${model.id}",
196
+ export const snippetZeroShotClassification = (model: ModelDataMinimal, accessToken: string): InferenceSnippet[] => {
197
+ return [
152
198
  {
153
- headers: {
154
- Authorization: "Bearer ${accessToken || `{API_TOKEN}`}",
155
- "Content-Type": "application/json",
156
- },
157
- method: "POST",
158
- body: JSON.stringify(data),
199
+ client: "fetch",
200
+ content: `async function query(data) {
201
+ const response = await fetch(
202
+ "https://api-inference.huggingface.co/models/${model.id}",
203
+ {
204
+ headers: {
205
+ Authorization: "Bearer ${accessToken || `{API_TOKEN}`}",
206
+ "Content-Type": "application/json",
207
+ },
208
+ method: "POST",
209
+ body: JSON.stringify(data),
210
+ }
211
+ );
212
+ const result = await response.json();
213
+ return result;
159
214
  }
160
- );
161
- const result = await response.json();
162
- return result;
163
- }
215
+
216
+ query({"inputs": ${getModelInputSnippet(
217
+ model
218
+ )}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}).then((response) => {
219
+ console.log(JSON.stringify(response));
220
+ });`,
221
+ },
222
+ ];
223
+ };
164
224
 
165
- query({"inputs": ${getModelInputSnippet(
166
- model
167
- )}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}).then((response) => {
168
- console.log(JSON.stringify(response));
169
- });`,
170
- });
225
+ export const snippetTextToImage = (
226
+ model: ModelDataMinimal,
227
+ accessToken: string,
228
+ provider: InferenceProvider
229
+ ): InferenceSnippet[] => {
230
+ return [
231
+ {
232
+ client: "huggingface.js",
233
+ content: `\
234
+ import { HfInference } from "@huggingface/inference";
235
+
236
+ const client = new HfInference("${accessToken || `{API_TOKEN}`}");
171
237
 
172
- export const snippetTextToImage = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
173
- content: `async function query(data) {
238
+ const image = await client.textToImage({
239
+ model: "${model.id}",
240
+ inputs: ${getModelInputSnippet(model)},
241
+ parameters: { num_inference_steps: 5 },
242
+ provider: "${provider}",
243
+ });
244
+ /// Use the generated image (it's a Blob)
245
+ `,
246
+ },
247
+ ...(provider === "hf-inference"
248
+ ? [
249
+ {
250
+ client: "fetch",
251
+ content: `async function query(data) {
174
252
  const response = await fetch(
175
253
  "https://api-inference.huggingface.co/models/${model.id}",
176
254
  {
@@ -188,9 +266,20 @@ export const snippetTextToImage = (model: ModelDataMinimal, accessToken: string)
188
266
  query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
189
267
  // Use image
190
268
  });`,
191
- });
269
+ },
270
+ ]
271
+ : []),
272
+ ];
273
+ };
192
274
 
193
- export const snippetTextToAudio = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => {
275
+ export const snippetTextToAudio = (
276
+ model: ModelDataMinimal,
277
+ accessToken: string,
278
+ provider: InferenceProvider
279
+ ): InferenceSnippet[] => {
280
+ if (provider !== "hf-inference") {
281
+ return [];
282
+ }
194
283
  const commonSnippet = `async function query(data) {
195
284
  const response = await fetch(
196
285
  "https://api-inference.huggingface.co/models/${model.id}",
@@ -204,22 +293,27 @@ export const snippetTextToAudio = (model: ModelDataMinimal, accessToken: string)
204
293
  }
205
294
  );`;
206
295
  if (model.library_name === "transformers") {
207
- return {
208
- content:
209
- commonSnippet +
210
- `
296
+ return [
297
+ {
298
+ client: "fetch",
299
+ content:
300
+ commonSnippet +
301
+ `
211
302
  const result = await response.blob();
212
303
  return result;
213
304
  }
214
305
  query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
215
306
  // Returns a byte object of the Audio wavform. Use it directly!
216
307
  });`,
217
- };
308
+ },
309
+ ];
218
310
  } else {
219
- return {
220
- content:
221
- commonSnippet +
222
- `
311
+ return [
312
+ {
313
+ client: "fetch",
314
+ content:
315
+ commonSnippet +
316
+ `
223
317
  const result = await response.json();
224
318
  return result;
225
319
  }
@@ -227,12 +321,51 @@ export const snippetTextToAudio = (model: ModelDataMinimal, accessToken: string)
227
321
  query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
228
322
  console.log(JSON.stringify(response));
229
323
  });`,
230
- };
324
+ },
325
+ ];
231
326
  }
232
327
  };
233
328
 
234
- export const snippetFile = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
235
- content: `async function query(filename) {
329
+ export const snippetAutomaticSpeechRecognition = (
330
+ model: ModelDataMinimal,
331
+ accessToken: string,
332
+ provider: InferenceProvider
333
+ ): InferenceSnippet[] => {
334
+ return [
335
+ {
336
+ client: "huggingface.js",
337
+ content: `\
338
+ import { HfInference } from "@huggingface/inference";
339
+
340
+ const client = new HfInference("${accessToken || `{API_TOKEN}`}");
341
+
342
+ const data = fs.readFileSync(${getModelInputSnippet(model)});
343
+
344
+ const output = await client.automaticSpeechRecognition({
345
+ data,
346
+ model: "${model.id}",
347
+ provider: "${provider}",
348
+ });
349
+
350
+ console.log(output);
351
+ `,
352
+ },
353
+ ...(provider === "hf-inference" ? snippetFile(model, accessToken, provider) : []),
354
+ ];
355
+ };
356
+
357
+ export const snippetFile = (
358
+ model: ModelDataMinimal,
359
+ accessToken: string,
360
+ provider: InferenceProvider
361
+ ): InferenceSnippet[] => {
362
+ if (provider !== "hf-inference") {
363
+ return [];
364
+ }
365
+ return [
366
+ {
367
+ client: "fetch",
368
+ content: `async function query(filename) {
236
369
  const data = fs.readFileSync(filename);
237
370
  const response = await fetch(
238
371
  "https://api-inference.huggingface.co/models/${model.id}",
@@ -252,7 +385,9 @@ export const snippetFile = (model: ModelDataMinimal, accessToken: string): Infer
252
385
  query(${getModelInputSnippet(model)}).then((response) => {
253
386
  console.log(JSON.stringify(response));
254
387
  });`,
255
- });
388
+ },
389
+ ];
390
+ };
256
391
 
257
392
  export const jsSnippets: Partial<
258
393
  Record<
@@ -260,11 +395,12 @@ export const jsSnippets: Partial<
260
395
  (
261
396
  model: ModelDataMinimal,
262
397
  accessToken: string,
398
+ provider: InferenceProvider,
263
399
  opts?: Record<string, unknown>
264
- ) => InferenceSnippet | InferenceSnippet[]
400
+ ) => InferenceSnippet[]
265
401
  >
266
402
  > = {
267
- // Same order as in js/src/lib/interfaces/Types.ts
403
+ // Same order as in tasks/src/pipelines.ts
268
404
  "text-classification": snippetBasic,
269
405
  "token-classification": snippetBasic,
270
406
  "table-question-answering": snippetBasic,
@@ -278,7 +414,7 @@ export const jsSnippets: Partial<
278
414
  "text2text-generation": snippetBasic,
279
415
  "fill-mask": snippetBasic,
280
416
  "sentence-similarity": snippetBasic,
281
- "automatic-speech-recognition": snippetFile,
417
+ "automatic-speech-recognition": snippetAutomaticSpeechRecognition,
282
418
  "text-to-image": snippetTextToImage,
283
419
  "text-to-speech": snippetTextToAudio,
284
420
  "text-to-audio": snippetTextToAudio,
@@ -293,13 +429,10 @@ export const jsSnippets: Partial<
293
429
  export function getJsInferenceSnippet(
294
430
  model: ModelDataMinimal,
295
431
  accessToken: string,
432
+ provider: InferenceProvider,
296
433
  opts?: Record<string, unknown>
297
- ): InferenceSnippet | InferenceSnippet[] {
434
+ ): InferenceSnippet[] {
298
435
  return model.pipeline_tag && model.pipeline_tag in jsSnippets
299
- ? jsSnippets[model.pipeline_tag]?.(model, accessToken, opts) ?? { content: "" }
300
- : { content: "" };
301
- }
302
-
303
- export function hasJsInferenceSnippet(model: ModelDataMinimal): boolean {
304
- return !!model.pipeline_tag && model.pipeline_tag in jsSnippets;
436
+ ? jsSnippets[model.pipeline_tag]?.(model, accessToken, provider, opts) ?? []
437
+ : [];
305
438
  }