@huggingface/tasks 0.14.0 → 0.15.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. package/dist/commonjs/index.d.ts +1 -0
  2. package/dist/commonjs/index.d.ts.map +1 -1
  3. package/dist/commonjs/index.js +1 -0
  4. package/dist/commonjs/inference-providers.d.ts +10 -0
  5. package/dist/commonjs/inference-providers.d.ts.map +1 -0
  6. package/dist/commonjs/inference-providers.js +16 -0
  7. package/dist/commonjs/snippets/curl.d.ts +8 -8
  8. package/dist/commonjs/snippets/curl.d.ts.map +1 -1
  9. package/dist/commonjs/snippets/curl.js +58 -30
  10. package/dist/commonjs/snippets/js.d.ts +11 -10
  11. package/dist/commonjs/snippets/js.d.ts.map +1 -1
  12. package/dist/commonjs/snippets/js.js +162 -53
  13. package/dist/commonjs/snippets/python.d.ts +12 -12
  14. package/dist/commonjs/snippets/python.d.ts.map +1 -1
  15. package/dist/commonjs/snippets/python.js +141 -71
  16. package/dist/commonjs/snippets/types.d.ts +1 -1
  17. package/dist/commonjs/snippets/types.d.ts.map +1 -1
  18. package/dist/esm/index.d.ts +1 -0
  19. package/dist/esm/index.d.ts.map +1 -1
  20. package/dist/esm/index.js +1 -0
  21. package/dist/esm/inference-providers.d.ts +10 -0
  22. package/dist/esm/inference-providers.d.ts.map +1 -0
  23. package/dist/esm/inference-providers.js +12 -0
  24. package/dist/esm/snippets/curl.d.ts +8 -8
  25. package/dist/esm/snippets/curl.d.ts.map +1 -1
  26. package/dist/esm/snippets/curl.js +58 -29
  27. package/dist/esm/snippets/js.d.ts +11 -10
  28. package/dist/esm/snippets/js.d.ts.map +1 -1
  29. package/dist/esm/snippets/js.js +159 -50
  30. package/dist/esm/snippets/python.d.ts +12 -12
  31. package/dist/esm/snippets/python.d.ts.map +1 -1
  32. package/dist/esm/snippets/python.js +140 -69
  33. package/dist/esm/snippets/types.d.ts +1 -1
  34. package/dist/esm/snippets/types.d.ts.map +1 -1
  35. package/package.json +1 -1
  36. package/src/index.ts +2 -0
  37. package/src/inference-providers.ts +16 -0
  38. package/src/snippets/curl.ts +72 -23
  39. package/src/snippets/js.ts +189 -56
  40. package/src/snippets/python.ts +154 -75
  41. package/src/snippets/types.ts +1 -1
@@ -1,9 +1,14 @@
1
+ import { HF_HUB_INFERENCE_PROXY_TEMPLATE, openAIbaseUrl } from "../inference-providers.js";
1
2
  import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
2
3
  import { getModelInputSnippet } from "./inputs.js";
