@huggingface/inference 3.5.2 → 3.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/browser/index.cjs +1652 -0
- package/dist/browser/index.js +1652 -0
- package/dist/index.cjs +277 -971
- package/dist/index.js +268 -982
- package/dist/src/index.d.ts.map +1 -1
- package/dist/src/lib/makeRequestOptions.d.ts +16 -1
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/novita.d.ts.map +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts +4 -0
- package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -0
- package/dist/src/snippets/index.d.ts +1 -4
- package/dist/src/snippets/index.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/package.json +15 -6
- package/src/index.ts +1 -1
- package/src/lib/makeRequestOptions.ts +37 -10
- package/src/providers/fireworks-ai.ts +1 -1
- package/src/providers/hf-inference.ts +1 -1
- package/src/providers/nebius.ts +3 -3
- package/src/providers/novita.ts +7 -6
- package/src/providers/sambanova.ts +1 -1
- package/src/providers/together.ts +3 -3
- package/src/snippets/getInferenceSnippets.ts +398 -0
- package/src/snippets/index.ts +1 -5
- package/src/snippets/templates/js/fetch/basic.jinja +19 -0
- package/src/snippets/templates/js/fetch/basicAudio.jinja +19 -0
- package/src/snippets/templates/js/fetch/basicImage.jinja +19 -0
- package/src/snippets/templates/js/fetch/textToAudio.jinja +41 -0
- package/src/snippets/templates/js/fetch/textToImage.jinja +19 -0
- package/src/snippets/templates/js/fetch/zeroShotClassification.jinja +22 -0
- package/src/snippets/templates/js/huggingface.js/basic.jinja +11 -0
- package/src/snippets/templates/js/huggingface.js/basicAudio.jinja +13 -0
- package/src/snippets/templates/js/huggingface.js/basicImage.jinja +13 -0
- package/src/snippets/templates/js/huggingface.js/conversational.jinja +11 -0
- package/src/snippets/templates/js/huggingface.js/conversationalStream.jinja +19 -0
- package/src/snippets/templates/js/huggingface.js/textToImage.jinja +11 -0
- package/src/snippets/templates/js/huggingface.js/textToVideo.jinja +10 -0
- package/src/snippets/templates/js/openai/conversational.jinja +13 -0
- package/src/snippets/templates/js/openai/conversationalStream.jinja +22 -0
- package/src/snippets/templates/python/fal_client/textToImage.jinja +11 -0
- package/src/snippets/templates/python/huggingface_hub/basic.jinja +4 -0
- package/src/snippets/templates/python/huggingface_hub/basicAudio.jinja +1 -0
- package/src/snippets/templates/python/huggingface_hub/basicImage.jinja +1 -0
- package/src/snippets/templates/python/huggingface_hub/conversational.jinja +6 -0
- package/src/snippets/templates/python/huggingface_hub/conversationalStream.jinja +8 -0
- package/src/snippets/templates/python/huggingface_hub/documentQuestionAnswering.jinja +5 -0
- package/src/snippets/templates/python/huggingface_hub/imageToImage.jinja +6 -0
- package/src/snippets/templates/python/huggingface_hub/importInferenceClient.jinja +6 -0
- package/src/snippets/templates/python/huggingface_hub/textToImage.jinja +5 -0
- package/src/snippets/templates/python/huggingface_hub/textToVideo.jinja +4 -0
- package/src/snippets/templates/python/openai/conversational.jinja +13 -0
- package/src/snippets/templates/python/openai/conversationalStream.jinja +15 -0
- package/src/snippets/templates/python/requests/basic.jinja +7 -0
- package/src/snippets/templates/python/requests/basicAudio.jinja +7 -0
- package/src/snippets/templates/python/requests/basicImage.jinja +7 -0
- package/src/snippets/templates/python/requests/conversational.jinja +9 -0
- package/src/snippets/templates/python/requests/conversationalStream.jinja +16 -0
- package/src/snippets/templates/python/requests/documentQuestionAnswering.jinja +13 -0
- package/src/snippets/templates/python/requests/imageToImage.jinja +15 -0
- package/src/snippets/templates/python/requests/importRequests.jinja +10 -0
- package/src/snippets/templates/python/requests/tabular.jinja +9 -0
- package/src/snippets/templates/python/requests/textToAudio.jinja +23 -0
- package/src/snippets/templates/python/requests/textToImage.jinja +14 -0
- package/src/snippets/templates/python/requests/zeroShotClassification.jinja +8 -0
- package/src/snippets/templates/python/requests/zeroShotImageClassification.jinja +14 -0
- package/src/snippets/templates/sh/curl/basic.jinja +7 -0
- package/src/snippets/templates/sh/curl/basicAudio.jinja +5 -0
- package/src/snippets/templates/sh/curl/basicImage.jinja +5 -0
- package/src/snippets/templates/sh/curl/conversational.jinja +7 -0
- package/src/snippets/templates/sh/curl/conversationalStream.jinja +7 -0
- package/src/snippets/templates/sh/curl/zeroShotClassification.jinja +5 -0
- package/src/tasks/cv/textToVideo.ts +25 -5
- package/src/vendor/fetch-event-source/LICENSE +21 -0
- package/dist/src/snippets/curl.d.ts +0 -17
- package/dist/src/snippets/curl.d.ts.map +0 -1
- package/dist/src/snippets/js.d.ts +0 -21
- package/dist/src/snippets/js.d.ts.map +0 -1
- package/dist/src/snippets/python.d.ts +0 -4
- package/dist/src/snippets/python.d.ts.map +0 -1
- package/src/snippets/curl.ts +0 -177
- package/src/snippets/js.ts +0 -475
- package/src/snippets/python.ts +0 -563
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
import type { PipelineType, WidgetType } from "@huggingface/tasks/src/pipelines.js";
|
|
2
|
+
import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingface/tasks/src/tasks/index.js";
|
|
3
|
+
import {
|
|
4
|
+
type InferenceSnippet,
|
|
5
|
+
type InferenceSnippetLanguage,
|
|
6
|
+
type ModelDataMinimal,
|
|
7
|
+
inferenceSnippetLanguages,
|
|
8
|
+
getModelInputSnippet,
|
|
9
|
+
} from "@huggingface/tasks";
|
|
10
|
+
import type { InferenceProvider, InferenceTask, RequestArgs } from "../types";
|
|
11
|
+
import { Template } from "@huggingface/jinja";
|
|
12
|
+
import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions";
|
|
13
|
+
import fs from "fs";
|
|
14
|
+
import path from "path";
|
|
15
|
+
import { existsSync as pathExists } from "node:fs";
|
|
16
|
+
|
|
17
|
+
const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const;
|
|
18
|
+
const JS_CLIENTS = ["fetch", "huggingface.js", "openai"] as const;
|
|
19
|
+
const SH_CLIENTS = ["curl"] as const;
|
|
20
|
+
|
|
21
|
+
type Client = (typeof SH_CLIENTS)[number] | (typeof PYTHON_CLIENTS)[number] | (typeof JS_CLIENTS)[number];
|
|
22
|
+
|
|
23
|
+
const CLIENTS: Record<InferenceSnippetLanguage, Client[]> = {
|
|
24
|
+
js: [...JS_CLIENTS],
|
|
25
|
+
python: [...PYTHON_CLIENTS],
|
|
26
|
+
sh: [...SH_CLIENTS],
|
|
27
|
+
};
|
|
28
|
+
|
|
29
|
+
type InputPreparationFn = (model: ModelDataMinimal, opts?: Record<string, unknown>) => object;
|
|
30
|
+
interface TemplateParams {
|
|
31
|
+
accessToken?: string;
|
|
32
|
+
authorizationHeader?: string;
|
|
33
|
+
baseUrl?: string;
|
|
34
|
+
fullUrl?: string;
|
|
35
|
+
inputs?: object;
|
|
36
|
+
providerInputs?: object;
|
|
37
|
+
model?: ModelDataMinimal;
|
|
38
|
+
provider?: InferenceProvider;
|
|
39
|
+
providerModelId?: string;
|
|
40
|
+
methodName?: string; // specific to snippetBasic
|
|
41
|
+
importBase64?: boolean; // specific to snippetImportRequests
|
|
42
|
+
importJson?: boolean; // specific to snippetImportRequests
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
// Helpers to find + load templates
|
|
46
|
+
|
|
47
|
+
const rootDirFinder = (): string => {
|
|
48
|
+
let currentPath =
|
|
49
|
+
typeof import.meta !== "undefined" && import.meta.url
|
|
50
|
+
? path.normalize(new URL(import.meta.url).pathname) /// for ESM
|
|
51
|
+
: __dirname; /// for CJS
|
|
52
|
+
|
|
53
|
+
while (currentPath !== "/") {
|
|
54
|
+
if (pathExists(path.join(currentPath, "package.json"))) {
|
|
55
|
+
return currentPath;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
currentPath = path.normalize(path.join(currentPath, ".."));
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
return "/";
|
|
62
|
+
};
|
|
63
|
+
|
|
64
|
+
const templatePath = (language: InferenceSnippetLanguage, client: Client, templateName: string): string =>
|
|
65
|
+
path.join(rootDirFinder(), "src", "snippets", "templates", language, client, `${templateName}.jinja`);
|
|
66
|
+
const hasTemplate = (language: InferenceSnippetLanguage, client: Client, templateName: string): boolean =>
|
|
67
|
+
pathExists(templatePath(language, client, templateName));
|
|
68
|
+
|
|
69
|
+
const loadTemplate = (
|
|
70
|
+
language: InferenceSnippetLanguage,
|
|
71
|
+
client: Client,
|
|
72
|
+
templateName: string
|
|
73
|
+
): ((data: TemplateParams) => string) => {
|
|
74
|
+
const template = fs.readFileSync(templatePath(language, client, templateName), "utf8");
|
|
75
|
+
return (data: TemplateParams) => new Template(template).render({ ...data });
|
|
76
|
+
};
|
|
77
|
+
|
|
78
|
+
const snippetImportPythonInferenceClient = loadTemplate("python", "huggingface_hub", "importInferenceClient");
|
|
79
|
+
const snippetImportRequests = loadTemplate("python", "requests", "importRequests");
|
|
80
|
+
|
|
81
|
+
// Needed for huggingface_hub basic snippets
|
|
82
|
+
|
|
83
|
+
const HF_PYTHON_METHODS: Partial<Record<WidgetType, string>> = {
|
|
84
|
+
"audio-classification": "audio_classification",
|
|
85
|
+
"audio-to-audio": "audio_to_audio",
|
|
86
|
+
"automatic-speech-recognition": "automatic_speech_recognition",
|
|
87
|
+
"document-question-answering": "document_question_answering",
|
|
88
|
+
"feature-extraction": "feature_extraction",
|
|
89
|
+
"fill-mask": "fill_mask",
|
|
90
|
+
"image-classification": "image_classification",
|
|
91
|
+
"image-segmentation": "image_segmentation",
|
|
92
|
+
"image-to-image": "image_to_image",
|
|
93
|
+
"image-to-text": "image_to_text",
|
|
94
|
+
"object-detection": "object_detection",
|
|
95
|
+
"question-answering": "question_answering",
|
|
96
|
+
"sentence-similarity": "sentence_similarity",
|
|
97
|
+
summarization: "summarization",
|
|
98
|
+
"table-question-answering": "table_question_answering",
|
|
99
|
+
"tabular-classification": "tabular_classification",
|
|
100
|
+
"tabular-regression": "tabular_regression",
|
|
101
|
+
"text-classification": "text_classification",
|
|
102
|
+
"text-generation": "text_generation",
|
|
103
|
+
"text-to-image": "text_to_image",
|
|
104
|
+
"text-to-speech": "text_to_speech",
|
|
105
|
+
"text-to-video": "text_to_video",
|
|
106
|
+
"token-classification": "token_classification",
|
|
107
|
+
translation: "translation",
|
|
108
|
+
"visual-question-answering": "visual_question_answering",
|
|
109
|
+
"zero-shot-classification": "zero_shot_classification",
|
|
110
|
+
"zero-shot-image-classification": "zero_shot_image_classification",
|
|
111
|
+
};
|
|
112
|
+
|
|
113
|
+
// Needed for huggingface.js basic snippets
|
|
114
|
+
|
|
115
|
+
const HF_JS_METHODS: Partial<Record<WidgetType, string>> = {
|
|
116
|
+
"automatic-speech-recognition": "automaticSpeechRecognition",
|
|
117
|
+
"feature-extraction": "featureExtraction",
|
|
118
|
+
"fill-mask": "fillMask",
|
|
119
|
+
"image-classification": "imageClassification",
|
|
120
|
+
"question-answering": "questionAnswering",
|
|
121
|
+
"sentence-similarity": "sentenceSimilarity",
|
|
122
|
+
summarization: "summarization",
|
|
123
|
+
"table-question-answering": "tableQuestionAnswering",
|
|
124
|
+
"text-classification": "textClassification",
|
|
125
|
+
"text-generation": "textGeneration",
|
|
126
|
+
"text2text-generation": "textGeneration",
|
|
127
|
+
"token-classification": "tokenClassification",
|
|
128
|
+
translation: "translation",
|
|
129
|
+
};
|
|
130
|
+
|
|
131
|
+
// Snippet generators
|
|
132
|
+
const snippetGenerator = (templateName: string, inputPreparationFn?: InputPreparationFn) => {
|
|
133
|
+
return (
|
|
134
|
+
model: ModelDataMinimal,
|
|
135
|
+
accessToken: string,
|
|
136
|
+
provider: InferenceProvider,
|
|
137
|
+
providerModelId?: string,
|
|
138
|
+
opts?: Record<string, unknown>
|
|
139
|
+
): InferenceSnippet[] => {
|
|
140
|
+
/// Hacky: hard-code conversational templates here
|
|
141
|
+
if (
|
|
142
|
+
model.pipeline_tag &&
|
|
143
|
+
["text-generation", "image-text-to-text"].includes(model.pipeline_tag) &&
|
|
144
|
+
model.tags.includes("conversational")
|
|
145
|
+
) {
|
|
146
|
+
templateName = opts?.streaming ? "conversationalStream" : "conversational";
|
|
147
|
+
inputPreparationFn = prepareConversationalInput;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
/// Prepare inputs + make request
|
|
151
|
+
const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
|
|
152
|
+
const request = makeRequestOptionsFromResolvedModel(
|
|
153
|
+
providerModelId ?? model.id,
|
|
154
|
+
{ accessToken: accessToken, provider: provider, ...inputs } as RequestArgs,
|
|
155
|
+
{ chatCompletion: templateName.includes("conversational"), task: model.pipeline_tag as InferenceTask }
|
|
156
|
+
);
|
|
157
|
+
|
|
158
|
+
/// Parse request.info.body if not a binary.
|
|
159
|
+
/// This is the body sent to the provider. Important for snippets with raw payload (e.g curl, requests, etc.)
|
|
160
|
+
let providerInputs = inputs;
|
|
161
|
+
const bodyAsObj = request.info.body;
|
|
162
|
+
if (typeof bodyAsObj === "string") {
|
|
163
|
+
try {
|
|
164
|
+
providerInputs = JSON.parse(bodyAsObj);
|
|
165
|
+
} catch (e) {
|
|
166
|
+
console.error("Failed to parse body as JSON", e);
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
/// Prepare template injection data
|
|
171
|
+
const params: TemplateParams = {
|
|
172
|
+
accessToken,
|
|
173
|
+
authorizationHeader: (request.info.headers as Record<string, string>)?.Authorization,
|
|
174
|
+
baseUrl: removeSuffix(request.url, "/chat/completions"),
|
|
175
|
+
fullUrl: request.url,
|
|
176
|
+
inputs: {
|
|
177
|
+
asObj: inputs,
|
|
178
|
+
asCurlString: formatBody(inputs, "curl"),
|
|
179
|
+
asJsonString: formatBody(inputs, "json"),
|
|
180
|
+
asPythonString: formatBody(inputs, "python"),
|
|
181
|
+
asTsString: formatBody(inputs, "ts"),
|
|
182
|
+
},
|
|
183
|
+
providerInputs: {
|
|
184
|
+
asObj: providerInputs,
|
|
185
|
+
asCurlString: formatBody(providerInputs, "curl"),
|
|
186
|
+
asJsonString: formatBody(providerInputs, "json"),
|
|
187
|
+
asPythonString: formatBody(providerInputs, "python"),
|
|
188
|
+
asTsString: formatBody(providerInputs, "ts"),
|
|
189
|
+
},
|
|
190
|
+
model,
|
|
191
|
+
provider,
|
|
192
|
+
providerModelId: providerModelId ?? model.id,
|
|
193
|
+
};
|
|
194
|
+
|
|
195
|
+
/// Iterate over clients => check if a snippet exists => generate
|
|
196
|
+
return inferenceSnippetLanguages
|
|
197
|
+
.map((language) => {
|
|
198
|
+
return CLIENTS[language]
|
|
199
|
+
.map((client) => {
|
|
200
|
+
if (!hasTemplate(language, client, templateName)) {
|
|
201
|
+
return;
|
|
202
|
+
}
|
|
203
|
+
const template = loadTemplate(language, client, templateName);
|
|
204
|
+
if (client === "huggingface_hub" && templateName.includes("basic")) {
|
|
205
|
+
if (!(model.pipeline_tag && model.pipeline_tag in HF_PYTHON_METHODS)) {
|
|
206
|
+
return;
|
|
207
|
+
}
|
|
208
|
+
params["methodName"] = HF_PYTHON_METHODS[model.pipeline_tag];
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
if (client === "huggingface.js" && templateName.includes("basic")) {
|
|
212
|
+
if (!(model.pipeline_tag && model.pipeline_tag in HF_JS_METHODS)) {
|
|
213
|
+
return;
|
|
214
|
+
}
|
|
215
|
+
params["methodName"] = HF_JS_METHODS[model.pipeline_tag];
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
/// Generate snippet
|
|
219
|
+
let snippet = template(params).trim();
|
|
220
|
+
if (!snippet) {
|
|
221
|
+
return;
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
/// Add import section separately
|
|
225
|
+
if (client === "huggingface_hub") {
|
|
226
|
+
const importSection = snippetImportPythonInferenceClient({ ...params });
|
|
227
|
+
snippet = `${importSection}\n\n${snippet}`;
|
|
228
|
+
} else if (client === "requests") {
|
|
229
|
+
const importSection = snippetImportRequests({
|
|
230
|
+
...params,
|
|
231
|
+
importBase64: snippet.includes("base64"),
|
|
232
|
+
importJson: snippet.includes("json."),
|
|
233
|
+
});
|
|
234
|
+
snippet = `${importSection}\n\n${snippet}`;
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
/// Snippet is ready!
|
|
238
|
+
return { language, client: client as string, content: snippet };
|
|
239
|
+
})
|
|
240
|
+
.filter((snippet): snippet is InferenceSnippet => snippet !== undefined);
|
|
241
|
+
})
|
|
242
|
+
.flat();
|
|
243
|
+
};
|
|
244
|
+
};
|
|
245
|
+
|
|
246
|
+
const prepareDocumentQuestionAnsweringInput = (model: ModelDataMinimal): object => {
|
|
247
|
+
return JSON.parse(getModelInputSnippet(model) as string);
|
|
248
|
+
};
|
|
249
|
+
|
|
250
|
+
const prepareImageToImageInput = (model: ModelDataMinimal): object => {
|
|
251
|
+
const data = JSON.parse(getModelInputSnippet(model) as string);
|
|
252
|
+
return { inputs: data.image, parameters: { prompt: data.prompt } };
|
|
253
|
+
};
|
|
254
|
+
|
|
255
|
+
const prepareConversationalInput = (
|
|
256
|
+
model: ModelDataMinimal,
|
|
257
|
+
opts?: {
|
|
258
|
+
streaming?: boolean;
|
|
259
|
+
messages?: ChatCompletionInputMessage[];
|
|
260
|
+
temperature?: GenerationParameters["temperature"];
|
|
261
|
+
max_tokens?: GenerationParameters["max_new_tokens"];
|
|
262
|
+
top_p?: GenerationParameters["top_p"];
|
|
263
|
+
}
|
|
264
|
+
): object => {
|
|
265
|
+
return {
|
|
266
|
+
messages: opts?.messages ?? getModelInputSnippet(model),
|
|
267
|
+
...(opts?.temperature ? { temperature: opts?.temperature } : undefined),
|
|
268
|
+
max_tokens: opts?.max_tokens ?? 500,
|
|
269
|
+
...(opts?.top_p ? { top_p: opts?.top_p } : undefined),
|
|
270
|
+
};
|
|
271
|
+
};
|
|
272
|
+
|
|
273
|
+
const snippets: Partial<
|
|
274
|
+
Record<
|
|
275
|
+
PipelineType,
|
|
276
|
+
(
|
|
277
|
+
model: ModelDataMinimal,
|
|
278
|
+
accessToken: string,
|
|
279
|
+
provider: InferenceProvider,
|
|
280
|
+
providerModelId?: string,
|
|
281
|
+
opts?: Record<string, unknown>
|
|
282
|
+
) => InferenceSnippet[]
|
|
283
|
+
>
|
|
284
|
+
> = {
|
|
285
|
+
"audio-classification": snippetGenerator("basicAudio"),
|
|
286
|
+
"audio-to-audio": snippetGenerator("basicAudio"),
|
|
287
|
+
"automatic-speech-recognition": snippetGenerator("basicAudio"),
|
|
288
|
+
"document-question-answering": snippetGenerator("documentQuestionAnswering", prepareDocumentQuestionAnsweringInput),
|
|
289
|
+
"feature-extraction": snippetGenerator("basic"),
|
|
290
|
+
"fill-mask": snippetGenerator("basic"),
|
|
291
|
+
"image-classification": snippetGenerator("basicImage"),
|
|
292
|
+
"image-segmentation": snippetGenerator("basicImage"),
|
|
293
|
+
"image-text-to-text": snippetGenerator("conversational"),
|
|
294
|
+
"image-to-image": snippetGenerator("imageToImage", prepareImageToImageInput),
|
|
295
|
+
"image-to-text": snippetGenerator("basicImage"),
|
|
296
|
+
"object-detection": snippetGenerator("basicImage"),
|
|
297
|
+
"question-answering": snippetGenerator("basic"),
|
|
298
|
+
"sentence-similarity": snippetGenerator("basic"),
|
|
299
|
+
summarization: snippetGenerator("basic"),
|
|
300
|
+
"tabular-classification": snippetGenerator("tabular"),
|
|
301
|
+
"tabular-regression": snippetGenerator("tabular"),
|
|
302
|
+
"table-question-answering": snippetGenerator("basic"),
|
|
303
|
+
"text-classification": snippetGenerator("basic"),
|
|
304
|
+
"text-generation": snippetGenerator("basic"),
|
|
305
|
+
"text-to-audio": snippetGenerator("textToAudio"),
|
|
306
|
+
"text-to-image": snippetGenerator("textToImage"),
|
|
307
|
+
"text-to-speech": snippetGenerator("textToAudio"),
|
|
308
|
+
"text-to-video": snippetGenerator("textToVideo"),
|
|
309
|
+
"text2text-generation": snippetGenerator("basic"),
|
|
310
|
+
"token-classification": snippetGenerator("basic"),
|
|
311
|
+
translation: snippetGenerator("basic"),
|
|
312
|
+
"zero-shot-classification": snippetGenerator("zeroShotClassification"),
|
|
313
|
+
"zero-shot-image-classification": snippetGenerator("zeroShotImageClassification"),
|
|
314
|
+
};
|
|
315
|
+
|
|
316
|
+
export function getInferenceSnippets(
|
|
317
|
+
model: ModelDataMinimal,
|
|
318
|
+
accessToken: string,
|
|
319
|
+
provider: InferenceProvider,
|
|
320
|
+
providerModelId?: string,
|
|
321
|
+
opts?: Record<string, unknown>
|
|
322
|
+
): InferenceSnippet[] {
|
|
323
|
+
return model.pipeline_tag && model.pipeline_tag in snippets
|
|
324
|
+
? snippets[model.pipeline_tag]?.(model, accessToken, provider, providerModelId, opts) ?? []
|
|
325
|
+
: [];
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
// String manipulation helpers
|
|
329
|
+
|
|
330
|
+
function formatBody(obj: object, format: "curl" | "json" | "python" | "ts"): string {
|
|
331
|
+
switch (format) {
|
|
332
|
+
case "curl":
|
|
333
|
+
return indentString(formatBody(obj, "json"));
|
|
334
|
+
|
|
335
|
+
case "json":
|
|
336
|
+
/// Hacky: remove outer brackets to make is extendable in templates
|
|
337
|
+
return JSON.stringify(obj, null, 4).split("\n").slice(1, -1).join("\n");
|
|
338
|
+
|
|
339
|
+
case "python":
|
|
340
|
+
return indentString(
|
|
341
|
+
Object.entries(obj)
|
|
342
|
+
.map(([key, value]) => {
|
|
343
|
+
const formattedValue = JSON.stringify(value, null, 4).replace(/"/g, '"');
|
|
344
|
+
return `${key}=${formattedValue},`;
|
|
345
|
+
})
|
|
346
|
+
.join("\n")
|
|
347
|
+
);
|
|
348
|
+
|
|
349
|
+
case "ts":
|
|
350
|
+
/// Hacky: remove outer brackets to make is extendable in templates
|
|
351
|
+
return formatTsObject(obj).split("\n").slice(1, -1).join("\n");
|
|
352
|
+
|
|
353
|
+
default:
|
|
354
|
+
throw new Error(`Unsupported format: ${format}`);
|
|
355
|
+
}
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
function formatTsObject(obj: unknown, depth?: number): string {
|
|
359
|
+
depth = depth ?? 0;
|
|
360
|
+
|
|
361
|
+
/// Case int, boolean, string, etc.
|
|
362
|
+
if (typeof obj !== "object" || obj === null) {
|
|
363
|
+
return JSON.stringify(obj);
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
/// Case array
|
|
367
|
+
if (Array.isArray(obj)) {
|
|
368
|
+
const items = obj
|
|
369
|
+
.map((item) => {
|
|
370
|
+
const formatted = formatTsObject(item, depth + 1);
|
|
371
|
+
return `${" ".repeat(4 * (depth + 1))}${formatted},`;
|
|
372
|
+
})
|
|
373
|
+
.join("\n");
|
|
374
|
+
return `[\n${items}\n${" ".repeat(4 * depth)}]`;
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
/// Case mapping
|
|
378
|
+
const entries = Object.entries(obj);
|
|
379
|
+
const lines = entries
|
|
380
|
+
.map(([key, value]) => {
|
|
381
|
+
const formattedValue = formatTsObject(value, depth + 1);
|
|
382
|
+
const keyStr = /^[a-zA-Z_$][a-zA-Z0-9_$]*$/.test(key) ? key : `"${key}"`;
|
|
383
|
+
return `${" ".repeat(4 * (depth + 1))}${keyStr}: ${formattedValue},`;
|
|
384
|
+
})
|
|
385
|
+
.join("\n");
|
|
386
|
+
return `{\n${lines}\n${" ".repeat(4 * depth)}}`;
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
function indentString(str: string): string {
|
|
390
|
+
return str
|
|
391
|
+
.split("\n")
|
|
392
|
+
.map((line) => " ".repeat(4) + line)
|
|
393
|
+
.join("\n");
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
function removeSuffix(str: string, suffix: string) {
|
|
397
|
+
return str.endsWith(suffix) ? str.slice(0, -suffix.length) : str;
|
|
398
|
+
}
|
package/src/snippets/index.ts
CHANGED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
async function query(data) {
|
|
2
|
+
const response = await fetch(
|
|
3
|
+
"{{ fullUrl }}",
|
|
4
|
+
{
|
|
5
|
+
headers: {
|
|
6
|
+
Authorization: "{{ authorizationHeader }}",
|
|
7
|
+
"Content-Type": "application/json",
|
|
8
|
+
},
|
|
9
|
+
method: "POST",
|
|
10
|
+
body: JSON.stringify(data),
|
|
11
|
+
}
|
|
12
|
+
);
|
|
13
|
+
const result = await response.json();
|
|
14
|
+
return result;
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
query({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {
|
|
18
|
+
console.log(JSON.stringify(response));
|
|
19
|
+
});
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
async function query(data) {
|
|
2
|
+
const response = await fetch(
|
|
3
|
+
"{{ fullUrl }}",
|
|
4
|
+
{
|
|
5
|
+
headers: {
|
|
6
|
+
Authorization: "{{ authorizationHeader }}",
|
|
7
|
+
"Content-Type": "audio/flac"
|
|
8
|
+
},
|
|
9
|
+
method: "POST",
|
|
10
|
+
body: JSON.stringify(data),
|
|
11
|
+
}
|
|
12
|
+
);
|
|
13
|
+
const result = await response.json();
|
|
14
|
+
return result;
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
query({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {
|
|
18
|
+
console.log(JSON.stringify(response));
|
|
19
|
+
});
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
async function query(data) {
|
|
2
|
+
const response = await fetch(
|
|
3
|
+
"{{ fullUrl }}",
|
|
4
|
+
{
|
|
5
|
+
headers: {
|
|
6
|
+
Authorization: "{{ authorizationHeader }}",
|
|
7
|
+
"Content-Type": "image/jpeg"
|
|
8
|
+
},
|
|
9
|
+
method: "POST",
|
|
10
|
+
body: JSON.stringify(data),
|
|
11
|
+
}
|
|
12
|
+
);
|
|
13
|
+
const result = await response.json();
|
|
14
|
+
return result;
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
query({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {
|
|
18
|
+
console.log(JSON.stringify(response));
|
|
19
|
+
});
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
{% if model.library_name == "transformers" %}
|
|
2
|
+
async function query(data) {
|
|
3
|
+
const response = await fetch(
|
|
4
|
+
"{{ fullUrl }}",
|
|
5
|
+
{
|
|
6
|
+
headers: {
|
|
7
|
+
Authorization: "{{ authorizationHeader }}",
|
|
8
|
+
"Content-Type": "application/json",
|
|
9
|
+
},
|
|
10
|
+
method: "POST",
|
|
11
|
+
body: JSON.stringify(data),
|
|
12
|
+
}
|
|
13
|
+
);
|
|
14
|
+
const result = await response.blob();
|
|
15
|
+
return result;
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
query({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {
|
|
19
|
+
// Returns a byte object of the Audio wavform. Use it directly!
|
|
20
|
+
});
|
|
21
|
+
{% else %}
|
|
22
|
+
async function query(data) {
|
|
23
|
+
const response = await fetch(
|
|
24
|
+
"{{ fullUrl }}",
|
|
25
|
+
{
|
|
26
|
+
headers: {
|
|
27
|
+
Authorization: "{{ authorizationHeader }}",
|
|
28
|
+
"Content-Type": "application/json",
|
|
29
|
+
},
|
|
30
|
+
method: "POST",
|
|
31
|
+
body: JSON.stringify(data),
|
|
32
|
+
}
|
|
33
|
+
);
|
|
34
|
+
const result = await response.json();
|
|
35
|
+
return result;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
query({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {
|
|
39
|
+
console.log(JSON.stringify(response));
|
|
40
|
+
});
|
|
41
|
+
{% endif %}
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
async function query(data) {
|
|
2
|
+
const response = await fetch(
|
|
3
|
+
"{{ fullUrl }}",
|
|
4
|
+
{
|
|
5
|
+
headers: {
|
|
6
|
+
Authorization: "{{ authorizationHeader }}",
|
|
7
|
+
"Content-Type": "application/json",
|
|
8
|
+
},
|
|
9
|
+
method: "POST",
|
|
10
|
+
body: JSON.stringify(data),
|
|
11
|
+
}
|
|
12
|
+
);
|
|
13
|
+
const result = await response.blob();
|
|
14
|
+
return result;
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
query({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {
|
|
18
|
+
// Use image
|
|
19
|
+
});
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
async function query(data) {
|
|
2
|
+
const response = await fetch(
|
|
3
|
+
"{{ fullUrl }}",
|
|
4
|
+
{
|
|
5
|
+
headers: {
|
|
6
|
+
Authorization: "{{ authorizationHeader }}",
|
|
7
|
+
"Content-Type": "application/json",
|
|
8
|
+
},
|
|
9
|
+
method: "POST",
|
|
10
|
+
body: JSON.stringify(data),
|
|
11
|
+
}
|
|
12
|
+
);
|
|
13
|
+
const result = await response.json();
|
|
14
|
+
return result;
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
query({
|
|
18
|
+
inputs: {{ providerInputs.asObj.inputs }},
|
|
19
|
+
parameters: { candidate_labels: ["refund", "legal", "faq"] }
|
|
20
|
+
}).then((response) => {
|
|
21
|
+
console.log(JSON.stringify(response));
|
|
22
|
+
});
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import { InferenceClient } from "@huggingface/inference";
|
|
2
|
+
|
|
3
|
+
const client = new InferenceClient("{{ accessToken }}");
|
|
4
|
+
|
|
5
|
+
const output = await client.{{ methodName }}({
|
|
6
|
+
model: "{{ model.id }}",
|
|
7
|
+
inputs: {{ inputs.asObj.inputs }},
|
|
8
|
+
provider: "{{ provider }}",
|
|
9
|
+
});
|
|
10
|
+
|
|
11
|
+
console.log(output);
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import { InferenceClient } from "@huggingface/inference";
|
|
2
|
+
|
|
3
|
+
const client = new InferenceClient("{{ accessToken }}");
|
|
4
|
+
|
|
5
|
+
const data = fs.readFileSync({{inputs.asObj.inputs}});
|
|
6
|
+
|
|
7
|
+
const output = await client.{{ methodName }}({
|
|
8
|
+
data,
|
|
9
|
+
model: "{{ model.id }}",
|
|
10
|
+
provider: "{{ provider }}",
|
|
11
|
+
});
|
|
12
|
+
|
|
13
|
+
console.log(output);
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import { InferenceClient } from "@huggingface/inference";
|
|
2
|
+
|
|
3
|
+
const client = new InferenceClient("{{ accessToken }}");
|
|
4
|
+
|
|
5
|
+
const data = fs.readFileSync({{inputs.asObj.inputs}});
|
|
6
|
+
|
|
7
|
+
const output = await client.{{ methodName }}({
|
|
8
|
+
data,
|
|
9
|
+
model: "{{ model.id }}",
|
|
10
|
+
provider: "{{ provider }}",
|
|
11
|
+
});
|
|
12
|
+
|
|
13
|
+
console.log(output);
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import { InferenceClient } from "@huggingface/inference";
|
|
2
|
+
|
|
3
|
+
const client = new InferenceClient("{{ accessToken }}");
|
|
4
|
+
|
|
5
|
+
const chatCompletion = await client.chatCompletion({
|
|
6
|
+
provider: "{{ provider }}",
|
|
7
|
+
model: "{{ model.id }}",
|
|
8
|
+
{{ inputs.asTsString }}
|
|
9
|
+
});
|
|
10
|
+
|
|
11
|
+
console.log(chatCompletion.choices[0].message);
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import { InferenceClient } from "@huggingface/inference";
|
|
2
|
+
|
|
3
|
+
const client = new InferenceClient("{{ accessToken }}");
|
|
4
|
+
|
|
5
|
+
let out = "";
|
|
6
|
+
|
|
7
|
+
const stream = await client.chatCompletionStream({
|
|
8
|
+
provider: "{{ provider }}",
|
|
9
|
+
model: "{{ model.id }}",
|
|
10
|
+
{{ inputs.asTsString }}
|
|
11
|
+
});
|
|
12
|
+
|
|
13
|
+
for await (const chunk of stream) {
|
|
14
|
+
if (chunk.choices && chunk.choices.length > 0) {
|
|
15
|
+
const newContent = chunk.choices[0].delta.content;
|
|
16
|
+
out += newContent;
|
|
17
|
+
console.log(newContent);
|
|
18
|
+
}
|
|
19
|
+
}
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import { InferenceClient } from "@huggingface/inference";
|
|
2
|
+
|
|
3
|
+
const client = new InferenceClient("{{ accessToken }}");
|
|
4
|
+
|
|
5
|
+
const image = await client.textToImage({
|
|
6
|
+
provider: "{{ provider }}",
|
|
7
|
+
model: "{{ model.id }}",
|
|
8
|
+
inputs: {{ inputs.asObj.inputs }},
|
|
9
|
+
parameters: { num_inference_steps: 5 },
|
|
10
|
+
});
|
|
11
|
+
/// Use the generated image (it's a Blob)
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
import { InferenceClient } from "@huggingface/inference";
|
|
2
|
+
|
|
3
|
+
const client = new InferenceClient("{{ accessToken }}");
|
|
4
|
+
|
|
5
|
+
const image = await client.textToVideo({
|
|
6
|
+
provider: "{{ provider }}",
|
|
7
|
+
model: "{{ model.id }}",
|
|
8
|
+
inputs: {{ inputs.asObj.inputs }},
|
|
9
|
+
});
|
|
10
|
+
// Use the generated video (it's a Blob)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import { OpenAI } from "openai";
|
|
2
|
+
|
|
3
|
+
const client = new OpenAI({
|
|
4
|
+
baseURL: "{{ baseUrl }}",
|
|
5
|
+
apiKey: "{{ accessToken }}",
|
|
6
|
+
});
|
|
7
|
+
|
|
8
|
+
const chatCompletion = await client.chat.completions.create({
|
|
9
|
+
model: "{{ providerModelId }}",
|
|
10
|
+
{{ inputs.asTsString }}
|
|
11
|
+
});
|
|
12
|
+
|
|
13
|
+
console.log(chatCompletion.choices[0].message);
|