@huggingface/tasks 0.14.0 → 0.15.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/commonjs/index.d.ts +1 -0
- package/dist/commonjs/index.d.ts.map +1 -1
- package/dist/commonjs/index.js +1 -0
- package/dist/commonjs/inference-providers.d.ts +10 -0
- package/dist/commonjs/inference-providers.d.ts.map +1 -0
- package/dist/commonjs/inference-providers.js +16 -0
- package/dist/commonjs/snippets/curl.d.ts +8 -8
- package/dist/commonjs/snippets/curl.d.ts.map +1 -1
- package/dist/commonjs/snippets/curl.js +58 -30
- package/dist/commonjs/snippets/js.d.ts +11 -10
- package/dist/commonjs/snippets/js.d.ts.map +1 -1
- package/dist/commonjs/snippets/js.js +162 -53
- package/dist/commonjs/snippets/python.d.ts +12 -12
- package/dist/commonjs/snippets/python.d.ts.map +1 -1
- package/dist/commonjs/snippets/python.js +141 -71
- package/dist/commonjs/snippets/types.d.ts +1 -1
- package/dist/commonjs/snippets/types.d.ts.map +1 -1
- package/dist/esm/index.d.ts +1 -0
- package/dist/esm/index.d.ts.map +1 -1
- package/dist/esm/index.js +1 -0
- package/dist/esm/inference-providers.d.ts +10 -0
- package/dist/esm/inference-providers.d.ts.map +1 -0
- package/dist/esm/inference-providers.js +12 -0
- package/dist/esm/snippets/curl.d.ts +8 -8
- package/dist/esm/snippets/curl.d.ts.map +1 -1
- package/dist/esm/snippets/curl.js +58 -29
- package/dist/esm/snippets/js.d.ts +11 -10
- package/dist/esm/snippets/js.d.ts.map +1 -1
- package/dist/esm/snippets/js.js +159 -50
- package/dist/esm/snippets/python.d.ts +12 -12
- package/dist/esm/snippets/python.d.ts.map +1 -1
- package/dist/esm/snippets/python.js +140 -69
- package/dist/esm/snippets/types.d.ts +1 -1
- package/dist/esm/snippets/types.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/index.ts +2 -0
- package/src/inference-providers.ts +16 -0
- package/src/snippets/curl.ts +72 -23
- package/src/snippets/js.ts +189 -56
- package/src/snippets/python.ts +154 -75
- package/src/snippets/types.ts +1 -1
|
@@ -1,9 +1,14 @@
|
|
|
1
|
+
import { HF_HUB_INFERENCE_PROXY_TEMPLATE, openAIbaseUrl } from "../inference-providers.js";
|
|
1
2
|
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
|
|
2
3
|
import { getModelInputSnippet } from "./inputs.js";
|
|
3
|
-
const snippetImportInferenceClient = (
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
4
|
+
const snippetImportInferenceClient = (accessToken, provider) => `\
|
|
5
|
+
from huggingface_hub import InferenceClient
|
|
6
|
+
|
|
7
|
+
client = InferenceClient(
|
|
8
|
+
provider="${provider}",
|
|
9
|
+
api_key="${accessToken || "{API_TOKEN}"}"
|
|
10
|
+
)`;
|
|
11
|
+
export const snippetConversational = (model, accessToken, provider, opts) => {
|
|
7
12
|
const streaming = opts?.streaming ?? true;
|
|
8
13
|
const exampleMessages = getModelInputSnippet(model);
|
|
9
14
|
const messages = opts?.messages ?? exampleMessages;
|
|
@@ -21,9 +26,8 @@ export const snippetConversational = (model, accessToken, opts) => {
|
|
|
21
26
|
return [
|
|
22
27
|
{
|
|
23
28
|
client: "huggingface_hub",
|
|
24
|
-
content:
|
|
25
|
-
|
|
26
|
-
client = InferenceClient(api_key="${accessToken || "{API_TOKEN}"}")
|
|
29
|
+
content: `\
|
|
30
|
+
${snippetImportInferenceClient(accessToken, provider)}
|
|
27
31
|
|
|
28
32
|
messages = ${messagesStr}
|
|
29
33
|
|
|
@@ -42,7 +46,7 @@ for chunk in stream:
|
|
|
42
46
|
content: `from openai import OpenAI
|
|
43
47
|
|
|
44
48
|
client = OpenAI(
|
|
45
|
-
base_url="
|
|
49
|
+
base_url="${openAIbaseUrl(provider)}",
|
|
46
50
|
api_key="${accessToken || "{API_TOKEN}"}"
|
|
47
51
|
)
|
|
48
52
|
|
|
@@ -64,9 +68,8 @@ for chunk in stream:
|
|
|
64
68
|
return [
|
|
65
69
|
{
|
|
66
70
|
client: "huggingface_hub",
|
|
67
|
-
content:
|
|
68
|
-
|
|
69
|
-
client = InferenceClient(api_key="${accessToken || "{API_TOKEN}"}")
|
|
71
|
+
content: `\
|
|
72
|
+
${snippetImportInferenceClient(accessToken, provider)}
|
|
70
73
|
|
|
71
74
|
messages = ${messagesStr}
|
|
72
75
|
|
|
@@ -83,7 +86,7 @@ print(completion.choices[0].message)`,
|
|
|
83
86
|
content: `from openai import OpenAI
|
|
84
87
|
|
|
85
88
|
client = OpenAI(
|
|
86
|
-
base_url="
|
|
89
|
+
base_url="${openAIbaseUrl(provider)}",
|
|
87
90
|
api_key="${accessToken || "{API_TOKEN}"}"
|
|
88
91
|
)
|
|
89
92
|
|
|
@@ -100,8 +103,11 @@ print(completion.choices[0].message)`,
|
|
|
100
103
|
];
|
|
101
104
|
}
|
|
102
105
|
};
|
|
103
|
-
export const snippetZeroShotClassification = (model) =>
|
|
104
|
-
|
|
106
|
+
export const snippetZeroShotClassification = (model) => {
|
|
107
|
+
return [
|
|
108
|
+
{
|
|
109
|
+
client: "requests",
|
|
110
|
+
content: `def query(payload):
|
|
105
111
|
response = requests.post(API_URL, headers=headers, json=payload)
|
|
106
112
|
return response.json()
|
|
107
113
|
|
|
@@ -109,9 +115,14 @@ output = query({
|
|
|
109
115
|
"inputs": ${getModelInputSnippet(model)},
|
|
110
116
|
"parameters": {"candidate_labels": ["refund", "legal", "faq"]},
|
|
111
117
|
})`,
|
|
112
|
-
}
|
|
113
|
-
|
|
114
|
-
|
|
118
|
+
},
|
|
119
|
+
];
|
|
120
|
+
};
|
|
121
|
+
export const snippetZeroShotImageClassification = (model) => {
|
|
122
|
+
return [
|
|
123
|
+
{
|
|
124
|
+
client: "requests",
|
|
125
|
+
content: `def query(data):
|
|
115
126
|
with open(data["image_path"], "rb") as f:
|
|
116
127
|
img = f.read()
|
|
117
128
|
payload={
|
|
@@ -125,37 +136,78 @@ output = query({
|
|
|
125
136
|
"image_path": ${getModelInputSnippet(model)},
|
|
126
137
|
"parameters": {"candidate_labels": ["cat", "dog", "llama"]},
|
|
127
138
|
})`,
|
|
128
|
-
}
|
|
129
|
-
|
|
130
|
-
|
|
139
|
+
},
|
|
140
|
+
];
|
|
141
|
+
};
|
|
142
|
+
export const snippetBasic = (model) => {
|
|
143
|
+
return [
|
|
144
|
+
{
|
|
145
|
+
client: "requests",
|
|
146
|
+
content: `def query(payload):
|
|
131
147
|
response = requests.post(API_URL, headers=headers, json=payload)
|
|
132
148
|
return response.json()
|
|
133
149
|
|
|
134
150
|
output = query({
|
|
135
151
|
"inputs": ${getModelInputSnippet(model)},
|
|
136
152
|
})`,
|
|
137
|
-
}
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
153
|
+
},
|
|
154
|
+
];
|
|
155
|
+
};
|
|
156
|
+
export const snippetFile = (model) => {
|
|
157
|
+
return [
|
|
158
|
+
{
|
|
159
|
+
client: "requests",
|
|
160
|
+
content: `def query(filename):
|
|
161
|
+
with open(filename, "rb") as f:
|
|
162
|
+
data = f.read()
|
|
163
|
+
response = requests.post(API_URL, headers=headers, data=data)
|
|
164
|
+
return response.json()
|
|
165
|
+
|
|
166
|
+
output = query(${getModelInputSnippet(model)})`,
|
|
167
|
+
},
|
|
168
|
+
];
|
|
169
|
+
};
|
|
170
|
+
export const snippetTextToImage = (model, accessToken, provider) => {
|
|
171
|
+
return [
|
|
172
|
+
{
|
|
173
|
+
client: "huggingface_hub",
|
|
174
|
+
content: `\
|
|
175
|
+
${snippetImportInferenceClient(accessToken, provider)}
|
|
144
176
|
|
|
145
|
-
output = query(${getModelInputSnippet(model)})`,
|
|
146
|
-
});
|
|
147
|
-
export const snippetTextToImage = (model, accessToken) => [
|
|
148
|
-
{
|
|
149
|
-
client: "huggingface_hub",
|
|
150
|
-
content: `${snippetImportInferenceClient(model, accessToken)}
|
|
151
177
|
# output is a PIL.Image object
|
|
152
|
-
image = client.text_to_image(
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
178
|
+
image = client.text_to_image(
|
|
179
|
+
${getModelInputSnippet(model)},
|
|
180
|
+
model="${model.id}"
|
|
181
|
+
)`,
|
|
182
|
+
},
|
|
183
|
+
...(provider === "fal-ai"
|
|
184
|
+
? [
|
|
185
|
+
{
|
|
186
|
+
client: "fal-client",
|
|
187
|
+
content: `\
|
|
188
|
+
import fal_client
|
|
189
|
+
|
|
190
|
+
result = fal_client.subscribe(
|
|
191
|
+
# replace with correct id from fal.ai
|
|
192
|
+
"fal-ai/${model.id}",
|
|
193
|
+
arguments={
|
|
194
|
+
"prompt": ${getModelInputSnippet(model)},
|
|
195
|
+
},
|
|
196
|
+
)
|
|
197
|
+
print(result)
|
|
198
|
+
`,
|
|
199
|
+
},
|
|
200
|
+
]
|
|
201
|
+
: []),
|
|
202
|
+
...(provider === "hf-inference"
|
|
203
|
+
? [
|
|
204
|
+
{
|
|
205
|
+
client: "requests",
|
|
206
|
+
content: `\
|
|
207
|
+
def query(payload):
|
|
157
208
|
response = requests.post(API_URL, headers=headers, json=payload)
|
|
158
209
|
return response.content
|
|
210
|
+
|
|
159
211
|
image_bytes = query({
|
|
160
212
|
"inputs": ${getModelInputSnippet(model)},
|
|
161
213
|
})
|
|
@@ -164,23 +216,33 @@ image_bytes = query({
|
|
|
164
216
|
import io
|
|
165
217
|
from PIL import Image
|
|
166
218
|
image = Image.open(io.BytesIO(image_bytes))`,
|
|
167
|
-
|
|
168
|
-
]
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
219
|
+
},
|
|
220
|
+
]
|
|
221
|
+
: []),
|
|
222
|
+
];
|
|
223
|
+
};
|
|
224
|
+
export const snippetTabular = (model) => {
|
|
225
|
+
return [
|
|
226
|
+
{
|
|
227
|
+
client: "requests",
|
|
228
|
+
content: `def query(payload):
|
|
229
|
+
response = requests.post(API_URL, headers=headers, json=payload)
|
|
230
|
+
return response.content
|
|
231
|
+
response = query({
|
|
232
|
+
"inputs": {"data": ${getModelInputSnippet(model)}},
|
|
233
|
+
})`,
|
|
234
|
+
},
|
|
235
|
+
];
|
|
236
|
+
};
|
|
177
237
|
export const snippetTextToAudio = (model) => {
|
|
178
238
|
// Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged
|
|
179
239
|
// with the latest update to inference-api (IA).
|
|
180
240
|
// Transformers IA returns a byte object (wav file), whereas AIC returns wav and sampling_rate.
|
|
181
241
|
if (model.library_name === "transformers") {
|
|
182
|
-
return
|
|
183
|
-
|
|
242
|
+
return [
|
|
243
|
+
{
|
|
244
|
+
client: "requests",
|
|
245
|
+
content: `def query(payload):
|
|
184
246
|
response = requests.post(API_URL, headers=headers, json=payload)
|
|
185
247
|
return response.content
|
|
186
248
|
|
|
@@ -190,11 +252,14 @@ audio_bytes = query({
|
|
|
190
252
|
# You can access the audio with IPython.display for example
|
|
191
253
|
from IPython.display import Audio
|
|
192
254
|
Audio(audio_bytes)`,
|
|
193
|
-
|
|
255
|
+
},
|
|
256
|
+
];
|
|
194
257
|
}
|
|
195
258
|
else {
|
|
196
|
-
return
|
|
197
|
-
|
|
259
|
+
return [
|
|
260
|
+
{
|
|
261
|
+
client: "requests",
|
|
262
|
+
content: `def query(payload):
|
|
198
263
|
response = requests.post(API_URL, headers=headers, json=payload)
|
|
199
264
|
return response.json()
|
|
200
265
|
|
|
@@ -204,11 +269,15 @@ audio, sampling_rate = query({
|
|
|
204
269
|
# You can access the audio with IPython.display for example
|
|
205
270
|
from IPython.display import Audio
|
|
206
271
|
Audio(audio, rate=sampling_rate)`,
|
|
207
|
-
|
|
272
|
+
},
|
|
273
|
+
];
|
|
208
274
|
}
|
|
209
275
|
};
|
|
210
|
-
export const snippetDocumentQuestionAnswering = (model) =>
|
|
211
|
-
|
|
276
|
+
export const snippetDocumentQuestionAnswering = (model) => {
|
|
277
|
+
return [
|
|
278
|
+
{
|
|
279
|
+
client: "requests",
|
|
280
|
+
content: `def query(payload):
|
|
212
281
|
with open(payload["image"], "rb") as f:
|
|
213
282
|
img = f.read()
|
|
214
283
|
payload["image"] = base64.b64encode(img).decode("utf-8")
|
|
@@ -218,7 +287,9 @@ export const snippetDocumentQuestionAnswering = (model) => ({
|
|
|
218
287
|
output = query({
|
|
219
288
|
"inputs": ${getModelInputSnippet(model)},
|
|
220
289
|
})`,
|
|
221
|
-
}
|
|
290
|
+
},
|
|
291
|
+
];
|
|
292
|
+
};
|
|
222
293
|
export const pythonSnippets = {
|
|
223
294
|
// Same order as in tasks/src/pipelines.ts
|
|
224
295
|
"text-classification": snippetBasic,
|
|
@@ -249,23 +320,26 @@ export const pythonSnippets = {
|
|
|
249
320
|
"image-to-text": snippetFile,
|
|
250
321
|
"zero-shot-image-classification": snippetZeroShotImageClassification,
|
|
251
322
|
};
|
|
252
|
-
export function getPythonInferenceSnippet(model, accessToken, opts) {
|
|
323
|
+
export function getPythonInferenceSnippet(model, accessToken, provider, opts) {
|
|
253
324
|
if (model.tags.includes("conversational")) {
|
|
254
325
|
// Conversational model detected, so we display a code snippet that features the Messages API
|
|
255
|
-
return snippetConversational(model, accessToken, opts);
|
|
326
|
+
return snippetConversational(model, accessToken, provider, opts);
|
|
256
327
|
}
|
|
257
328
|
else {
|
|
258
|
-
|
|
259
|
-
? pythonSnippets[model.pipeline_tag]?.(model, accessToken) ??
|
|
260
|
-
:
|
|
261
|
-
|
|
329
|
+
const snippets = model.pipeline_tag && model.pipeline_tag in pythonSnippets
|
|
330
|
+
? pythonSnippets[model.pipeline_tag]?.(model, accessToken, provider) ?? []
|
|
331
|
+
: [];
|
|
332
|
+
const baseUrl = provider === "hf-inference"
|
|
333
|
+
? `https://api-inference.huggingface.co/models/${model.id}`
|
|
334
|
+
: HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider);
|
|
262
335
|
return snippets.map((snippet) => {
|
|
263
336
|
return {
|
|
264
337
|
...snippet,
|
|
265
|
-
content: snippet.
|
|
266
|
-
?
|
|
338
|
+
content: snippet.client === "requests"
|
|
339
|
+
? `\
|
|
340
|
+
import requests
|
|
267
341
|
|
|
268
|
-
API_URL = "
|
|
342
|
+
API_URL = "${baseUrl}"
|
|
269
343
|
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
|
|
270
344
|
|
|
271
345
|
${snippet.content}`
|
|
@@ -274,6 +348,3 @@ ${snippet.content}`
|
|
|
274
348
|
});
|
|
275
349
|
}
|
|
276
350
|
}
|
|
277
|
-
export function hasPythonInferenceSnippet(model) {
|
|
278
|
-
return !!model.pipeline_tag && model.pipeline_tag in pythonSnippets;
|
|
279
|
-
}
|
|
@@ -7,6 +7,6 @@ import type { ModelData } from "../model-data.js";
|
|
|
7
7
|
export type ModelDataMinimal = Pick<ModelData, "id" | "pipeline_tag" | "mask_token" | "library_name" | "config" | "tags" | "inference">;
|
|
8
8
|
export interface InferenceSnippet {
|
|
9
9
|
content: string;
|
|
10
|
-
client
|
|
10
|
+
client: string;
|
|
11
11
|
}
|
|
12
12
|
//# sourceMappingURL=types.d.ts.map
|
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"types.d.ts","sourceRoot":"","sources":["../../../src/snippets/types.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,SAAS,EAAE,MAAM,kBAAkB,CAAC;AAElD;;;;GAIG;AACH,MAAM,MAAM,gBAAgB,GAAG,IAAI,CAClC,SAAS,EACT,IAAI,GAAG,cAAc,GAAG,YAAY,GAAG,cAAc,GAAG,QAAQ,GAAG,MAAM,GAAG,WAAW,CACvF,CAAC;AAEF,MAAM,WAAW,gBAAgB;IAChC,OAAO,EAAE,MAAM,CAAC;IAChB,MAAM,
|
|
1
|
+
{"version":3,"file":"types.d.ts","sourceRoot":"","sources":["../../../src/snippets/types.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,SAAS,EAAE,MAAM,kBAAkB,CAAC;AAElD;;;;GAIG;AACH,MAAM,MAAM,gBAAgB,GAAG,IAAI,CAClC,SAAS,EACT,IAAI,GAAG,cAAc,GAAG,YAAY,GAAG,cAAc,GAAG,QAAQ,GAAG,MAAM,GAAG,WAAW,CACvF,CAAC;AAEF,MAAM,WAAW,gBAAgB;IAChC,OAAO,EAAE,MAAM,CAAC;IAChB,MAAM,EAAE,MAAM,CAAC;CACf"}
|
package/package.json
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@huggingface/tasks",
|
|
3
3
|
"packageManager": "pnpm@8.10.5",
|
|
4
|
-
"version": "0.
|
|
4
|
+
"version": "0.15.0",
|
|
5
5
|
"description": "List of ML tasks for huggingface.co/tasks",
|
|
6
6
|
"repository": "https://github.com/huggingface/huggingface.js.git",
|
|
7
7
|
"publishConfig": {
|
package/src/index.ts
CHANGED
|
@@ -58,3 +58,5 @@ export type { LocalApp, LocalAppKey, LocalAppSnippet } from "./local-apps.js";
|
|
|
58
58
|
|
|
59
59
|
export { DATASET_LIBRARIES_UI_ELEMENTS } from "./dataset-libraries.js";
|
|
60
60
|
export type { DatasetLibraryUiElement, DatasetLibraryKey } from "./dataset-libraries.js";
|
|
61
|
+
|
|
62
|
+
export * from "./inference-providers.js";
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
export const INFERENCE_PROVIDERS = ["hf-inference", "fal-ai", "replicate", "sambanova", "together"] as const;
|
|
2
|
+
|
|
3
|
+
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
|
|
4
|
+
|
|
5
|
+
export const HF_HUB_INFERENCE_PROXY_TEMPLATE = `https://huggingface.co/api/inference-proxy/{{PROVIDER}}`;
|
|
6
|
+
|
|
7
|
+
/**
|
|
8
|
+
* URL to set as baseUrl in the OpenAI SDK.
|
|
9
|
+
*
|
|
10
|
+
* TODO(Expose this from HfInference in the future?)
|
|
11
|
+
*/
|
|
12
|
+
export function openAIbaseUrl(provider: InferenceProvider): string {
|
|
13
|
+
return provider === "hf-inference"
|
|
14
|
+
? "https://api-inference.huggingface.co/v1/"
|
|
15
|
+
: HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider);
|
|
16
|
+
}
|
package/src/snippets/curl.ts
CHANGED
|
@@ -1,20 +1,35 @@
|
|
|
1
|
+
import { HF_HUB_INFERENCE_PROXY_TEMPLATE, type InferenceProvider } from "../inference-providers.js";
|
|
1
2
|
import type { PipelineType } from "../pipelines.js";
|
|
2
3
|
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
|
|
3
4
|
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
|
|
4
5
|
import { getModelInputSnippet } from "./inputs.js";
|
|
5
6
|
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
|
|
6
7
|
|
|
7
|
-
export const snippetBasic = (
|
|
8
|
-
|
|
8
|
+
export const snippetBasic = (
|
|
9
|
+
model: ModelDataMinimal,
|
|
10
|
+
accessToken: string,
|
|
11
|
+
provider: InferenceProvider
|
|
12
|
+
): InferenceSnippet[] => {
|
|
13
|
+
if (provider !== "hf-inference") {
|
|
14
|
+
return [];
|
|
15
|
+
}
|
|
16
|
+
return [
|
|
17
|
+
{
|
|
18
|
+
client: "curl",
|
|
19
|
+
content: `\
|
|
20
|
+
curl https://api-inference.huggingface.co/models/${model.id} \\
|
|
9
21
|
-X POST \\
|
|
10
22
|
-d '{"inputs": ${getModelInputSnippet(model, true)}}' \\
|
|
11
23
|
-H 'Content-Type: application/json' \\
|
|
12
24
|
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`,
|
|
13
|
-
}
|
|
25
|
+
},
|
|
26
|
+
];
|
|
27
|
+
};
|
|
14
28
|
|
|
15
29
|
export const snippetTextGeneration = (
|
|
16
30
|
model: ModelDataMinimal,
|
|
17
31
|
accessToken: string,
|
|
32
|
+
provider: InferenceProvider,
|
|
18
33
|
opts?: {
|
|
19
34
|
streaming?: boolean;
|
|
20
35
|
messages?: ChatCompletionInputMessage[];
|
|
@@ -22,8 +37,13 @@ export const snippetTextGeneration = (
|
|
|
22
37
|
max_tokens?: GenerationParameters["max_tokens"];
|
|
23
38
|
top_p?: GenerationParameters["top_p"];
|
|
24
39
|
}
|
|
25
|
-
): InferenceSnippet => {
|
|
40
|
+
): InferenceSnippet[] => {
|
|
26
41
|
if (model.tags.includes("conversational")) {
|
|
42
|
+
const baseUrl =
|
|
43
|
+
provider === "hf-inference"
|
|
44
|
+
? `https://api-inference.huggingface.co/models/${model.id}/v1/chat/completions`
|
|
45
|
+
: HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider) + "/v1/chat/completions";
|
|
46
|
+
|
|
27
47
|
// Conversational model detected, so we display a code snippet that features the Messages API
|
|
28
48
|
const streaming = opts?.streaming ?? true;
|
|
29
49
|
const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[];
|
|
@@ -34,8 +54,10 @@ export const snippetTextGeneration = (
|
|
|
34
54
|
max_tokens: opts?.max_tokens ?? 500,
|
|
35
55
|
...(opts?.top_p ? { top_p: opts.top_p } : undefined),
|
|
36
56
|
};
|
|
37
|
-
return
|
|
38
|
-
|
|
57
|
+
return [
|
|
58
|
+
{
|
|
59
|
+
client: "curl",
|
|
60
|
+
content: `curl '${baseUrl}' \\
|
|
39
61
|
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}' \\
|
|
40
62
|
-H 'Content-Type: application/json' \\
|
|
41
63
|
--data '{
|
|
@@ -52,34 +74,64 @@ export const snippetTextGeneration = (
|
|
|
52
74
|
})},
|
|
53
75
|
"stream": ${!!streaming}
|
|
54
76
|
}'`,
|
|
55
|
-
|
|
77
|
+
},
|
|
78
|
+
];
|
|
56
79
|
} else {
|
|
57
|
-
return snippetBasic(model, accessToken);
|
|
80
|
+
return snippetBasic(model, accessToken, provider);
|
|
58
81
|
}
|
|
59
82
|
};
|
|
60
83
|
|
|
61
|
-
export const snippetZeroShotClassification = (
|
|
62
|
-
|
|
84
|
+
export const snippetZeroShotClassification = (
|
|
85
|
+
model: ModelDataMinimal,
|
|
86
|
+
accessToken: string,
|
|
87
|
+
provider: InferenceProvider
|
|
88
|
+
): InferenceSnippet[] => {
|
|
89
|
+
if (provider !== "hf-inference") {
|
|
90
|
+
return [];
|
|
91
|
+
}
|
|
92
|
+
return [
|
|
93
|
+
{
|
|
94
|
+
client: "curl",
|
|
95
|
+
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
|
|
63
96
|
-X POST \\
|
|
64
97
|
-d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\
|
|
65
98
|
-H 'Content-Type: application/json' \\
|
|
66
99
|
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`,
|
|
67
|
-
}
|
|
100
|
+
},
|
|
101
|
+
];
|
|
102
|
+
};
|
|
68
103
|
|
|
69
|
-
export const snippetFile = (
|
|
70
|
-
|
|
104
|
+
export const snippetFile = (
|
|
105
|
+
model: ModelDataMinimal,
|
|
106
|
+
accessToken: string,
|
|
107
|
+
provider: InferenceProvider
|
|
108
|
+
): InferenceSnippet[] => {
|
|
109
|
+
if (provider !== "hf-inference") {
|
|
110
|
+
return [];
|
|
111
|
+
}
|
|
112
|
+
return [
|
|
113
|
+
{
|
|
114
|
+
client: "curl",
|
|
115
|
+
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
|
|
71
116
|
-X POST \\
|
|
72
117
|
--data-binary '@${getModelInputSnippet(model, true, true)}' \\
|
|
73
118
|
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`,
|
|
74
|
-
}
|
|
119
|
+
},
|
|
120
|
+
];
|
|
121
|
+
};
|
|
75
122
|
|
|
76
123
|
export const curlSnippets: Partial<
|
|
77
124
|
Record<
|
|
78
125
|
PipelineType,
|
|
79
|
-
(
|
|
126
|
+
(
|
|
127
|
+
model: ModelDataMinimal,
|
|
128
|
+
accessToken: string,
|
|
129
|
+
provider: InferenceProvider,
|
|
130
|
+
opts?: Record<string, unknown>
|
|
131
|
+
) => InferenceSnippet[]
|
|
80
132
|
>
|
|
81
133
|
> = {
|
|
82
|
-
// Same order as in
|
|
134
|
+
// Same order as in tasks/src/pipelines.ts
|
|
83
135
|
"text-classification": snippetBasic,
|
|
84
136
|
"token-classification": snippetBasic,
|
|
85
137
|
"table-question-answering": snippetBasic,
|
|
@@ -108,13 +160,10 @@ export const curlSnippets: Partial<
|
|
|
108
160
|
export function getCurlInferenceSnippet(
|
|
109
161
|
model: ModelDataMinimal,
|
|
110
162
|
accessToken: string,
|
|
163
|
+
provider: InferenceProvider,
|
|
111
164
|
opts?: Record<string, unknown>
|
|
112
|
-
): InferenceSnippet {
|
|
165
|
+
): InferenceSnippet[] {
|
|
113
166
|
return model.pipeline_tag && model.pipeline_tag in curlSnippets
|
|
114
|
-
? curlSnippets[model.pipeline_tag]?.(model, accessToken, opts) ??
|
|
115
|
-
:
|
|
116
|
-
}
|
|
117
|
-
|
|
118
|
-
export function hasCurlInferenceSnippet(model: Pick<ModelDataMinimal, "pipeline_tag">): boolean {
|
|
119
|
-
return !!model.pipeline_tag && model.pipeline_tag in curlSnippets;
|
|
167
|
+
? curlSnippets[model.pipeline_tag]?.(model, accessToken, provider, opts) ?? []
|
|
168
|
+
: [];
|
|
120
169
|
}
|