3
- const snippetImportInferenceClient = (model, accessToken) => `from huggingface_hub import InferenceClient
4
- client = InferenceClient("${model.id}", token="${accessToken || "{API_TOKEN}"}")
5
- `;
6
- export const snippetConversational = (model, accessToken, opts) => {
4
+ const snippetImportInferenceClient = (accessToken, provider) => `\
5
+ from huggingface_hub import InferenceClient
6
+
7
+ client = InferenceClient(
8
+ provider="${provider}",
9
+ api_key="${accessToken || "{API_TOKEN}"}"
10
+ )`;
11
+ export const snippetConversational = (model, accessToken, provider, opts) => {
7
12
  const streaming = opts?.streaming ?? true;
8
13
  const exampleMessages = getModelInputSnippet(model);
9
14
  const messages = opts?.messages ?? exampleMessages;
@@ -21,9 +26,8 @@ export const snippetConversational = (model, accessToken, opts) => {
21
26
  return [
22
27
  {
23
28
  client: "huggingface_hub",
24
- content: `from huggingface_hub import InferenceClient
25
-
26
- client = InferenceClient(api_key="${accessToken || "{API_TOKEN}"}")
29
+ content: `\
30
+ ${snippetImportInferenceClient(accessToken, provider)}
27
31
 
28
32
  messages = ${messagesStr}
29
33
 
@@ -42,7 +46,7 @@ for chunk in stream:
42
46
  content: `from openai import OpenAI
43
47
 
44
48
  client = OpenAI(
45
- base_url="https://api-inference.huggingface.co/v1/",
49
+ base_url="${openAIbaseUrl(provider)}",
46
50
  api_key="${accessToken || "{API_TOKEN}"}"
47
51
  )
48
52
 
@@ -64,9 +68,8 @@ for chunk in stream:
64
68
  return [
65
69
  {
66
70
  client: "huggingface_hub",
67
- content: `from huggingface_hub import InferenceClient
68
-
69
- client = InferenceClient(api_key="${accessToken || "{API_TOKEN}"}")
71
+ content: `\
72
+ ${snippetImportInferenceClient(accessToken, provider)}
70
73
 
71
74
  messages = ${messagesStr}
72
75
 
@@ -83,7 +86,7 @@ print(completion.choices[0].message)`,
83
86
  content: `from openai import OpenAI
84
87
 
85
88
  client = OpenAI(
86
- base_url="https://api-inference.huggingface.co/v1/",
89
+ base_url="${openAIbaseUrl(provider)}",
87
90
  api_key="${accessToken || "{API_TOKEN}"}"
88
91
  )
89
92
 
@@ -100,8 +103,11 @@ print(completion.choices[0].message)`,
100
103
  ];
101
104
  }
102
105
  };
103
- export const snippetZeroShotClassification = (model) => ({
104
- content: `def query(payload):
106
+ export const snippetZeroShotClassification = (model) => {
107
+ return [
108
+ {
109
+ client: "requests",
110
+ content: `def query(payload):
105
111
  response = requests.post(API_URL, headers=headers, json=payload)
106
112
  return response.json()
107
113
 
@@ -109,9 +115,14 @@ output = query({
109
115
  "inputs": ${getModelInputSnippet(model)},
110
116
  "parameters": {"candidate_labels": ["refund", "legal", "faq"]},
111
117
  })`,
112
- });
113
- export const snippetZeroShotImageClassification = (model) => ({
114
- content: `def query(data):
118
+ },
119
+ ];
120
+ };
121
+ export const snippetZeroShotImageClassification = (model) => {
122
+ return [
123
+ {
124
+ client: "requests",
125
+ content: `def query(data):
115
126
  with open(data["image_path"], "rb") as f:
116
127
  img = f.read()
117
128
  payload={
@@ -125,37 +136,78 @@ output = query({
125
136
  "image_path": ${getModelInputSnippet(model)},
126
137
  "parameters": {"candidate_labels": ["cat", "dog", "llama"]},
127
138
  })`,
128
- });
129
- export const snippetBasic = (model) => ({
130
- content: `def query(payload):
139
+ },
140
+ ];
141
+ };
142
+ export const snippetBasic = (model) => {
143
+ return [
144
+ {
145
+ client: "requests",
146
+ content: `def query(payload):
131
147
  response = requests.post(API_URL, headers=headers, json=payload)
132
148
  return response.json()
133
149
 
134
150
  output = query({
135
151
  "inputs": ${getModelInputSnippet(model)},
136
152
  })`,
137
- });
138
- export const snippetFile = (model) => ({
139
- content: `def query(filename):
140
- with open(filename, "rb") as f:
141
- data = f.read()
142
- response = requests.post(API_URL, headers=headers, data=data)
143
- return response.json()
153
+ },
154
+ ];
155
+ };
156
+ export const snippetFile = (model) => {
157
+ return [
158
+ {
159
+ client: "requests",
160
+ content: `def query(filename):
161
+ with open(filename, "rb") as f:
162
+ data = f.read()
163
+ response = requests.post(API_URL, headers=headers, data=data)
164
+ return response.json()
165
+
166
+ output = query(${getModelInputSnippet(model)})`,
167
+ },
168
+ ];
169
+ };
170
+ export const snippetTextToImage = (model, accessToken, provider) => {
171
+ return [
172
+ {
173
+ client: "huggingface_hub",
174
+ content: `\
175
+ ${snippetImportInferenceClient(accessToken, provider)}
144
176
 
145
- output = query(${getModelInputSnippet(model)})`,
146
- });
147
- export const snippetTextToImage = (model, accessToken) => [
148
- {
149
- client: "huggingface_hub",
150
- content: `${snippetImportInferenceClient(model, accessToken)}
151
177
  # output is a PIL.Image object
152
- image = client.text_to_image(${getModelInputSnippet(model)})`,
153
- },
154
- {
155
- client: "requests",
156
- content: `def query(payload):
178
+ image = client.text_to_image(
179
+ ${getModelInputSnippet(model)},
180
+ model="${model.id}"
181
+ )`,
182
+ },
183
+ ...(provider === "fal-ai"
184
+ ? [
185
+ {
186
+ client: "fal-client",
187
+ content: `\
188
+ import fal_client
189
+
190
+ result = fal_client.subscribe(
191
+ # replace with correct id from fal.ai
192
+ "fal-ai/${model.id}",
193
+ arguments={
194
+ "prompt": ${getModelInputSnippet(model)},
195
+ },
196
+ )
197
+ print(result)
198
+ `,
199
+ },
200
+ ]
201
+ : []),
202
+ ...(provider === "hf-inference"
203
+ ? [
204
+ {
205
+ client: "requests",
206
+ content: `\
207
+ def query(payload):
157
208
  response = requests.post(API_URL, headers=headers, json=payload)
158
209
  return response.content
210
+
159
211
  image_bytes = query({
160
212
  "inputs": ${getModelInputSnippet(model)},
161
213
  })
@@ -164,23 +216,33 @@ image_bytes = query({
164
216
  import io
165
217
  from PIL import Image
166
218
  image = Image.open(io.BytesIO(image_bytes))`,
167
- },
168
- ];
169
- export const snippetTabular = (model) => ({
170
- content: `def query(payload):
171
- response = requests.post(API_URL, headers=headers, json=payload)
172
- return response.content
173
- response = query({
174
- "inputs": {"data": ${getModelInputSnippet(model)}},
175
- })`,
176
- });
219
+ },
220
+ ]
221
+ : []),
222
+ ];
223
+ };
224
+ export const snippetTabular = (model) => {
225
+ return [
226
+ {
227
+ client: "requests",
228
+ content: `def query(payload):
229
+ response = requests.post(API_URL, headers=headers, json=payload)
230
+ return response.content
231
+ response = query({
232
+ "inputs": {"data": ${getModelInputSnippet(model)}},
233
+ })`,
234
+ },
235
+ ];
236
+ };
177
237
  export const snippetTextToAudio = (model) => {
178
238
  // Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged
179
239
  // with the latest update to inference-api (IA).
180
240
  // Transformers IA returns a byte object (wav file), whereas AIC returns wav and sampling_rate.
181
241
  if (model.library_name === "transformers") {
182
- return {
183
- content: `def query(payload):
242
+ return [
243
+ {
244
+ client: "requests",
245
+ content: `def query(payload):
184
246
  response = requests.post(API_URL, headers=headers, json=payload)
185
247
  return response.content
186
248
 
@@ -190,11 +252,14 @@ audio_bytes = query({
190
252
  # You can access the audio with IPython.display for example
191
253
  from IPython.display import Audio
192
254
  Audio(audio_bytes)`,
193
- };
255
+ },
256
+ ];
194
257
  }
195
258
  else {
196
- return {
197
- content: `def query(payload):
259
+ return [
260
+ {
261
+ client: "requests",
262
+ content: `def query(payload):
198
263
  response = requests.post(API_URL, headers=headers, json=payload)
199
264
  return response.json()
200
265
 
@@ -204,11 +269,15 @@ audio, sampling_rate = query({
204
269
  # You can access the audio with IPython.display for example
205
270
  from IPython.display import Audio
206
271
  Audio(audio, rate=sampling_rate)`,
207
- };
272
+ },
273
+ ];
208
274
  }
209
275
  };
210
- export const snippetDocumentQuestionAnswering = (model) => ({
211
- content: `def query(payload):
276
+ export const snippetDocumentQuestionAnswering = (model) => {
277
+ return [
278
+ {
279
+ client: "requests",
280
+ content: `def query(payload):
212
281
  with open(payload["image"], "rb") as f:
213
282
  img = f.read()
214
283
  payload["image"] = base64.b64encode(img).decode("utf-8")
@@ -218,7 +287,9 @@ export const snippetDocumentQuestionAnswering = (model) => ({
218
287
  output = query({
219
288
  "inputs": ${getModelInputSnippet(model)},
220
289
  })`,
221
- });
290
+ },
291
+ ];
292
+ };
222
293
  export const pythonSnippets = {
223
294
  // Same order as in tasks/src/pipelines.ts
224
295
  "text-classification": snippetBasic,
@@ -249,23 +320,26 @@ export const pythonSnippets = {
249
320
  "image-to-text": snippetFile,
250
321
  "zero-shot-image-classification": snippetZeroShotImageClassification,
251
322
  };
252
- export function getPythonInferenceSnippet(model, accessToken, opts) {
323
+ export function getPythonInferenceSnippet(model, accessToken, provider, opts) {
253
324
  if (model.tags.includes("conversational")) {
254
325
  // Conversational model detected, so we display a code snippet that features the Messages API
255
- return snippetConversational(model, accessToken, opts);
326
+ return snippetConversational(model, accessToken, provider, opts);
256
327
  }
257
328
  else {
258
- let snippets = model.pipeline_tag && model.pipeline_tag in pythonSnippets
259
- ? pythonSnippets[model.pipeline_tag]?.(model, accessToken) ?? { content: "" }
260
- : { content: "" };
261
- snippets = Array.isArray(snippets) ? snippets : [snippets];
329
+ const snippets = model.pipeline_tag && model.pipeline_tag in pythonSnippets
330
+ ? pythonSnippets[model.pipeline_tag]?.(model, accessToken, provider) ?? []
331
+ : [];
332
+ const baseUrl = provider === "hf-inference"
333
+ ? `https://api-inference.huggingface.co/models/${model.id}`
334
+ : HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider);
262
335
  return snippets.map((snippet) => {
263
336
  return {
264
337
  ...snippet,
265
- content: snippet.content.includes("requests")
266
- ? `import requests
338
+ content: snippet.client === "requests"
339
+ ? `\
340
+ import requests
267
341
 
268
- API_URL = "https://api-inference.huggingface.co/models/${model.id}"
342
+ API_URL = "${baseUrl}"
269
343
  headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
270
344
 
271
345
  ${snippet.content}`
@@ -274,6 +348,3 @@ ${snippet.content}`
274
348
  });
275
349
  }
276
350
  }
277
- export function hasPythonInferenceSnippet(model) {
278
- return !!model.pipeline_tag && model.pipeline_tag in pythonSnippets;
279
- }
@@ -7,6 +7,6 @@ import type { ModelData } from "../model-data.js";
7
7
  export type ModelDataMinimal = Pick<ModelData, "id" | "pipeline_tag" | "mask_token" | "library_name" | "config" | "tags" | "inference">;
8
8
  export interface InferenceSnippet {
9
9
  content: string;
10
- client?: string;
10
+ client: string;
11
11
  }
12
12
  //# sourceMappingURL=types.d.ts.map
@@ -1 +1 @@
1
- {"version":3,"file":"types.d.ts","sourceRoot":"","sources":["../../../src/snippets/types.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,SAAS,EAAE,MAAM,kBAAkB,CAAC;AAElD;;;;GAIG;AACH,MAAM,MAAM,gBAAgB,GAAG,IAAI,CAClC,SAAS,EACT,IAAI,GAAG,cAAc,GAAG,YAAY,GAAG,cAAc,GAAG,QAAQ,GAAG,MAAM,GAAG,WAAW,CACvF,CAAC;AAEF,MAAM,WAAW,gBAAgB;IAChC,OAAO,EAAE,MAAM,CAAC;IAChB,MAAM,CAAC,EAAE,MAAM,CAAC;CAChB"}
1
+ {"version":3,"file":"types.d.ts","sourceRoot":"","sources":["../../../src/snippets/types.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,SAAS,EAAE,MAAM,kBAAkB,CAAC;AAElD;;;;GAIG;AACH,MAAM,MAAM,gBAAgB,GAAG,IAAI,CAClC,SAAS,EACT,IAAI,GAAG,cAAc,GAAG,YAAY,GAAG,cAAc,GAAG,QAAQ,GAAG,MAAM,GAAG,WAAW,CACvF,CAAC;AAEF,MAAM,WAAW,gBAAgB;IAChC,OAAO,EAAE,MAAM,CAAC;IAChB,MAAM,EAAE,MAAM,CAAC;CACf"}
package/package.json CHANGED
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "name": "@huggingface/tasks",
3
3
  "packageManager": "pnpm@8.10.5",
4
- "version": "0.14.0",
4
+ "version": "0.15.0",
5
5
  "description": "List of ML tasks for huggingface.co/tasks",
6
6
  "repository": "https://github.com/huggingface/huggingface.js.git",
7
7
  "publishConfig": {
package/src/index.ts CHANGED
@@ -58,3 +58,5 @@ export type { LocalApp, LocalAppKey, LocalAppSnippet } from "./local-apps.js";
58
58
 
59
59
  export { DATASET_LIBRARIES_UI_ELEMENTS } from "./dataset-libraries.js";
60
60
  export type { DatasetLibraryUiElement, DatasetLibraryKey } from "./dataset-libraries.js";
61
+
62
+ export * from "./inference-providers.js";
@@ -0,0 +1,16 @@
1
+ export const INFERENCE_PROVIDERS = ["hf-inference", "fal-ai", "replicate", "sambanova", "together"] as const;
2
+
3
+ export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
4
+
5
+ export const HF_HUB_INFERENCE_PROXY_TEMPLATE = `https://huggingface.co/api/inference-proxy/{{PROVIDER}}`;
6
+
7
+ /**
8
+ * URL to set as baseUrl in the OpenAI SDK.
9
+ *
10
+ * TODO(Expose this from HfInference in the future?)
11
+ */
12
+ export function openAIbaseUrl(provider: InferenceProvider): string {
13
+ return provider === "hf-inference"
14
+ ? "https://api-inference.huggingface.co/v1/"
15
+ : HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider);
16
+ }
@@ -1,20 +1,35 @@
1
+ import { HF_HUB_INFERENCE_PROXY_TEMPLATE, type InferenceProvider } from "../inference-providers.js";
1
2
  import type { PipelineType } from "../pipelines.js";
2
3
  import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
3
4
  import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
4
5
  import { getModelInputSnippet } from "./inputs.js";
5
6
  import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
6
7
 
7
- export const snippetBasic = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
8
- content: `curl https://api-inference.huggingface.co/models/${model.id} \\
8
+ export const snippetBasic = (
9
+ model: ModelDataMinimal,
10
+ accessToken: string,
11
+ provider: InferenceProvider
12
+ ): InferenceSnippet[] => {
13
+ if (provider !== "hf-inference") {
14
+ return [];
15
+ }
16
+ return [
17
+ {
18
+ client: "curl",
19
+ content: `\
20
+ curl https://api-inference.huggingface.co/models/${model.id} \\
9
21
  -X POST \\
10
22
  -d '{"inputs": ${getModelInputSnippet(model, true)}}' \\
11
23
  -H 'Content-Type: application/json' \\
12
24
  -H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`,
13
- });
25
+ },
26
+ ];
27
+ };
14
28
 
15
29
  export const snippetTextGeneration = (
16
30
  model: ModelDataMinimal,
17
31
  accessToken: string,
32
+ provider: InferenceProvider,
18
33
  opts?: {
19
34
  streaming?: boolean;
20
35
  messages?: ChatCompletionInputMessage[];
@@ -22,8 +37,13 @@ export const snippetTextGeneration = (
22
37
  max_tokens?: GenerationParameters["max_tokens"];
23
38
  top_p?: GenerationParameters["top_p"];
24
39
  }
25
- ): InferenceSnippet => {
40
+ ): InferenceSnippet[] => {
26
41
  if (model.tags.includes("conversational")) {
42
+ const baseUrl =
43
+ provider === "hf-inference"
44
+ ? `https://api-inference.huggingface.co/models/${model.id}/v1/chat/completions`
45
+ : HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider) + "/v1/chat/completions";
46
+
27
47
  // Conversational model detected, so we display a code snippet that features the Messages API
28
48
  const streaming = opts?.streaming ?? true;
29
49
  const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[];
@@ -34,8 +54,10 @@ export const snippetTextGeneration = (
34
54
  max_tokens: opts?.max_tokens ?? 500,
35
55
  ...(opts?.top_p ? { top_p: opts.top_p } : undefined),
36
56
  };
37
- return {
38
- content: `curl 'https://api-inference.huggingface.co/models/${model.id}/v1/chat/completions' \\
57
+ return [
58
+ {
59
+ client: "curl",
60
+ content: `curl '${baseUrl}' \\
39
61
  -H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}' \\
40
62
  -H 'Content-Type: application/json' \\
41
63
  --data '{
@@ -52,34 +74,64 @@ export const snippetTextGeneration = (
52
74
  })},
53
75
  "stream": ${!!streaming}
54
76
  }'`,
55
- };
77
+ },
78
+ ];
56
79
  } else {
57
- return snippetBasic(model, accessToken);
80
+ return snippetBasic(model, accessToken, provider);
58
81
  }
59
82
  };
60
83
 
61
- export const snippetZeroShotClassification = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
62
- content: `curl https://api-inference.huggingface.co/models/${model.id} \\
84
+ export const snippetZeroShotClassification = (
85
+ model: ModelDataMinimal,
86
+ accessToken: string,
87
+ provider: InferenceProvider
88
+ ): InferenceSnippet[] => {
89
+ if (provider !== "hf-inference") {
90
+ return [];
91
+ }
92
+ return [
93
+ {
94
+ client: "curl",
95
+ content: `curl https://api-inference.huggingface.co/models/${model.id} \\
63
96
  -X POST \\
64
97
  -d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\
65
98
  -H 'Content-Type: application/json' \\
66
99
  -H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`,
67
- });
100
+ },
101
+ ];
102
+ };
68
103
 
69
- export const snippetFile = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
70
- content: `curl https://api-inference.huggingface.co/models/${model.id} \\
104
+ export const snippetFile = (
105
+ model: ModelDataMinimal,
106
+ accessToken: string,
107
+ provider: InferenceProvider
108
+ ): InferenceSnippet[] => {
109
+ if (provider !== "hf-inference") {
110
+ return [];
111
+ }
112
+ return [
113
+ {
114
+ client: "curl",
115
+ content: `curl https://api-inference.huggingface.co/models/${model.id} \\
71
116
  -X POST \\
72
117
  --data-binary '@${getModelInputSnippet(model, true, true)}' \\
73
118
  -H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`,
74
- });
119
+ },
120
+ ];
121
+ };
75
122
 
76
123
  export const curlSnippets: Partial<
77
124
  Record<
78
125
  PipelineType,
79
- (model: ModelDataMinimal, accessToken: string, opts?: Record<string, unknown>) => InferenceSnippet
126
+ (
127
+ model: ModelDataMinimal,
128
+ accessToken: string,
129
+ provider: InferenceProvider,
130
+ opts?: Record<string, unknown>
131
+ ) => InferenceSnippet[]
80
132
  >
81
133
  > = {
82
- // Same order as in js/src/lib/interfaces/Types.ts
134
+ // Same order as in tasks/src/pipelines.ts
83
135
  "text-classification": snippetBasic,
84
136
  "token-classification": snippetBasic,
85
137
  "table-question-answering": snippetBasic,
@@ -108,13 +160,10 @@ export const curlSnippets: Partial<
108
160
  export function getCurlInferenceSnippet(
109
161
  model: ModelDataMinimal,
110
162
  accessToken: string,
163
+ provider: InferenceProvider,
111
164
  opts?: Record<string, unknown>
112
- ): InferenceSnippet {
165
+ ): InferenceSnippet[] {
113
166
  return model.pipeline_tag && model.pipeline_tag in curlSnippets
114
- ? curlSnippets[model.pipeline_tag]?.(model, accessToken, opts) ?? { content: "" }
115
- : { content: "" };
116
- }
117
-
118
- export function hasCurlInferenceSnippet(model: Pick<ModelDataMinimal, "pipeline_tag">): boolean {
119
- return !!model.pipeline_tag && model.pipeline_tag in curlSnippets;
167
+ ? curlSnippets[model.pipeline_tag]?.(model, accessToken, provider, opts) ?? []
168
+ : [];
120
169
  }