@huggingface/inference 3.7.0 → 3.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.cjs +1152 -839
- package/dist/index.js +1154 -841
- package/dist/src/lib/getProviderHelper.d.ts +37 -0
- package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
- package/dist/src/lib/makeRequestOptions.d.ts +0 -2
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/black-forest-labs.d.ts +14 -18
- package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
- package/dist/src/providers/cerebras.d.ts +4 -2
- package/dist/src/providers/cerebras.d.ts.map +1 -1
- package/dist/src/providers/cohere.d.ts +5 -2
- package/dist/src/providers/cohere.d.ts.map +1 -1
- package/dist/src/providers/fal-ai.d.ts +50 -3
- package/dist/src/providers/fal-ai.d.ts.map +1 -1
- package/dist/src/providers/fireworks-ai.d.ts +5 -2
- package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
- package/dist/src/providers/hf-inference.d.ts +125 -2
- package/dist/src/providers/hf-inference.d.ts.map +1 -1
- package/dist/src/providers/hyperbolic.d.ts +31 -2
- package/dist/src/providers/hyperbolic.d.ts.map +1 -1
- package/dist/src/providers/nebius.d.ts +20 -18
- package/dist/src/providers/nebius.d.ts.map +1 -1
- package/dist/src/providers/novita.d.ts +21 -18
- package/dist/src/providers/novita.d.ts.map +1 -1
- package/dist/src/providers/openai.d.ts +4 -2
- package/dist/src/providers/openai.d.ts.map +1 -1
- package/dist/src/providers/providerHelper.d.ts +182 -0
- package/dist/src/providers/providerHelper.d.ts.map +1 -0
- package/dist/src/providers/replicate.d.ts +23 -19
- package/dist/src/providers/replicate.d.ts.map +1 -1
- package/dist/src/providers/sambanova.d.ts +4 -2
- package/dist/src/providers/sambanova.d.ts.map +1 -1
- package/dist/src/providers/together.d.ts +32 -2
- package/dist/src/providers/together.d.ts.map +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/audio/utils.d.ts +2 -1
- package/dist/src/tasks/audio/utils.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts +0 -2
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts +0 -2
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/index.d.ts +6 -6
- package/dist/src/tasks/index.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
- package/dist/src/types.d.ts +3 -13
- package/dist/src/types.d.ts.map +1 -1
- package/package.json +3 -3
- package/src/lib/getProviderHelper.ts +270 -0
- package/src/lib/makeRequestOptions.ts +34 -91
- package/src/providers/black-forest-labs.ts +73 -22
- package/src/providers/cerebras.ts +6 -27
- package/src/providers/cohere.ts +9 -28
- package/src/providers/fal-ai.ts +195 -77
- package/src/providers/fireworks-ai.ts +8 -29
- package/src/providers/hf-inference.ts +555 -34
- package/src/providers/hyperbolic.ts +107 -29
- package/src/providers/nebius.ts +65 -29
- package/src/providers/novita.ts +68 -32
- package/src/providers/openai.ts +6 -32
- package/src/providers/providerHelper.ts +354 -0
- package/src/providers/replicate.ts +124 -34
- package/src/providers/sambanova.ts +5 -30
- package/src/providers/together.ts +92 -28
- package/src/snippets/getInferenceSnippets.ts +16 -9
- package/src/snippets/templates.exported.ts +1 -1
- package/src/tasks/audio/audioClassification.ts +4 -7
- package/src/tasks/audio/audioToAudio.ts +3 -26
- package/src/tasks/audio/automaticSpeechRecognition.ts +4 -3
- package/src/tasks/audio/textToSpeech.ts +5 -29
- package/src/tasks/audio/utils.ts +2 -1
- package/src/tasks/custom/request.ts +0 -2
- package/src/tasks/custom/streamingRequest.ts +0 -2
- package/src/tasks/cv/imageClassification.ts +3 -7
- package/src/tasks/cv/imageSegmentation.ts +3 -8
- package/src/tasks/cv/imageToImage.ts +3 -6
- package/src/tasks/cv/imageToText.ts +3 -6
- package/src/tasks/cv/objectDetection.ts +3 -18
- package/src/tasks/cv/textToImage.ts +9 -137
- package/src/tasks/cv/textToVideo.ts +11 -62
- package/src/tasks/cv/zeroShotImageClassification.ts +3 -7
- package/src/tasks/index.ts +6 -6
- package/src/tasks/multimodal/documentQuestionAnswering.ts +3 -19
- package/src/tasks/multimodal/visualQuestionAnswering.ts +3 -11
- package/src/tasks/nlp/chatCompletion.ts +5 -20
- package/src/tasks/nlp/chatCompletionStream.ts +1 -2
- package/src/tasks/nlp/featureExtraction.ts +3 -18
- package/src/tasks/nlp/fillMask.ts +3 -16
- package/src/tasks/nlp/questionAnswering.ts +3 -22
- package/src/tasks/nlp/sentenceSimilarity.ts +3 -7
- package/src/tasks/nlp/summarization.ts +3 -6
- package/src/tasks/nlp/tableQuestionAnswering.ts +3 -27
- package/src/tasks/nlp/textClassification.ts +3 -8
- package/src/tasks/nlp/textGeneration.ts +12 -79
- package/src/tasks/nlp/tokenClassification.ts +3 -18
- package/src/tasks/nlp/translation.ts +3 -6
- package/src/tasks/nlp/zeroShotClassification.ts +3 -16
- package/src/tasks/tabular/tabularClassification.ts +3 -6
- package/src/tasks/tabular/tabularRegression.ts +3 -6
- package/src/types.ts +3 -14
package/dist/index.js
CHANGED
|
@@ -41,442 +41,1088 @@ __export(tasks_exports, {
|
|
|
41
41
|
zeroShotImageClassification: () => zeroShotImageClassification
|
|
42
42
|
});
|
|
43
43
|
|
|
44
|
+
// package.json
|
|
45
|
+
var name = "@huggingface/inference";
|
|
46
|
+
var version = "3.7.1";
|
|
47
|
+
|
|
44
48
|
// src/config.ts
|
|
45
49
|
var HF_HUB_URL = "https://huggingface.co";
|
|
46
50
|
var HF_ROUTER_URL = "https://router.huggingface.co";
|
|
47
51
|
var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
|
|
48
52
|
|
|
49
|
-
// src/
|
|
50
|
-
var
|
|
51
|
-
|
|
52
|
-
|
|
53
|
+
// src/lib/InferenceOutputError.ts
|
|
54
|
+
var InferenceOutputError = class extends TypeError {
|
|
55
|
+
constructor(message) {
|
|
56
|
+
super(
|
|
57
|
+
`Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.`
|
|
58
|
+
);
|
|
59
|
+
this.name = "InferenceOutputError";
|
|
60
|
+
}
|
|
53
61
|
};
|
|
54
|
-
|
|
55
|
-
|
|
62
|
+
|
|
63
|
+
// src/utils/delay.ts
|
|
64
|
+
function delay(ms) {
|
|
65
|
+
return new Promise((resolve) => {
|
|
66
|
+
setTimeout(() => resolve(), ms);
|
|
67
|
+
});
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
// src/utils/pick.ts
|
|
71
|
+
function pick(o, props) {
|
|
72
|
+
return Object.assign(
|
|
73
|
+
{},
|
|
74
|
+
...props.map((prop) => {
|
|
75
|
+
if (o[prop] !== void 0) {
|
|
76
|
+
return { [prop]: o[prop] };
|
|
77
|
+
}
|
|
78
|
+
})
|
|
79
|
+
);
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
// src/utils/typedInclude.ts
|
|
83
|
+
function typedInclude(arr, v) {
|
|
84
|
+
return arr.includes(v);
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
// src/utils/omit.ts
|
|
88
|
+
function omit(o, props) {
|
|
89
|
+
const propsArr = Array.isArray(props) ? props : [props];
|
|
90
|
+
const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
|
|
91
|
+
return pick(o, letsKeep);
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
// src/utils/toArray.ts
|
|
95
|
+
function toArray(obj) {
|
|
96
|
+
if (Array.isArray(obj)) {
|
|
97
|
+
return obj;
|
|
98
|
+
}
|
|
99
|
+
return [obj];
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
// src/providers/providerHelper.ts
|
|
103
|
+
var TaskProviderHelper = class {
|
|
104
|
+
constructor(provider, baseUrl, clientSideRoutingOnly = false) {
|
|
105
|
+
this.provider = provider;
|
|
106
|
+
this.baseUrl = baseUrl;
|
|
107
|
+
this.clientSideRoutingOnly = clientSideRoutingOnly;
|
|
108
|
+
}
|
|
109
|
+
/**
|
|
110
|
+
* Prepare the base URL for the request
|
|
111
|
+
*/
|
|
112
|
+
makeBaseUrl(params) {
|
|
113
|
+
return params.authMethod !== "provider-key" ? `${HF_ROUTER_URL}/${this.provider}` : this.baseUrl;
|
|
114
|
+
}
|
|
115
|
+
/**
|
|
116
|
+
* Prepare the body for the request
|
|
117
|
+
*/
|
|
118
|
+
makeBody(params) {
|
|
119
|
+
if ("data" in params.args && !!params.args.data) {
|
|
120
|
+
return params.args.data;
|
|
121
|
+
}
|
|
122
|
+
return JSON.stringify(this.preparePayload(params));
|
|
123
|
+
}
|
|
124
|
+
/**
|
|
125
|
+
* Prepare the URL for the request
|
|
126
|
+
*/
|
|
127
|
+
makeUrl(params) {
|
|
128
|
+
const baseUrl = this.makeBaseUrl(params);
|
|
129
|
+
const route = this.makeRoute(params).replace(/^\/+/, "");
|
|
130
|
+
return `${baseUrl}/${route}`;
|
|
131
|
+
}
|
|
132
|
+
/**
|
|
133
|
+
* Prepare the headers for the request
|
|
134
|
+
*/
|
|
135
|
+
prepareHeaders(params, isBinary) {
|
|
136
|
+
const headers = { Authorization: `Bearer ${params.accessToken}` };
|
|
137
|
+
if (!isBinary) {
|
|
138
|
+
headers["Content-Type"] = "application/json";
|
|
139
|
+
}
|
|
140
|
+
return headers;
|
|
141
|
+
}
|
|
56
142
|
};
|
|
57
|
-
var
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
}
|
|
61
|
-
|
|
143
|
+
var BaseConversationalTask = class extends TaskProviderHelper {
|
|
144
|
+
constructor(provider, baseUrl, clientSideRoutingOnly = false) {
|
|
145
|
+
super(provider, baseUrl, clientSideRoutingOnly);
|
|
146
|
+
}
|
|
147
|
+
makeRoute() {
|
|
148
|
+
return "v1/chat/completions";
|
|
149
|
+
}
|
|
150
|
+
preparePayload(params) {
|
|
151
|
+
return {
|
|
152
|
+
...params.args,
|
|
153
|
+
model: params.model
|
|
154
|
+
};
|
|
155
|
+
}
|
|
156
|
+
async getResponse(response) {
|
|
157
|
+
if (typeof response === "object" && Array.isArray(response?.choices) && typeof response?.created === "number" && typeof response?.id === "string" && typeof response?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
|
|
158
|
+
(response.system_fingerprint === void 0 || response.system_fingerprint === null || typeof response.system_fingerprint === "string") && typeof response?.usage === "object") {
|
|
159
|
+
return response;
|
|
160
|
+
}
|
|
161
|
+
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
62
162
|
}
|
|
63
163
|
};
|
|
64
|
-
var
|
|
65
|
-
|
|
164
|
+
var BaseTextGenerationTask = class extends TaskProviderHelper {
|
|
165
|
+
constructor(provider, baseUrl, clientSideRoutingOnly = false) {
|
|
166
|
+
super(provider, baseUrl, clientSideRoutingOnly);
|
|
167
|
+
}
|
|
168
|
+
preparePayload(params) {
|
|
169
|
+
return {
|
|
170
|
+
...params.args,
|
|
171
|
+
model: params.model
|
|
172
|
+
};
|
|
173
|
+
}
|
|
174
|
+
makeRoute() {
|
|
175
|
+
return "v1/completions";
|
|
176
|
+
}
|
|
177
|
+
async getResponse(response) {
|
|
178
|
+
const res = toArray(response);
|
|
179
|
+
if (Array.isArray(res) && res.length > 0 && res.every(
|
|
180
|
+
(x) => typeof x === "object" && !!x && "generated_text" in x && typeof x.generated_text === "string"
|
|
181
|
+
)) {
|
|
182
|
+
return res[0];
|
|
183
|
+
}
|
|
184
|
+
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
185
|
+
}
|
|
66
186
|
};
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
187
|
+
|
|
188
|
+
// src/providers/black-forest-labs.ts
|
|
189
|
+
var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
|
|
190
|
+
var BlackForestLabsTextToImageTask = class extends TaskProviderHelper {
|
|
191
|
+
constructor() {
|
|
192
|
+
super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
|
|
193
|
+
}
|
|
194
|
+
preparePayload(params) {
|
|
195
|
+
return {
|
|
196
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
197
|
+
...params.args.parameters,
|
|
198
|
+
prompt: params.args.inputs
|
|
199
|
+
};
|
|
200
|
+
}
|
|
201
|
+
prepareHeaders(params, binary) {
|
|
202
|
+
const headers = {
|
|
203
|
+
Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`
|
|
204
|
+
};
|
|
205
|
+
if (!binary) {
|
|
206
|
+
headers["Content-Type"] = "application/json";
|
|
207
|
+
}
|
|
208
|
+
return headers;
|
|
209
|
+
}
|
|
210
|
+
makeRoute(params) {
|
|
211
|
+
if (!params) {
|
|
212
|
+
throw new Error("Params are required");
|
|
213
|
+
}
|
|
214
|
+
return `/v1/${params.model}`;
|
|
215
|
+
}
|
|
216
|
+
async getResponse(response, url, headers, outputType) {
|
|
217
|
+
const urlObj = new URL(response.polling_url);
|
|
218
|
+
for (let step = 0; step < 5; step++) {
|
|
219
|
+
await delay(1e3);
|
|
220
|
+
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
|
|
221
|
+
urlObj.searchParams.set("attempt", step.toString(10));
|
|
222
|
+
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
|
|
223
|
+
if (!resp.ok) {
|
|
224
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
225
|
+
}
|
|
226
|
+
const payload = await resp.json();
|
|
227
|
+
if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
|
|
228
|
+
if (outputType === "url") {
|
|
229
|
+
return payload.result.sample;
|
|
230
|
+
}
|
|
231
|
+
const image = await fetch(payload.result.sample);
|
|
232
|
+
return await image.blob();
|
|
233
|
+
}
|
|
234
|
+
}
|
|
235
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
236
|
+
}
|
|
72
237
|
};
|
|
73
238
|
|
|
74
239
|
// src/providers/cerebras.ts
|
|
75
|
-
var
|
|
76
|
-
|
|
77
|
-
|
|
240
|
+
var CerebrasConversationalTask = class extends BaseConversationalTask {
|
|
241
|
+
constructor() {
|
|
242
|
+
super("cerebras", "https://api.cerebras.ai");
|
|
243
|
+
}
|
|
78
244
|
};
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
245
|
+
|
|
246
|
+
// src/providers/cohere.ts
|
|
247
|
+
var CohereConversationalTask = class extends BaseConversationalTask {
|
|
248
|
+
constructor() {
|
|
249
|
+
super("cohere", "https://api.cohere.com");
|
|
250
|
+
}
|
|
251
|
+
makeRoute() {
|
|
252
|
+
return "/compatibility/v1/chat/completions";
|
|
253
|
+
}
|
|
254
|
+
};
|
|
255
|
+
|
|
256
|
+
// src/lib/isUrl.ts
|
|
257
|
+
function isUrl(modelOrUrl) {
|
|
258
|
+
return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
// src/providers/fal-ai.ts
|
|
262
|
+
var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
|
|
263
|
+
var FalAITask = class extends TaskProviderHelper {
|
|
264
|
+
constructor(url) {
|
|
265
|
+
super("fal-ai", url || "https://fal.run");
|
|
266
|
+
}
|
|
267
|
+
preparePayload(params) {
|
|
268
|
+
return params.args;
|
|
269
|
+
}
|
|
270
|
+
makeRoute(params) {
|
|
271
|
+
return `/${params.model}`;
|
|
272
|
+
}
|
|
273
|
+
prepareHeaders(params, binary) {
|
|
274
|
+
const headers = {
|
|
275
|
+
Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`
|
|
276
|
+
};
|
|
277
|
+
if (!binary) {
|
|
278
|
+
headers["Content-Type"] = "application/json";
|
|
279
|
+
}
|
|
280
|
+
return headers;
|
|
281
|
+
}
|
|
282
|
+
};
|
|
283
|
+
var FalAITextToImageTask = class extends FalAITask {
|
|
284
|
+
preparePayload(params) {
|
|
285
|
+
return {
|
|
286
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
287
|
+
...params.args.parameters,
|
|
288
|
+
sync_mode: true,
|
|
289
|
+
prompt: params.args.inputs
|
|
290
|
+
};
|
|
291
|
+
}
|
|
292
|
+
async getResponse(response, outputType) {
|
|
293
|
+
if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") {
|
|
294
|
+
if (outputType === "url") {
|
|
295
|
+
return response.images[0].url;
|
|
296
|
+
}
|
|
297
|
+
const urlResponse = await fetch(response.images[0].url);
|
|
298
|
+
return await urlResponse.blob();
|
|
299
|
+
}
|
|
300
|
+
throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
|
|
301
|
+
}
|
|
302
|
+
};
|
|
303
|
+
var FalAITextToVideoTask = class extends FalAITask {
|
|
304
|
+
constructor() {
|
|
305
|
+
super("https://queue.fal.run");
|
|
306
|
+
}
|
|
307
|
+
makeRoute(params) {
|
|
308
|
+
if (params.authMethod !== "provider-key") {
|
|
309
|
+
return `/${params.model}?_subdomain=queue`;
|
|
310
|
+
}
|
|
311
|
+
return `/${params.model}`;
|
|
312
|
+
}
|
|
313
|
+
preparePayload(params) {
|
|
314
|
+
return {
|
|
315
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
316
|
+
...params.args.parameters,
|
|
317
|
+
prompt: params.args.inputs
|
|
318
|
+
};
|
|
319
|
+
}
|
|
320
|
+
async getResponse(response, url, headers) {
|
|
321
|
+
if (!url || !headers) {
|
|
322
|
+
throw new InferenceOutputError("URL and headers are required for text-to-video task");
|
|
323
|
+
}
|
|
324
|
+
const requestId = response.request_id;
|
|
325
|
+
if (!requestId) {
|
|
326
|
+
throw new InferenceOutputError("No request ID found in the response");
|
|
327
|
+
}
|
|
328
|
+
let status = response.status;
|
|
329
|
+
const parsedUrl = new URL(url);
|
|
330
|
+
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
|
|
331
|
+
const modelId = new URL(response.response_url).pathname;
|
|
332
|
+
const queryParams = parsedUrl.search;
|
|
333
|
+
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
|
|
334
|
+
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
|
|
335
|
+
while (status !== "COMPLETED") {
|
|
336
|
+
await delay(500);
|
|
337
|
+
const statusResponse = await fetch(statusUrl, { headers });
|
|
338
|
+
if (!statusResponse.ok) {
|
|
339
|
+
throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
|
|
340
|
+
}
|
|
341
|
+
try {
|
|
342
|
+
status = (await statusResponse.json()).status;
|
|
343
|
+
} catch (error) {
|
|
344
|
+
throw new InferenceOutputError("Failed to parse status response from fal-ai API");
|
|
345
|
+
}
|
|
346
|
+
}
|
|
347
|
+
const resultResponse = await fetch(resultUrl, { headers });
|
|
348
|
+
let result;
|
|
349
|
+
try {
|
|
350
|
+
result = await resultResponse.json();
|
|
351
|
+
} catch (error) {
|
|
352
|
+
throw new InferenceOutputError("Failed to parse result response from fal-ai API");
|
|
353
|
+
}
|
|
354
|
+
if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
|
|
355
|
+
const urlResponse = await fetch(result.video.url);
|
|
356
|
+
return await urlResponse.blob();
|
|
357
|
+
} else {
|
|
358
|
+
throw new InferenceOutputError(
|
|
359
|
+
"Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
|
|
360
|
+
);
|
|
361
|
+
}
|
|
362
|
+
}
|
|
84
363
|
};
|
|
85
|
-
var
|
|
86
|
-
|
|
364
|
+
var FalAIAutomaticSpeechRecognitionTask = class extends FalAITask {
|
|
365
|
+
prepareHeaders(params, binary) {
|
|
366
|
+
const headers = super.prepareHeaders(params, binary);
|
|
367
|
+
headers["Content-Type"] = "application/json";
|
|
368
|
+
return headers;
|
|
369
|
+
}
|
|
370
|
+
async getResponse(response) {
|
|
371
|
+
const res = response;
|
|
372
|
+
if (typeof res?.text !== "string") {
|
|
373
|
+
throw new InferenceOutputError(
|
|
374
|
+
`Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
|
|
375
|
+
);
|
|
376
|
+
}
|
|
377
|
+
return { text: res.text };
|
|
378
|
+
}
|
|
87
379
|
};
|
|
88
|
-
var
|
|
89
|
-
|
|
380
|
+
var FalAITextToSpeechTask = class extends FalAITask {
|
|
381
|
+
preparePayload(params) {
|
|
382
|
+
return {
|
|
383
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
384
|
+
...params.args.parameters,
|
|
385
|
+
lyrics: params.args.inputs
|
|
386
|
+
};
|
|
387
|
+
}
|
|
388
|
+
async getResponse(response) {
|
|
389
|
+
const res = response;
|
|
390
|
+
if (typeof res?.audio?.url !== "string") {
|
|
391
|
+
throw new InferenceOutputError(
|
|
392
|
+
`Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
|
|
393
|
+
);
|
|
394
|
+
}
|
|
395
|
+
try {
|
|
396
|
+
const urlResponse = await fetch(res.audio.url);
|
|
397
|
+
if (!urlResponse.ok) {
|
|
398
|
+
throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
|
|
399
|
+
}
|
|
400
|
+
return await urlResponse.blob();
|
|
401
|
+
} catch (error) {
|
|
402
|
+
throw new InferenceOutputError(
|
|
403
|
+
`Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)}`
|
|
404
|
+
);
|
|
405
|
+
}
|
|
406
|
+
}
|
|
90
407
|
};
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
408
|
+
|
|
409
|
+
// src/providers/fireworks-ai.ts
|
|
410
|
+
var FireworksConversationalTask = class extends BaseConversationalTask {
|
|
411
|
+
constructor() {
|
|
412
|
+
super("fireworks-ai", "https://api.fireworks.ai");
|
|
413
|
+
}
|
|
414
|
+
makeRoute() {
|
|
415
|
+
return "/inference/v1/chat/completions";
|
|
416
|
+
}
|
|
96
417
|
};
|
|
97
418
|
|
|
98
|
-
// src/providers/
|
|
99
|
-
var
|
|
100
|
-
|
|
101
|
-
|
|
419
|
+
// src/providers/hf-inference.ts
|
|
420
|
+
var HFInferenceTask = class extends TaskProviderHelper {
|
|
421
|
+
constructor() {
|
|
422
|
+
super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
|
|
423
|
+
}
|
|
424
|
+
preparePayload(params) {
|
|
425
|
+
return params.args;
|
|
426
|
+
}
|
|
427
|
+
makeUrl(params) {
|
|
428
|
+
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
|
|
429
|
+
return params.model;
|
|
430
|
+
}
|
|
431
|
+
return super.makeUrl(params);
|
|
432
|
+
}
|
|
433
|
+
makeRoute(params) {
|
|
434
|
+
if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
|
|
435
|
+
return `pipeline/${params.task}/${params.model}`;
|
|
436
|
+
}
|
|
437
|
+
return `models/${params.model}`;
|
|
438
|
+
}
|
|
439
|
+
async getResponse(response) {
|
|
440
|
+
return response;
|
|
441
|
+
}
|
|
102
442
|
};
|
|
103
|
-
var
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
443
|
+
var HFInferenceTextToImageTask = class extends HFInferenceTask {
|
|
444
|
+
async getResponse(response, url, headers, outputType) {
|
|
445
|
+
if (!response) {
|
|
446
|
+
throw new InferenceOutputError("response is undefined");
|
|
447
|
+
}
|
|
448
|
+
if (typeof response == "object") {
|
|
449
|
+
if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
|
|
450
|
+
const base64Data = response.data[0].b64_json;
|
|
451
|
+
if (outputType === "url") {
|
|
452
|
+
return `data:image/jpeg;base64,${base64Data}`;
|
|
453
|
+
}
|
|
454
|
+
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
455
|
+
return await base64Response.blob();
|
|
456
|
+
}
|
|
457
|
+
if ("output" in response && Array.isArray(response.output)) {
|
|
458
|
+
if (outputType === "url") {
|
|
459
|
+
return response.output[0];
|
|
460
|
+
}
|
|
461
|
+
const urlResponse = await fetch(response.output[0]);
|
|
462
|
+
const blob = await urlResponse.blob();
|
|
463
|
+
return blob;
|
|
464
|
+
}
|
|
465
|
+
}
|
|
466
|
+
if (response instanceof Blob) {
|
|
467
|
+
if (outputType === "url") {
|
|
468
|
+
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
|
|
469
|
+
return `data:image/jpeg;base64,${b64}`;
|
|
470
|
+
}
|
|
471
|
+
return response;
|
|
472
|
+
}
|
|
473
|
+
throw new InferenceOutputError("Expected a Blob ");
|
|
474
|
+
}
|
|
475
|
+
};
|
|
476
|
+
var HFInferenceConversationalTask = class extends HFInferenceTask {
|
|
477
|
+
makeUrl(params) {
|
|
478
|
+
let url;
|
|
479
|
+
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
|
|
480
|
+
url = params.model.trim();
|
|
481
|
+
} else {
|
|
482
|
+
url = `${this.makeBaseUrl(params)}/models/${params.model}`;
|
|
483
|
+
}
|
|
484
|
+
url = url.replace(/\/+$/, "");
|
|
485
|
+
if (url.endsWith("/v1")) {
|
|
486
|
+
url += "/chat/completions";
|
|
487
|
+
} else if (!url.endsWith("/chat/completions")) {
|
|
488
|
+
url += "/v1/chat/completions";
|
|
489
|
+
}
|
|
490
|
+
return url;
|
|
491
|
+
}
|
|
492
|
+
preparePayload(params) {
|
|
493
|
+
return {
|
|
494
|
+
...params.args,
|
|
495
|
+
model: params.model
|
|
496
|
+
};
|
|
497
|
+
}
|
|
498
|
+
async getResponse(response) {
|
|
499
|
+
return response;
|
|
500
|
+
}
|
|
108
501
|
};
|
|
109
|
-
var
|
|
110
|
-
|
|
502
|
+
var HFInferenceTextGenerationTask = class extends HFInferenceTask {
|
|
503
|
+
async getResponse(response) {
|
|
504
|
+
const res = toArray(response);
|
|
505
|
+
if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
|
|
506
|
+
return res?.[0];
|
|
507
|
+
}
|
|
508
|
+
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
509
|
+
}
|
|
111
510
|
};
|
|
112
|
-
var
|
|
113
|
-
|
|
511
|
+
var HFInferenceAudioClassificationTask = class extends HFInferenceTask {
|
|
512
|
+
async getResponse(response) {
|
|
513
|
+
if (Array.isArray(response) && response.every(
|
|
514
|
+
(x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
|
|
515
|
+
)) {
|
|
516
|
+
return response;
|
|
517
|
+
}
|
|
518
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
|
|
519
|
+
}
|
|
114
520
|
};
|
|
115
|
-
var
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
makeUrl: makeUrl3
|
|
521
|
+
var HFInferenceAutomaticSpeechRecognitionTask = class extends HFInferenceTask {
|
|
522
|
+
async getResponse(response) {
|
|
523
|
+
return response;
|
|
524
|
+
}
|
|
120
525
|
};
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
526
|
+
var HFInferenceAudioToAudioTask = class extends HFInferenceTask {
|
|
527
|
+
async getResponse(response) {
|
|
528
|
+
if (!Array.isArray(response)) {
|
|
529
|
+
throw new InferenceOutputError("Expected Array");
|
|
530
|
+
}
|
|
531
|
+
if (!response.every((elem) => {
|
|
532
|
+
return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
|
|
533
|
+
})) {
|
|
534
|
+
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
|
|
535
|
+
}
|
|
536
|
+
return response;
|
|
537
|
+
}
|
|
538
|
+
};
|
|
539
|
+
var HFInferenceDocumentQuestionAnsweringTask = class extends HFInferenceTask {
|
|
540
|
+
async getResponse(response) {
|
|
541
|
+
if (Array.isArray(response) && response.every(
|
|
542
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
|
|
543
|
+
)) {
|
|
544
|
+
return response[0];
|
|
545
|
+
}
|
|
546
|
+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
547
|
+
}
|
|
548
|
+
};
|
|
549
|
+
var HFInferenceFeatureExtractionTask = class extends HFInferenceTask {
|
|
550
|
+
async getResponse(response) {
|
|
551
|
+
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
|
|
552
|
+
if (curDepth > maxDepth)
|
|
553
|
+
return false;
|
|
554
|
+
if (arr.every((x) => Array.isArray(x))) {
|
|
555
|
+
return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
|
|
556
|
+
} else {
|
|
557
|
+
return arr.every((x) => typeof x === "number");
|
|
558
|
+
}
|
|
559
|
+
};
|
|
560
|
+
if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
|
|
561
|
+
return response;
|
|
562
|
+
}
|
|
563
|
+
throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
|
|
564
|
+
}
|
|
565
|
+
};
|
|
566
|
+
var HFInferenceImageClassificationTask = class extends HFInferenceTask {
|
|
567
|
+
async getResponse(response) {
|
|
568
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
|
|
569
|
+
return response;
|
|
570
|
+
}
|
|
571
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
572
|
+
}
|
|
573
|
+
};
|
|
574
|
+
var HFInferenceImageSegmentationTask = class extends HFInferenceTask {
|
|
575
|
+
async getResponse(response) {
|
|
576
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) {
|
|
577
|
+
return response;
|
|
578
|
+
}
|
|
579
|
+
throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
|
|
580
|
+
}
|
|
581
|
+
};
|
|
582
|
+
var HFInferenceImageToTextTask = class extends HFInferenceTask {
|
|
583
|
+
async getResponse(response) {
|
|
584
|
+
if (typeof response?.generated_text !== "string") {
|
|
585
|
+
throw new InferenceOutputError("Expected {generated_text: string}");
|
|
586
|
+
}
|
|
587
|
+
return response;
|
|
588
|
+
}
|
|
589
|
+
};
|
|
590
|
+
var HFInferenceImageToImageTask = class extends HFInferenceTask {
|
|
591
|
+
async getResponse(response) {
|
|
592
|
+
if (response instanceof Blob) {
|
|
593
|
+
return response;
|
|
594
|
+
}
|
|
595
|
+
throw new InferenceOutputError("Expected Blob");
|
|
596
|
+
}
|
|
597
|
+
};
|
|
598
|
+
var HFInferenceObjectDetectionTask = class extends HFInferenceTask {
|
|
599
|
+
async getResponse(response) {
|
|
600
|
+
if (Array.isArray(response) && response.every(
|
|
601
|
+
(x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
|
|
602
|
+
)) {
|
|
603
|
+
return response;
|
|
604
|
+
}
|
|
605
|
+
throw new InferenceOutputError(
|
|
606
|
+
"Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
|
|
127
607
|
);
|
|
128
|
-
this.name = "InferenceOutputError";
|
|
129
608
|
}
|
|
130
609
|
};
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
}
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
var makeBaseUrl4 = (task) => {
|
|
148
|
-
return task === "text-to-video" ? FAL_AI_API_BASE_URL_QUEUE : FAL_AI_API_BASE_URL;
|
|
610
|
+
var HFInferenceZeroShotImageClassificationTask = class extends HFInferenceTask {
|
|
611
|
+
async getResponse(response) {
|
|
612
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
|
|
613
|
+
return response;
|
|
614
|
+
}
|
|
615
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
616
|
+
}
|
|
617
|
+
};
|
|
618
|
+
var HFInferenceTextClassificationTask = class extends HFInferenceTask {
|
|
619
|
+
async getResponse(response) {
|
|
620
|
+
const output = response?.[0];
|
|
621
|
+
if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
|
|
622
|
+
return output;
|
|
623
|
+
}
|
|
624
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
625
|
+
}
|
|
149
626
|
};
|
|
150
|
-
var
|
|
151
|
-
|
|
627
|
+
var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask {
|
|
628
|
+
async getResponse(response) {
|
|
629
|
+
if (Array.isArray(response) ? response.every(
|
|
630
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
|
|
631
|
+
) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number") {
|
|
632
|
+
return Array.isArray(response) ? response[0] : response;
|
|
633
|
+
}
|
|
634
|
+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
635
|
+
}
|
|
152
636
|
};
|
|
153
|
-
var
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
637
|
+
var HFInferenceFillMaskTask = class extends HFInferenceTask {
|
|
638
|
+
async getResponse(response) {
|
|
639
|
+
if (Array.isArray(response) && response.every(
|
|
640
|
+
(x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
|
|
641
|
+
)) {
|
|
642
|
+
return response;
|
|
643
|
+
}
|
|
644
|
+
throw new InferenceOutputError(
|
|
645
|
+
"Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
|
|
646
|
+
);
|
|
647
|
+
}
|
|
157
648
|
};
|
|
158
|
-
var
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
};
|
|
165
|
-
var FAL_AI_CONFIG = {
|
|
166
|
-
makeBaseUrl: makeBaseUrl4,
|
|
167
|
-
makeBody: makeBody4,
|
|
168
|
-
makeHeaders: makeHeaders4,
|
|
169
|
-
makeUrl: makeUrl4
|
|
170
|
-
};
|
|
171
|
-
async function pollFalResponse(res, url, headers) {
|
|
172
|
-
const requestId = res.request_id;
|
|
173
|
-
if (!requestId) {
|
|
174
|
-
throw new InferenceOutputError("No request ID found in the response");
|
|
175
|
-
}
|
|
176
|
-
let status = res.status;
|
|
177
|
-
const parsedUrl = new URL(url);
|
|
178
|
-
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
|
|
179
|
-
const modelId = new URL(res.response_url).pathname;
|
|
180
|
-
const queryParams = parsedUrl.search;
|
|
181
|
-
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
|
|
182
|
-
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
|
|
183
|
-
while (status !== "COMPLETED") {
|
|
184
|
-
await delay(500);
|
|
185
|
-
const statusResponse = await fetch(statusUrl, { headers });
|
|
186
|
-
if (!statusResponse.ok) {
|
|
187
|
-
throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
|
|
649
|
+
var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask {
|
|
650
|
+
async getResponse(response) {
|
|
651
|
+
if (Array.isArray(response) && response.every(
|
|
652
|
+
(x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
|
|
653
|
+
)) {
|
|
654
|
+
return response;
|
|
188
655
|
}
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
656
|
+
throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
|
|
657
|
+
}
|
|
658
|
+
};
|
|
659
|
+
var HFInferenceSentenceSimilarityTask = class extends HFInferenceTask {
|
|
660
|
+
async getResponse(response) {
|
|
661
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
662
|
+
return response;
|
|
193
663
|
}
|
|
664
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
194
665
|
}
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
666
|
+
};
|
|
667
|
+
var HFInferenceTableQuestionAnsweringTask = class extends HFInferenceTask {
|
|
668
|
+
static validate(elem) {
|
|
669
|
+
return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
|
|
670
|
+
(coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
|
|
671
|
+
);
|
|
201
672
|
}
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
673
|
+
async getResponse(response) {
|
|
674
|
+
if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) {
|
|
675
|
+
return Array.isArray(response) ? response[0] : response;
|
|
676
|
+
}
|
|
206
677
|
throw new InferenceOutputError(
|
|
207
|
-
"Expected {
|
|
678
|
+
"Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
|
|
208
679
|
);
|
|
209
680
|
}
|
|
210
|
-
}
|
|
211
|
-
|
|
212
|
-
// src/providers/fireworks-ai.ts
|
|
213
|
-
var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai";
|
|
214
|
-
var makeBaseUrl5 = () => {
|
|
215
|
-
return FIREWORKS_AI_API_BASE_URL;
|
|
216
|
-
};
|
|
217
|
-
var makeBody5 = (params) => {
|
|
218
|
-
return {
|
|
219
|
-
...params.args,
|
|
220
|
-
...params.chatCompletion ? { model: params.model } : void 0
|
|
221
|
-
};
|
|
222
681
|
};
|
|
223
|
-
var
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
682
|
+
var HFInferenceTokenClassificationTask = class extends HFInferenceTask {
|
|
683
|
+
async getResponse(response) {
|
|
684
|
+
if (Array.isArray(response) && response.every(
|
|
685
|
+
(x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
|
|
686
|
+
)) {
|
|
687
|
+
return response;
|
|
688
|
+
}
|
|
689
|
+
throw new InferenceOutputError(
|
|
690
|
+
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
|
|
691
|
+
);
|
|
229
692
|
}
|
|
230
|
-
return `${params.baseUrl}/inference`;
|
|
231
693
|
};
|
|
232
|
-
var
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
694
|
+
var HFInferenceTranslationTask = class extends HFInferenceTask {
|
|
695
|
+
async getResponse(response) {
|
|
696
|
+
if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
|
|
697
|
+
return response?.length === 1 ? response?.[0] : response;
|
|
698
|
+
}
|
|
699
|
+
throw new InferenceOutputError("Expected Array<{translation_text: string}>");
|
|
700
|
+
}
|
|
237
701
|
};
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
702
|
+
var HFInferenceSummarizationTask = class extends HFInferenceTask {
|
|
703
|
+
async getResponse(response) {
|
|
704
|
+
if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
|
|
705
|
+
return response?.[0];
|
|
706
|
+
}
|
|
707
|
+
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
|
708
|
+
}
|
|
242
709
|
};
|
|
243
|
-
var
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
};
|
|
710
|
+
var HFInferenceTextToSpeechTask = class extends HFInferenceTask {
|
|
711
|
+
async getResponse(response) {
|
|
712
|
+
return response;
|
|
713
|
+
}
|
|
248
714
|
};
|
|
249
|
-
var
|
|
250
|
-
|
|
715
|
+
var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
|
|
716
|
+
async getResponse(response) {
|
|
717
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
718
|
+
return response;
|
|
719
|
+
}
|
|
720
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
721
|
+
}
|
|
251
722
|
};
|
|
252
|
-
var
|
|
253
|
-
|
|
254
|
-
|
|
723
|
+
var HFInferenceVisualQuestionAnsweringTask = class extends HFInferenceTask {
|
|
724
|
+
async getResponse(response) {
|
|
725
|
+
if (Array.isArray(response) && response.every(
|
|
726
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
|
|
727
|
+
)) {
|
|
728
|
+
return response[0];
|
|
729
|
+
}
|
|
730
|
+
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
|
255
731
|
}
|
|
256
|
-
|
|
257
|
-
|
|
732
|
+
};
|
|
733
|
+
var HFInferenceTabularRegressionTask = class extends HFInferenceTask {
|
|
734
|
+
async getResponse(response) {
|
|
735
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
736
|
+
return response;
|
|
737
|
+
}
|
|
738
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
258
739
|
}
|
|
259
|
-
return `${params.baseUrl}/models/${params.model}`;
|
|
260
740
|
};
|
|
261
|
-
var
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
makeUrl: makeUrl6
|
|
741
|
+
var HFInferenceTextToAudioTask = class extends HFInferenceTask {
|
|
742
|
+
async getResponse(response) {
|
|
743
|
+
return response;
|
|
744
|
+
}
|
|
266
745
|
};
|
|
267
746
|
|
|
268
747
|
// src/providers/hyperbolic.ts
|
|
269
748
|
var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
|
|
270
|
-
var
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
return {
|
|
275
|
-
...params.args,
|
|
276
|
-
...params.task === "text-to-image" ? { model_name: params.model } : { model: params.model }
|
|
277
|
-
};
|
|
278
|
-
};
|
|
279
|
-
var makeHeaders7 = (params) => {
|
|
280
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
749
|
+
var HyperbolicConversationalTask = class extends BaseConversationalTask {
|
|
750
|
+
constructor() {
|
|
751
|
+
super("hyperbolic", HYPERBOLIC_API_BASE_URL);
|
|
752
|
+
}
|
|
281
753
|
};
|
|
282
|
-
var
|
|
283
|
-
|
|
284
|
-
|
|
754
|
+
var HyperbolicTextGenerationTask = class extends BaseTextGenerationTask {
|
|
755
|
+
constructor() {
|
|
756
|
+
super("hyperbolic", HYPERBOLIC_API_BASE_URL);
|
|
757
|
+
}
|
|
758
|
+
makeRoute() {
|
|
759
|
+
return "v1/chat/completions";
|
|
760
|
+
}
|
|
761
|
+
preparePayload(params) {
|
|
762
|
+
return {
|
|
763
|
+
messages: [{ content: params.args.inputs, role: "user" }],
|
|
764
|
+
...params.args.parameters ? {
|
|
765
|
+
max_tokens: params.args.parameters.max_new_tokens,
|
|
766
|
+
...omit(params.args.parameters, "max_new_tokens")
|
|
767
|
+
} : void 0,
|
|
768
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
769
|
+
model: params.model
|
|
770
|
+
};
|
|
771
|
+
}
|
|
772
|
+
async getResponse(response) {
|
|
773
|
+
if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
|
|
774
|
+
const completion = response.choices[0];
|
|
775
|
+
return {
|
|
776
|
+
generated_text: completion.message.content
|
|
777
|
+
};
|
|
778
|
+
}
|
|
779
|
+
throw new InferenceOutputError("Expected Hyperbolic text generation response format");
|
|
285
780
|
}
|
|
286
|
-
return `${params.baseUrl}/v1/chat/completions`;
|
|
287
781
|
};
|
|
288
|
-
var
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
782
|
+
var HyperbolicTextToImageTask = class extends TaskProviderHelper {
|
|
783
|
+
constructor() {
|
|
784
|
+
super("hyperbolic", HYPERBOLIC_API_BASE_URL);
|
|
785
|
+
}
|
|
786
|
+
makeRoute(params) {
|
|
787
|
+
return `/v1/images/generations`;
|
|
788
|
+
}
|
|
789
|
+
preparePayload(params) {
|
|
790
|
+
return {
|
|
791
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
792
|
+
...params.args.parameters,
|
|
793
|
+
prompt: params.args.inputs,
|
|
794
|
+
model_name: params.model
|
|
795
|
+
};
|
|
796
|
+
}
|
|
797
|
+
async getResponse(response, url, headers, outputType) {
|
|
798
|
+
if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images[0] && typeof response.images[0].image === "string") {
|
|
799
|
+
if (outputType === "url") {
|
|
800
|
+
return `data:image/jpeg;base64,${response.images[0].image}`;
|
|
801
|
+
}
|
|
802
|
+
return fetch(`data:image/jpeg;base64,${response.images[0].image}`).then((res) => res.blob());
|
|
803
|
+
}
|
|
804
|
+
throw new InferenceOutputError("Expected Hyperbolic text-to-image response format");
|
|
805
|
+
}
|
|
293
806
|
};
|
|
294
807
|
|
|
295
808
|
// src/providers/nebius.ts
|
|
296
809
|
var NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
|
|
297
|
-
var
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
return {
|
|
302
|
-
...params.args,
|
|
303
|
-
model: params.model
|
|
304
|
-
};
|
|
810
|
+
var NebiusConversationalTask = class extends BaseConversationalTask {
|
|
811
|
+
constructor() {
|
|
812
|
+
super("nebius", NEBIUS_API_BASE_URL);
|
|
813
|
+
}
|
|
305
814
|
};
|
|
306
|
-
var
|
|
307
|
-
|
|
815
|
+
var NebiusTextGenerationTask = class extends BaseTextGenerationTask {
|
|
816
|
+
constructor() {
|
|
817
|
+
super("nebius", NEBIUS_API_BASE_URL);
|
|
818
|
+
}
|
|
308
819
|
};
|
|
309
|
-
var
|
|
310
|
-
|
|
311
|
-
|
|
820
|
+
var NebiusTextToImageTask = class extends TaskProviderHelper {
|
|
821
|
+
constructor() {
|
|
822
|
+
super("nebius", NEBIUS_API_BASE_URL);
|
|
823
|
+
}
|
|
824
|
+
preparePayload(params) {
|
|
825
|
+
return {
|
|
826
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
827
|
+
...params.args.parameters,
|
|
828
|
+
response_format: "b64_json",
|
|
829
|
+
prompt: params.args.inputs,
|
|
830
|
+
model: params.model
|
|
831
|
+
};
|
|
312
832
|
}
|
|
313
|
-
|
|
314
|
-
return
|
|
833
|
+
makeRoute(params) {
|
|
834
|
+
return "v1/images/generations";
|
|
315
835
|
}
|
|
316
|
-
|
|
317
|
-
|
|
836
|
+
async getResponse(response, url, headers, outputType) {
|
|
837
|
+
if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
|
|
838
|
+
const base64Data = response.data[0].b64_json;
|
|
839
|
+
if (outputType === "url") {
|
|
840
|
+
return `data:image/jpeg;base64,${base64Data}`;
|
|
841
|
+
}
|
|
842
|
+
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
|
|
843
|
+
}
|
|
844
|
+
throw new InferenceOutputError("Expected Nebius text-to-image response format");
|
|
318
845
|
}
|
|
319
|
-
return params.baseUrl;
|
|
320
|
-
};
|
|
321
|
-
var NEBIUS_CONFIG = {
|
|
322
|
-
makeBaseUrl: makeBaseUrl8,
|
|
323
|
-
makeBody: makeBody8,
|
|
324
|
-
makeHeaders: makeHeaders8,
|
|
325
|
-
makeUrl: makeUrl8
|
|
326
846
|
};
|
|
327
847
|
|
|
328
848
|
// src/providers/novita.ts
|
|
329
849
|
var NOVITA_API_BASE_URL = "https://api.novita.ai";
|
|
330
|
-
var
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
};
|
|
338
|
-
};
|
|
339
|
-
var makeHeaders9 = (params) => {
|
|
340
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
850
|
+
var NovitaTextGenerationTask = class extends BaseTextGenerationTask {
|
|
851
|
+
constructor() {
|
|
852
|
+
super("novita", NOVITA_API_BASE_URL);
|
|
853
|
+
}
|
|
854
|
+
makeRoute() {
|
|
855
|
+
return "/v3/openai/chat/completions";
|
|
856
|
+
}
|
|
341
857
|
};
|
|
342
|
-
var
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
}
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
return `${params.baseUrl}/v3/hf/${params.model}`;
|
|
858
|
+
var NovitaConversationalTask = class extends BaseConversationalTask {
|
|
859
|
+
constructor() {
|
|
860
|
+
super("novita", NOVITA_API_BASE_URL);
|
|
861
|
+
}
|
|
862
|
+
makeRoute() {
|
|
863
|
+
return "/v3/openai/chat/completions";
|
|
349
864
|
}
|
|
350
|
-
return params.baseUrl;
|
|
351
865
|
};
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
866
|
+
|
|
867
|
+
// src/providers/openai.ts
|
|
868
|
+
var OPENAI_API_BASE_URL = "https://api.openai.com";
|
|
869
|
+
var OpenAIConversationalTask = class extends BaseConversationalTask {
|
|
870
|
+
constructor() {
|
|
871
|
+
super("openai", OPENAI_API_BASE_URL, true);
|
|
872
|
+
}
|
|
357
873
|
};
|
|
358
874
|
|
|
359
875
|
// src/providers/replicate.ts
|
|
360
|
-
var
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
}
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
876
|
+
var ReplicateTask = class extends TaskProviderHelper {
|
|
877
|
+
constructor(url) {
|
|
878
|
+
super("replicate", url || "https://api.replicate.com");
|
|
879
|
+
}
|
|
880
|
+
makeRoute(params) {
|
|
881
|
+
if (params.model.includes(":")) {
|
|
882
|
+
return "v1/predictions";
|
|
883
|
+
}
|
|
884
|
+
return `v1/models/${params.model}/predictions`;
|
|
885
|
+
}
|
|
886
|
+
preparePayload(params) {
|
|
887
|
+
return {
|
|
888
|
+
input: {
|
|
889
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
890
|
+
...params.args.parameters,
|
|
891
|
+
prompt: params.args.inputs
|
|
892
|
+
},
|
|
893
|
+
version: params.model.includes(":") ? params.model.split(":")[1] : void 0
|
|
894
|
+
};
|
|
895
|
+
}
|
|
896
|
+
prepareHeaders(params, binary) {
|
|
897
|
+
const headers = { Authorization: `Bearer ${params.accessToken}`, Prefer: "wait" };
|
|
898
|
+
if (!binary) {
|
|
899
|
+
headers["Content-Type"] = "application/json";
|
|
900
|
+
}
|
|
901
|
+
return headers;
|
|
902
|
+
}
|
|
903
|
+
makeUrl(params) {
|
|
904
|
+
const baseUrl = this.makeBaseUrl(params);
|
|
905
|
+
if (params.model.includes(":")) {
|
|
906
|
+
return `${baseUrl}/v1/predictions`;
|
|
907
|
+
}
|
|
908
|
+
return `${baseUrl}/v1/models/${params.model}/predictions`;
|
|
909
|
+
}
|
|
369
910
|
};
|
|
370
|
-
var
|
|
371
|
-
|
|
911
|
+
var ReplicateTextToImageTask = class extends ReplicateTask {
|
|
912
|
+
async getResponse(res, url, headers, outputType) {
|
|
913
|
+
if (typeof res === "object" && "output" in res && Array.isArray(res.output) && res.output.length > 0 && typeof res.output[0] === "string") {
|
|
914
|
+
if (outputType === "url") {
|
|
915
|
+
return res.output[0];
|
|
916
|
+
}
|
|
917
|
+
const urlResponse = await fetch(res.output[0]);
|
|
918
|
+
return await urlResponse.blob();
|
|
919
|
+
}
|
|
920
|
+
throw new InferenceOutputError("Expected Replicate text-to-image response format");
|
|
921
|
+
}
|
|
372
922
|
};
|
|
373
|
-
var
|
|
374
|
-
|
|
375
|
-
|
|
923
|
+
var ReplicateTextToSpeechTask = class extends ReplicateTask {
|
|
924
|
+
preparePayload(params) {
|
|
925
|
+
const payload = super.preparePayload(params);
|
|
926
|
+
const input = payload["input"];
|
|
927
|
+
if (typeof input === "object" && input !== null && "prompt" in input) {
|
|
928
|
+
const inputObj = input;
|
|
929
|
+
inputObj["text"] = inputObj["prompt"];
|
|
930
|
+
delete inputObj["prompt"];
|
|
931
|
+
}
|
|
932
|
+
return payload;
|
|
933
|
+
}
|
|
934
|
+
async getResponse(response) {
|
|
935
|
+
if (response instanceof Blob) {
|
|
936
|
+
return response;
|
|
937
|
+
}
|
|
938
|
+
if (response && typeof response === "object") {
|
|
939
|
+
if ("output" in response) {
|
|
940
|
+
if (typeof response.output === "string") {
|
|
941
|
+
const urlResponse = await fetch(response.output);
|
|
942
|
+
return await urlResponse.blob();
|
|
943
|
+
} else if (Array.isArray(response.output)) {
|
|
944
|
+
const urlResponse = await fetch(response.output[0]);
|
|
945
|
+
return await urlResponse.blob();
|
|
946
|
+
}
|
|
947
|
+
}
|
|
948
|
+
}
|
|
949
|
+
throw new InferenceOutputError("Expected Blob or object with output");
|
|
376
950
|
}
|
|
377
|
-
return `${params.baseUrl}/v1/models/${params.model}/predictions`;
|
|
378
951
|
};
|
|
379
|
-
var
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
952
|
+
var ReplicateTextToVideoTask = class extends ReplicateTask {
|
|
953
|
+
async getResponse(response) {
|
|
954
|
+
if (typeof response === "object" && !!response && "output" in response && typeof response.output === "string" && isUrl(response.output)) {
|
|
955
|
+
const urlResponse = await fetch(response.output);
|
|
956
|
+
return await urlResponse.blob();
|
|
957
|
+
}
|
|
958
|
+
throw new InferenceOutputError("Expected { output: string }");
|
|
959
|
+
}
|
|
384
960
|
};
|
|
385
961
|
|
|
386
962
|
// src/providers/sambanova.ts
|
|
387
|
-
var
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
};
|
|
391
|
-
var makeBody11 = (params) => {
|
|
392
|
-
return {
|
|
393
|
-
...params.args,
|
|
394
|
-
...params.chatCompletion ? { model: params.model } : void 0
|
|
395
|
-
};
|
|
396
|
-
};
|
|
397
|
-
var makeHeaders11 = (params) => {
|
|
398
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
399
|
-
};
|
|
400
|
-
var makeUrl11 = (params) => {
|
|
401
|
-
if (params.chatCompletion) {
|
|
402
|
-
return `${params.baseUrl}/v1/chat/completions`;
|
|
963
|
+
var SambanovaConversationalTask = class extends BaseConversationalTask {
|
|
964
|
+
constructor() {
|
|
965
|
+
super("sambanova", "https://api.sambanova.ai");
|
|
403
966
|
}
|
|
404
|
-
return params.baseUrl;
|
|
405
|
-
};
|
|
406
|
-
var SAMBANOVA_CONFIG = {
|
|
407
|
-
makeBaseUrl: makeBaseUrl11,
|
|
408
|
-
makeBody: makeBody11,
|
|
409
|
-
makeHeaders: makeHeaders11,
|
|
410
|
-
makeUrl: makeUrl11
|
|
411
967
|
};
|
|
412
968
|
|
|
413
969
|
// src/providers/together.ts
|
|
414
970
|
var TOGETHER_API_BASE_URL = "https://api.together.xyz";
|
|
415
|
-
var
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
return {
|
|
420
|
-
...params.args,
|
|
421
|
-
model: params.model
|
|
422
|
-
};
|
|
423
|
-
};
|
|
424
|
-
var makeHeaders12 = (params) => {
|
|
425
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
971
|
+
var TogetherConversationalTask = class extends BaseConversationalTask {
|
|
972
|
+
constructor() {
|
|
973
|
+
super("together", TOGETHER_API_BASE_URL);
|
|
974
|
+
}
|
|
426
975
|
};
|
|
427
|
-
var
|
|
428
|
-
|
|
429
|
-
|
|
976
|
+
var TogetherTextGenerationTask = class extends BaseTextGenerationTask {
|
|
977
|
+
constructor() {
|
|
978
|
+
super("together", TOGETHER_API_BASE_URL);
|
|
430
979
|
}
|
|
431
|
-
|
|
432
|
-
return
|
|
980
|
+
preparePayload(params) {
|
|
981
|
+
return {
|
|
982
|
+
model: params.model,
|
|
983
|
+
...params.args,
|
|
984
|
+
prompt: params.args.inputs
|
|
985
|
+
};
|
|
433
986
|
}
|
|
434
|
-
|
|
435
|
-
|
|
987
|
+
async getResponse(response) {
|
|
988
|
+
if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
|
|
989
|
+
const completion = response.choices[0];
|
|
990
|
+
return {
|
|
991
|
+
generated_text: completion.text
|
|
992
|
+
};
|
|
993
|
+
}
|
|
994
|
+
throw new InferenceOutputError("Expected Together text generation response format");
|
|
436
995
|
}
|
|
437
|
-
return params.baseUrl;
|
|
438
996
|
};
|
|
439
|
-
var
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
997
|
+
var TogetherTextToImageTask = class extends TaskProviderHelper {
|
|
998
|
+
constructor() {
|
|
999
|
+
super("together", TOGETHER_API_BASE_URL);
|
|
1000
|
+
}
|
|
1001
|
+
makeRoute() {
|
|
1002
|
+
return "v1/images/generations";
|
|
1003
|
+
}
|
|
1004
|
+
preparePayload(params) {
|
|
1005
|
+
return {
|
|
1006
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
1007
|
+
...params.args.parameters,
|
|
1008
|
+
prompt: params.args.inputs,
|
|
1009
|
+
response_format: "base64",
|
|
1010
|
+
model: params.model
|
|
1011
|
+
};
|
|
1012
|
+
}
|
|
1013
|
+
async getResponse(response, outputType) {
|
|
1014
|
+
if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
|
|
1015
|
+
const base64Data = response.data[0].b64_json;
|
|
1016
|
+
if (outputType === "url") {
|
|
1017
|
+
return `data:image/jpeg;base64,${base64Data}`;
|
|
1018
|
+
}
|
|
1019
|
+
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
|
|
1020
|
+
}
|
|
1021
|
+
throw new InferenceOutputError("Expected Together text-to-image response format");
|
|
1022
|
+
}
|
|
444
1023
|
};
|
|
445
1024
|
|
|
446
|
-
// src/
|
|
447
|
-
var
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
}
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
1025
|
+
// src/lib/getProviderHelper.ts
|
|
1026
|
+
var PROVIDERS = {
|
|
1027
|
+
"black-forest-labs": {
|
|
1028
|
+
"text-to-image": new BlackForestLabsTextToImageTask()
|
|
1029
|
+
},
|
|
1030
|
+
cerebras: {
|
|
1031
|
+
conversational: new CerebrasConversationalTask()
|
|
1032
|
+
},
|
|
1033
|
+
cohere: {
|
|
1034
|
+
conversational: new CohereConversationalTask()
|
|
1035
|
+
},
|
|
1036
|
+
"fal-ai": {
|
|
1037
|
+
"text-to-image": new FalAITextToImageTask(),
|
|
1038
|
+
"text-to-speech": new FalAITextToSpeechTask(),
|
|
1039
|
+
"text-to-video": new FalAITextToVideoTask(),
|
|
1040
|
+
"automatic-speech-recognition": new FalAIAutomaticSpeechRecognitionTask()
|
|
1041
|
+
},
|
|
1042
|
+
"hf-inference": {
|
|
1043
|
+
"text-to-image": new HFInferenceTextToImageTask(),
|
|
1044
|
+
conversational: new HFInferenceConversationalTask(),
|
|
1045
|
+
"text-generation": new HFInferenceTextGenerationTask(),
|
|
1046
|
+
"text-classification": new HFInferenceTextClassificationTask(),
|
|
1047
|
+
"question-answering": new HFInferenceQuestionAnsweringTask(),
|
|
1048
|
+
"audio-classification": new HFInferenceAudioClassificationTask(),
|
|
1049
|
+
"automatic-speech-recognition": new HFInferenceAutomaticSpeechRecognitionTask(),
|
|
1050
|
+
"fill-mask": new HFInferenceFillMaskTask(),
|
|
1051
|
+
"feature-extraction": new HFInferenceFeatureExtractionTask(),
|
|
1052
|
+
"image-classification": new HFInferenceImageClassificationTask(),
|
|
1053
|
+
"image-segmentation": new HFInferenceImageSegmentationTask(),
|
|
1054
|
+
"document-question-answering": new HFInferenceDocumentQuestionAnsweringTask(),
|
|
1055
|
+
"image-to-text": new HFInferenceImageToTextTask(),
|
|
1056
|
+
"object-detection": new HFInferenceObjectDetectionTask(),
|
|
1057
|
+
"audio-to-audio": new HFInferenceAudioToAudioTask(),
|
|
1058
|
+
"zero-shot-image-classification": new HFInferenceZeroShotImageClassificationTask(),
|
|
1059
|
+
"zero-shot-classification": new HFInferenceZeroShotClassificationTask(),
|
|
1060
|
+
"image-to-image": new HFInferenceImageToImageTask(),
|
|
1061
|
+
"sentence-similarity": new HFInferenceSentenceSimilarityTask(),
|
|
1062
|
+
"table-question-answering": new HFInferenceTableQuestionAnsweringTask(),
|
|
1063
|
+
"tabular-classification": new HFInferenceTabularClassificationTask(),
|
|
1064
|
+
"text-to-speech": new HFInferenceTextToSpeechTask(),
|
|
1065
|
+
"token-classification": new HFInferenceTokenClassificationTask(),
|
|
1066
|
+
translation: new HFInferenceTranslationTask(),
|
|
1067
|
+
summarization: new HFInferenceSummarizationTask(),
|
|
1068
|
+
"visual-question-answering": new HFInferenceVisualQuestionAnsweringTask(),
|
|
1069
|
+
"tabular-regression": new HFInferenceTabularRegressionTask(),
|
|
1070
|
+
"text-to-audio": new HFInferenceTextToAudioTask()
|
|
1071
|
+
},
|
|
1072
|
+
"fireworks-ai": {
|
|
1073
|
+
conversational: new FireworksConversationalTask()
|
|
1074
|
+
},
|
|
1075
|
+
hyperbolic: {
|
|
1076
|
+
"text-to-image": new HyperbolicTextToImageTask(),
|
|
1077
|
+
conversational: new HyperbolicConversationalTask(),
|
|
1078
|
+
"text-generation": new HyperbolicTextGenerationTask()
|
|
1079
|
+
},
|
|
1080
|
+
nebius: {
|
|
1081
|
+
"text-to-image": new NebiusTextToImageTask(),
|
|
1082
|
+
conversational: new NebiusConversationalTask(),
|
|
1083
|
+
"text-generation": new NebiusTextGenerationTask()
|
|
1084
|
+
},
|
|
1085
|
+
novita: {
|
|
1086
|
+
conversational: new NovitaConversationalTask(),
|
|
1087
|
+
"text-generation": new NovitaTextGenerationTask()
|
|
1088
|
+
},
|
|
1089
|
+
openai: {
|
|
1090
|
+
conversational: new OpenAIConversationalTask()
|
|
1091
|
+
},
|
|
1092
|
+
replicate: {
|
|
1093
|
+
"text-to-image": new ReplicateTextToImageTask(),
|
|
1094
|
+
"text-to-speech": new ReplicateTextToSpeechTask(),
|
|
1095
|
+
"text-to-video": new ReplicateTextToVideoTask()
|
|
1096
|
+
},
|
|
1097
|
+
sambanova: {
|
|
1098
|
+
conversational: new SambanovaConversationalTask()
|
|
1099
|
+
},
|
|
1100
|
+
together: {
|
|
1101
|
+
"text-to-image": new TogetherTextToImageTask(),
|
|
1102
|
+
conversational: new TogetherConversationalTask(),
|
|
1103
|
+
"text-generation": new TogetherTextGenerationTask()
|
|
454
1104
|
}
|
|
455
|
-
return {
|
|
456
|
-
...params.args,
|
|
457
|
-
model: params.model
|
|
458
|
-
};
|
|
459
1105
|
};
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
throw new Error("OpenAI only supports chat completions.");
|
|
1106
|
+
function getProviderHelper(provider, task) {
|
|
1107
|
+
if (provider === "hf-inference") {
|
|
1108
|
+
if (!task) {
|
|
1109
|
+
return new HFInferenceTask();
|
|
1110
|
+
}
|
|
466
1111
|
}
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
1112
|
+
if (!task) {
|
|
1113
|
+
throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'");
|
|
1114
|
+
}
|
|
1115
|
+
if (!(provider in PROVIDERS)) {
|
|
1116
|
+
throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
|
|
1117
|
+
}
|
|
1118
|
+
const providerTasks = PROVIDERS[provider];
|
|
1119
|
+
if (!providerTasks || !(task in providerTasks)) {
|
|
1120
|
+
throw new Error(
|
|
1121
|
+
`Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}`
|
|
1122
|
+
);
|
|
1123
|
+
}
|
|
1124
|
+
return providerTasks[task];
|
|
1125
|
+
}
|
|
480
1126
|
|
|
481
1127
|
// src/providers/consts.ts
|
|
482
1128
|
var HARDCODED_MODEL_ID_MAPPING = {
|
|
@@ -546,28 +1192,11 @@ async function getProviderModelId(params, args, options = {}) {
|
|
|
546
1192
|
}
|
|
547
1193
|
|
|
548
1194
|
// src/lib/makeRequestOptions.ts
|
|
549
|
-
var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
|
|
550
1195
|
var tasks = null;
|
|
551
|
-
var providerConfigs = {
|
|
552
|
-
"black-forest-labs": BLACK_FOREST_LABS_CONFIG,
|
|
553
|
-
cerebras: CEREBRAS_CONFIG,
|
|
554
|
-
cohere: COHERE_CONFIG,
|
|
555
|
-
"fal-ai": FAL_AI_CONFIG,
|
|
556
|
-
"fireworks-ai": FIREWORKS_AI_CONFIG,
|
|
557
|
-
"hf-inference": HF_INFERENCE_CONFIG,
|
|
558
|
-
hyperbolic: HYPERBOLIC_CONFIG,
|
|
559
|
-
openai: OPENAI_CONFIG,
|
|
560
|
-
nebius: NEBIUS_CONFIG,
|
|
561
|
-
novita: NOVITA_CONFIG,
|
|
562
|
-
replicate: REPLICATE_CONFIG,
|
|
563
|
-
sambanova: SAMBANOVA_CONFIG,
|
|
564
|
-
together: TOGETHER_CONFIG
|
|
565
|
-
};
|
|
566
1196
|
async function makeRequestOptions(args, options) {
|
|
567
1197
|
const { provider: maybeProvider, model: maybeModel } = args;
|
|
568
1198
|
const provider = maybeProvider ?? "hf-inference";
|
|
569
|
-
const
|
|
570
|
-
const { task, chatCompletion: chatCompletion2 } = options ?? {};
|
|
1199
|
+
const { task } = options ?? {};
|
|
571
1200
|
if (args.endpointUrl && provider !== "hf-inference") {
|
|
572
1201
|
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
|
|
573
1202
|
}
|
|
@@ -577,19 +1206,16 @@ async function makeRequestOptions(args, options) {
|
|
|
577
1206
|
if (!maybeModel && !task) {
|
|
578
1207
|
throw new Error("No model provided, and no task has been specified.");
|
|
579
1208
|
}
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
if (providerConfig.clientSideRoutingOnly && !maybeModel) {
|
|
1209
|
+
const hfModel = maybeModel ?? await loadDefaultModel(task);
|
|
1210
|
+
const providerHelper = getProviderHelper(provider, task);
|
|
1211
|
+
if (providerHelper.clientSideRoutingOnly && !maybeModel) {
|
|
584
1212
|
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
|
|
585
1213
|
}
|
|
586
|
-
const
|
|
587
|
-
const resolvedModel = providerConfig.clientSideRoutingOnly ? (
|
|
1214
|
+
const resolvedModel = providerHelper.clientSideRoutingOnly ? (
|
|
588
1215
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
589
1216
|
removeProviderPrefix(maybeModel, provider)
|
|
590
1217
|
) : await getProviderModelId({ model: hfModel, provider }, args, {
|
|
591
1218
|
task,
|
|
592
|
-
chatCompletion: chatCompletion2,
|
|
593
1219
|
fetch: options?.fetch
|
|
594
1220
|
});
|
|
595
1221
|
return makeRequestOptionsFromResolvedModel(resolvedModel, args, options);
|
|
@@ -597,10 +1223,10 @@ async function makeRequestOptions(args, options) {
|
|
|
597
1223
|
function makeRequestOptionsFromResolvedModel(resolvedModel, args, options) {
|
|
598
1224
|
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
|
|
599
1225
|
const provider = maybeProvider ?? "hf-inference";
|
|
600
|
-
const
|
|
601
|
-
const
|
|
1226
|
+
const { includeCredentials, task, signal, billTo } = options ?? {};
|
|
1227
|
+
const providerHelper = getProviderHelper(provider, task);
|
|
602
1228
|
const authMethod = (() => {
|
|
603
|
-
if (
|
|
1229
|
+
if (providerHelper.clientSideRoutingOnly) {
|
|
604
1230
|
if (accessToken && accessToken.startsWith("hf_")) {
|
|
605
1231
|
throw new Error(`Provider ${provider} is closed-source and does not support HF tokens.`);
|
|
606
1232
|
}
|
|
@@ -614,35 +1240,30 @@ function makeRequestOptionsFromResolvedModel(resolvedModel, args, options) {
|
|
|
614
1240
|
}
|
|
615
1241
|
return "none";
|
|
616
1242
|
})();
|
|
617
|
-
const
|
|
1243
|
+
const modelId = endpointUrl ?? resolvedModel;
|
|
1244
|
+
const url = providerHelper.makeUrl({
|
|
618
1245
|
authMethod,
|
|
619
|
-
|
|
620
|
-
model: resolvedModel,
|
|
621
|
-
chatCompletion: chatCompletion2,
|
|
1246
|
+
model: modelId,
|
|
622
1247
|
task
|
|
623
1248
|
});
|
|
624
|
-
const
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
1249
|
+
const headers = providerHelper.prepareHeaders(
|
|
1250
|
+
{
|
|
1251
|
+
accessToken,
|
|
1252
|
+
authMethod
|
|
1253
|
+
},
|
|
1254
|
+
"data" in args && !!args.data
|
|
1255
|
+
);
|
|
629
1256
|
if (billTo) {
|
|
630
1257
|
headers[HF_HEADER_X_BILL_TO] = billTo;
|
|
631
1258
|
}
|
|
632
|
-
if (!binary) {
|
|
633
|
-
headers["Content-Type"] = "application/json";
|
|
634
|
-
}
|
|
635
1259
|
const ownUserAgent = `${name}/${version}`;
|
|
636
1260
|
const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : void 0].filter((x) => x !== void 0).join(" ");
|
|
637
1261
|
headers["User-Agent"] = userAgent;
|
|
638
|
-
const body =
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
chatCompletion: chatCompletion2
|
|
644
|
-
})
|
|
645
|
-
);
|
|
1262
|
+
const body = providerHelper.makeBody({
|
|
1263
|
+
args: remainingArgs,
|
|
1264
|
+
model: resolvedModel,
|
|
1265
|
+
task
|
|
1266
|
+
});
|
|
646
1267
|
let credentials;
|
|
647
1268
|
if (typeof includeCredentials === "string") {
|
|
648
1269
|
credentials = includeCredentials;
|
|
@@ -904,30 +1525,6 @@ async function* streamingRequest(args, options) {
|
|
|
904
1525
|
yield* innerStreamingRequest(args, options);
|
|
905
1526
|
}
|
|
906
1527
|
|
|
907
|
-
// src/utils/pick.ts
|
|
908
|
-
function pick(o, props) {
|
|
909
|
-
return Object.assign(
|
|
910
|
-
{},
|
|
911
|
-
...props.map((prop) => {
|
|
912
|
-
if (o[prop] !== void 0) {
|
|
913
|
-
return { [prop]: o[prop] };
|
|
914
|
-
}
|
|
915
|
-
})
|
|
916
|
-
);
|
|
917
|
-
}
|
|
918
|
-
|
|
919
|
-
// src/utils/typedInclude.ts
|
|
920
|
-
function typedInclude(arr, v) {
|
|
921
|
-
return arr.includes(v);
|
|
922
|
-
}
|
|
923
|
-
|
|
924
|
-
// src/utils/omit.ts
|
|
925
|
-
function omit(o, props) {
|
|
926
|
-
const propsArr = Array.isArray(props) ? props : [props];
|
|
927
|
-
const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
|
|
928
|
-
return pick(o, letsKeep);
|
|
929
|
-
}
|
|
930
|
-
|
|
931
1528
|
// src/tasks/audio/utils.ts
|
|
932
1529
|
function preparePayload(args) {
|
|
933
1530
|
return "data" in args ? args : {
|
|
@@ -938,16 +1535,24 @@ function preparePayload(args) {
|
|
|
938
1535
|
|
|
939
1536
|
// src/tasks/audio/audioClassification.ts
|
|
940
1537
|
async function audioClassification(args, options) {
|
|
1538
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
|
|
941
1539
|
const payload = preparePayload(args);
|
|
942
1540
|
const { data: res } = await innerRequest(payload, {
|
|
943
1541
|
...options,
|
|
944
1542
|
task: "audio-classification"
|
|
945
1543
|
});
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
1544
|
+
return providerHelper.getResponse(res);
|
|
1545
|
+
}
|
|
1546
|
+
|
|
1547
|
+
// src/tasks/audio/audioToAudio.ts
|
|
1548
|
+
async function audioToAudio(args, options) {
|
|
1549
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
|
|
1550
|
+
const payload = preparePayload(args);
|
|
1551
|
+
const { data: res } = await innerRequest(payload, {
|
|
1552
|
+
...options,
|
|
1553
|
+
task: "audio-to-audio"
|
|
1554
|
+
});
|
|
1555
|
+
return providerHelper.getResponse(res);
|
|
951
1556
|
}
|
|
952
1557
|
|
|
953
1558
|
// src/utils/base64FromBytes.ts
|
|
@@ -965,6 +1570,7 @@ function base64FromBytes(arr) {
|
|
|
965
1570
|
|
|
966
1571
|
// src/tasks/audio/automaticSpeechRecognition.ts
|
|
967
1572
|
async function automaticSpeechRecognition(args, options) {
|
|
1573
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
|
|
968
1574
|
const payload = await buildPayload(args);
|
|
969
1575
|
const { data: res } = await innerRequest(payload, {
|
|
970
1576
|
...options,
|
|
@@ -974,9 +1580,8 @@ async function automaticSpeechRecognition(args, options) {
|
|
|
974
1580
|
if (!isValidOutput) {
|
|
975
1581
|
throw new InferenceOutputError("Expected {text: string}");
|
|
976
1582
|
}
|
|
977
|
-
return res;
|
|
1583
|
+
return providerHelper.getResponse(res);
|
|
978
1584
|
}
|
|
979
|
-
var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
|
|
980
1585
|
async function buildPayload(args) {
|
|
981
1586
|
if (args.provider === "fal-ai") {
|
|
982
1587
|
const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : void 0;
|
|
@@ -995,63 +1600,23 @@ async function buildPayload(args) {
|
|
|
995
1600
|
}
|
|
996
1601
|
const base64audio = base64FromBytes(new Uint8Array(await blob.arrayBuffer()));
|
|
997
1602
|
return {
|
|
998
|
-
..."data" in args ? omit(args, "data") : omit(args, "inputs"),
|
|
999
|
-
audio_url: `data:${contentType};base64,${base64audio}`
|
|
1000
|
-
};
|
|
1001
|
-
} else {
|
|
1002
|
-
return preparePayload(args);
|
|
1003
|
-
}
|
|
1004
|
-
}
|
|
1005
|
-
|
|
1006
|
-
// src/tasks/audio/textToSpeech.ts
|
|
1007
|
-
async function textToSpeech(args, options) {
|
|
1008
|
-
const payload = args.provider === "replicate" ? {
|
|
1009
|
-
...omit(args, ["inputs", "parameters"]),
|
|
1010
|
-
...args.parameters,
|
|
1011
|
-
text: args.inputs
|
|
1012
|
-
} : args;
|
|
1013
|
-
const { data: res } = await innerRequest(payload, {
|
|
1014
|
-
...options,
|
|
1015
|
-
task: "text-to-speech"
|
|
1016
|
-
});
|
|
1017
|
-
if (res instanceof Blob) {
|
|
1018
|
-
return res;
|
|
1019
|
-
}
|
|
1020
|
-
if (res && typeof res === "object") {
|
|
1021
|
-
if ("output" in res) {
|
|
1022
|
-
if (typeof res.output === "string") {
|
|
1023
|
-
const urlResponse = await fetch(res.output);
|
|
1024
|
-
const blob = await urlResponse.blob();
|
|
1025
|
-
return blob;
|
|
1026
|
-
} else if (Array.isArray(res.output)) {
|
|
1027
|
-
const urlResponse = await fetch(res.output[0]);
|
|
1028
|
-
const blob = await urlResponse.blob();
|
|
1029
|
-
return blob;
|
|
1030
|
-
}
|
|
1031
|
-
}
|
|
1603
|
+
..."data" in args ? omit(args, "data") : omit(args, "inputs"),
|
|
1604
|
+
audio_url: `data:${contentType};base64,${base64audio}`
|
|
1605
|
+
};
|
|
1606
|
+
} else {
|
|
1607
|
+
return preparePayload(args);
|
|
1032
1608
|
}
|
|
1033
|
-
throw new InferenceOutputError("Expected Blob or object with output");
|
|
1034
1609
|
}
|
|
1035
1610
|
|
|
1036
|
-
// src/tasks/audio/
|
|
1037
|
-
async function
|
|
1038
|
-
const
|
|
1039
|
-
const
|
|
1611
|
+
// src/tasks/audio/textToSpeech.ts
|
|
1612
|
+
async function textToSpeech(args, options) {
|
|
1613
|
+
const provider = args.provider ?? "hf-inference";
|
|
1614
|
+
const providerHelper = getProviderHelper(provider, "text-to-speech");
|
|
1615
|
+
const { data: res } = await innerRequest(args, {
|
|
1040
1616
|
...options,
|
|
1041
|
-
task: "
|
|
1617
|
+
task: "text-to-speech"
|
|
1042
1618
|
});
|
|
1043
|
-
return
|
|
1044
|
-
}
|
|
1045
|
-
function validateOutput(output) {
|
|
1046
|
-
if (!Array.isArray(output)) {
|
|
1047
|
-
throw new InferenceOutputError("Expected Array");
|
|
1048
|
-
}
|
|
1049
|
-
if (!output.every((elem) => {
|
|
1050
|
-
return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
|
|
1051
|
-
})) {
|
|
1052
|
-
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
|
|
1053
|
-
}
|
|
1054
|
-
return output;
|
|
1619
|
+
return providerHelper.getResponse(res);
|
|
1055
1620
|
}
|
|
1056
1621
|
|
|
1057
1622
|
// src/tasks/cv/utils.ts
|
|
@@ -1061,183 +1626,95 @@ function preparePayload2(args) {
|
|
|
1061
1626
|
|
|
1062
1627
|
// src/tasks/cv/imageClassification.ts
|
|
1063
1628
|
async function imageClassification(args, options) {
|
|
1629
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
|
|
1064
1630
|
const payload = preparePayload2(args);
|
|
1065
1631
|
const { data: res } = await innerRequest(payload, {
|
|
1066
1632
|
...options,
|
|
1067
1633
|
task: "image-classification"
|
|
1068
1634
|
});
|
|
1069
|
-
|
|
1070
|
-
if (!isValidOutput) {
|
|
1071
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
1072
|
-
}
|
|
1073
|
-
return res;
|
|
1635
|
+
return providerHelper.getResponse(res);
|
|
1074
1636
|
}
|
|
1075
1637
|
|
|
1076
1638
|
// src/tasks/cv/imageSegmentation.ts
|
|
1077
1639
|
async function imageSegmentation(args, options) {
|
|
1640
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
|
|
1078
1641
|
const payload = preparePayload2(args);
|
|
1079
1642
|
const { data: res } = await innerRequest(payload, {
|
|
1080
1643
|
...options,
|
|
1081
1644
|
task: "image-segmentation"
|
|
1082
1645
|
});
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1646
|
+
return providerHelper.getResponse(res);
|
|
1647
|
+
}
|
|
1648
|
+
|
|
1649
|
+
// src/tasks/cv/imageToImage.ts
|
|
1650
|
+
async function imageToImage(args, options) {
|
|
1651
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
|
|
1652
|
+
let reqArgs;
|
|
1653
|
+
if (!args.parameters) {
|
|
1654
|
+
reqArgs = {
|
|
1655
|
+
accessToken: args.accessToken,
|
|
1656
|
+
model: args.model,
|
|
1657
|
+
data: args.inputs
|
|
1658
|
+
};
|
|
1659
|
+
} else {
|
|
1660
|
+
reqArgs = {
|
|
1661
|
+
...args,
|
|
1662
|
+
inputs: base64FromBytes(
|
|
1663
|
+
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await args.inputs.arrayBuffer())
|
|
1664
|
+
)
|
|
1665
|
+
};
|
|
1086
1666
|
}
|
|
1087
|
-
|
|
1667
|
+
const { data: res } = await innerRequest(reqArgs, {
|
|
1668
|
+
...options,
|
|
1669
|
+
task: "image-to-image"
|
|
1670
|
+
});
|
|
1671
|
+
return providerHelper.getResponse(res);
|
|
1088
1672
|
}
|
|
1089
1673
|
|
|
1090
1674
|
// src/tasks/cv/imageToText.ts
|
|
1091
1675
|
async function imageToText(args, options) {
|
|
1676
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
|
|
1092
1677
|
const payload = preparePayload2(args);
|
|
1093
1678
|
const { data: res } = await innerRequest(payload, {
|
|
1094
1679
|
...options,
|
|
1095
1680
|
task: "image-to-text"
|
|
1096
1681
|
});
|
|
1097
|
-
|
|
1098
|
-
throw new InferenceOutputError("Expected {generated_text: string}");
|
|
1099
|
-
}
|
|
1100
|
-
return res?.[0];
|
|
1682
|
+
return providerHelper.getResponse(res[0]);
|
|
1101
1683
|
}
|
|
1102
1684
|
|
|
1103
1685
|
// src/tasks/cv/objectDetection.ts
|
|
1104
1686
|
async function objectDetection(args, options) {
|
|
1687
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
|
|
1105
1688
|
const payload = preparePayload2(args);
|
|
1106
1689
|
const { data: res } = await innerRequest(payload, {
|
|
1107
1690
|
...options,
|
|
1108
1691
|
task: "object-detection"
|
|
1109
1692
|
});
|
|
1110
|
-
|
|
1111
|
-
(x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
|
|
1112
|
-
);
|
|
1113
|
-
if (!isValidOutput) {
|
|
1114
|
-
throw new InferenceOutputError(
|
|
1115
|
-
"Expected Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
|
|
1116
|
-
);
|
|
1117
|
-
}
|
|
1118
|
-
return res;
|
|
1693
|
+
return providerHelper.getResponse(res);
|
|
1119
1694
|
}
|
|
1120
1695
|
|
|
1121
1696
|
// src/tasks/cv/textToImage.ts
|
|
1122
|
-
function getResponseFormatArg(provider) {
|
|
1123
|
-
switch (provider) {
|
|
1124
|
-
case "fal-ai":
|
|
1125
|
-
return { sync_mode: true };
|
|
1126
|
-
case "nebius":
|
|
1127
|
-
return { response_format: "b64_json" };
|
|
1128
|
-
case "replicate":
|
|
1129
|
-
return void 0;
|
|
1130
|
-
case "together":
|
|
1131
|
-
return { response_format: "base64" };
|
|
1132
|
-
default:
|
|
1133
|
-
return void 0;
|
|
1134
|
-
}
|
|
1135
|
-
}
|
|
1136
1697
|
async function textToImage(args, options) {
|
|
1137
|
-
const
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
...getResponseFormatArg(args.provider),
|
|
1141
|
-
prompt: args.inputs
|
|
1142
|
-
};
|
|
1143
|
-
const { data: res } = await innerRequest(payload, {
|
|
1698
|
+
const provider = args.provider ?? "hf-inference";
|
|
1699
|
+
const providerHelper = getProviderHelper(provider, "text-to-image");
|
|
1700
|
+
const { data: res } = await innerRequest(args, {
|
|
1144
1701
|
...options,
|
|
1145
1702
|
task: "text-to-image"
|
|
1146
1703
|
});
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
return await pollBflResponse(res.polling_url, options?.outputType);
|
|
1150
|
-
}
|
|
1151
|
-
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
|
|
1152
|
-
if (options?.outputType === "url") {
|
|
1153
|
-
return res.images[0].url;
|
|
1154
|
-
} else {
|
|
1155
|
-
const image = await fetch(res.images[0].url);
|
|
1156
|
-
return await image.blob();
|
|
1157
|
-
}
|
|
1158
|
-
}
|
|
1159
|
-
if (args.provider === "hyperbolic" && "images" in res && Array.isArray(res.images) && res.images[0] && typeof res.images[0].image === "string") {
|
|
1160
|
-
if (options?.outputType === "url") {
|
|
1161
|
-
return `data:image/jpeg;base64,${res.images[0].image}`;
|
|
1162
|
-
}
|
|
1163
|
-
const base64Response = await fetch(`data:image/jpeg;base64,${res.images[0].image}`);
|
|
1164
|
-
return await base64Response.blob();
|
|
1165
|
-
}
|
|
1166
|
-
if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
|
|
1167
|
-
const base64Data = res.data[0].b64_json;
|
|
1168
|
-
if (options?.outputType === "url") {
|
|
1169
|
-
return `data:image/jpeg;base64,${base64Data}`;
|
|
1170
|
-
}
|
|
1171
|
-
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
1172
|
-
return await base64Response.blob();
|
|
1173
|
-
}
|
|
1174
|
-
if ("output" in res && Array.isArray(res.output)) {
|
|
1175
|
-
if (options?.outputType === "url") {
|
|
1176
|
-
return res.output[0];
|
|
1177
|
-
}
|
|
1178
|
-
const urlResponse = await fetch(res.output[0]);
|
|
1179
|
-
const blob = await urlResponse.blob();
|
|
1180
|
-
return blob;
|
|
1181
|
-
}
|
|
1182
|
-
}
|
|
1183
|
-
const isValidOutput = res && res instanceof Blob;
|
|
1184
|
-
if (!isValidOutput) {
|
|
1185
|
-
throw new InferenceOutputError("Expected Blob");
|
|
1186
|
-
}
|
|
1187
|
-
if (options?.outputType === "url") {
|
|
1188
|
-
const b64 = await res.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
|
|
1189
|
-
return `data:image/jpeg;base64,${b64}`;
|
|
1190
|
-
}
|
|
1191
|
-
return res;
|
|
1192
|
-
}
|
|
1193
|
-
async function pollBflResponse(url, outputType) {
|
|
1194
|
-
const urlObj = new URL(url);
|
|
1195
|
-
for (let step = 0; step < 5; step++) {
|
|
1196
|
-
await delay(1e3);
|
|
1197
|
-
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
|
|
1198
|
-
urlObj.searchParams.set("attempt", step.toString(10));
|
|
1199
|
-
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
|
|
1200
|
-
if (!resp.ok) {
|
|
1201
|
-
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
1202
|
-
}
|
|
1203
|
-
const payload = await resp.json();
|
|
1204
|
-
if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
|
|
1205
|
-
if (outputType === "url") {
|
|
1206
|
-
return payload.result.sample;
|
|
1207
|
-
}
|
|
1208
|
-
const image = await fetch(payload.result.sample);
|
|
1209
|
-
return await image.blob();
|
|
1210
|
-
}
|
|
1211
|
-
}
|
|
1212
|
-
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
1704
|
+
const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-image" });
|
|
1705
|
+
return providerHelper.getResponse(res, url, info.headers, options?.outputType);
|
|
1213
1706
|
}
|
|
1214
1707
|
|
|
1215
|
-
// src/tasks/cv/
|
|
1216
|
-
async function
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
1220
|
-
accessToken: args.accessToken,
|
|
1221
|
-
model: args.model,
|
|
1222
|
-
data: args.inputs
|
|
1223
|
-
};
|
|
1224
|
-
} else {
|
|
1225
|
-
reqArgs = {
|
|
1226
|
-
...args,
|
|
1227
|
-
inputs: base64FromBytes(
|
|
1228
|
-
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await args.inputs.arrayBuffer())
|
|
1229
|
-
)
|
|
1230
|
-
};
|
|
1231
|
-
}
|
|
1232
|
-
const { data: res } = await innerRequest(reqArgs, {
|
|
1708
|
+
// src/tasks/cv/textToVideo.ts
|
|
1709
|
+
async function textToVideo(args, options) {
|
|
1710
|
+
const provider = args.provider ?? "hf-inference";
|
|
1711
|
+
const providerHelper = getProviderHelper(provider, "text-to-video");
|
|
1712
|
+
const { data: response } = await innerRequest(args, {
|
|
1233
1713
|
...options,
|
|
1234
|
-
task: "
|
|
1714
|
+
task: "text-to-video"
|
|
1235
1715
|
});
|
|
1236
|
-
const
|
|
1237
|
-
|
|
1238
|
-
throw new InferenceOutputError("Expected Blob");
|
|
1239
|
-
}
|
|
1240
|
-
return res;
|
|
1716
|
+
const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-video" });
|
|
1717
|
+
return providerHelper.getResponse(response, url, info.headers);
|
|
1241
1718
|
}
|
|
1242
1719
|
|
|
1243
1720
|
// src/tasks/cv/zeroShotImageClassification.ts
|
|
@@ -1263,226 +1740,112 @@ async function preparePayload3(args) {
|
|
|
1263
1740
|
}
|
|
1264
1741
|
}
|
|
1265
1742
|
async function zeroShotImageClassification(args, options) {
|
|
1743
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
|
|
1266
1744
|
const payload = await preparePayload3(args);
|
|
1267
1745
|
const { data: res } = await innerRequest(payload, {
|
|
1268
1746
|
...options,
|
|
1269
1747
|
task: "zero-shot-image-classification"
|
|
1270
1748
|
});
|
|
1271
|
-
|
|
1272
|
-
if (!isValidOutput) {
|
|
1273
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
1274
|
-
}
|
|
1275
|
-
return res;
|
|
1749
|
+
return providerHelper.getResponse(res);
|
|
1276
1750
|
}
|
|
1277
1751
|
|
|
1278
|
-
// src/tasks/
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
throw new Error(
|
|
1283
|
-
`textToVideo inference is only supported for the following providers: ${SUPPORTED_PROVIDERS.join(", ")}`
|
|
1284
|
-
);
|
|
1285
|
-
}
|
|
1286
|
-
const payload = args.provider === "fal-ai" || args.provider === "replicate" || args.provider === "novita" ? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs } : args;
|
|
1287
|
-
const { data, requestContext } = await innerRequest(payload, {
|
|
1752
|
+
// src/tasks/nlp/chatCompletion.ts
|
|
1753
|
+
async function chatCompletion(args, options) {
|
|
1754
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
|
|
1755
|
+
const { data: response } = await innerRequest(args, {
|
|
1288
1756
|
...options,
|
|
1289
|
-
task: "
|
|
1757
|
+
task: "conversational"
|
|
1758
|
+
});
|
|
1759
|
+
return providerHelper.getResponse(response);
|
|
1760
|
+
}
|
|
1761
|
+
|
|
1762
|
+
// src/tasks/nlp/chatCompletionStream.ts
|
|
1763
|
+
async function* chatCompletionStream(args, options) {
|
|
1764
|
+
yield* innerStreamingRequest(args, {
|
|
1765
|
+
...options,
|
|
1766
|
+
task: "conversational"
|
|
1290
1767
|
});
|
|
1291
|
-
if (args.provider === "fal-ai") {
|
|
1292
|
-
return await pollFalResponse(
|
|
1293
|
-
data,
|
|
1294
|
-
requestContext.url,
|
|
1295
|
-
requestContext.info.headers
|
|
1296
|
-
);
|
|
1297
|
-
} else if (args.provider === "novita") {
|
|
1298
|
-
const isValidOutput = typeof data === "object" && !!data && "video" in data && typeof data.video === "object" && !!data.video && "video_url" in data.video && typeof data.video.video_url === "string" && isUrl(data.video.video_url);
|
|
1299
|
-
if (!isValidOutput) {
|
|
1300
|
-
throw new InferenceOutputError("Expected { video: { video_url: string } }");
|
|
1301
|
-
}
|
|
1302
|
-
const urlResponse = await fetch(data.video.video_url);
|
|
1303
|
-
return await urlResponse.blob();
|
|
1304
|
-
} else {
|
|
1305
|
-
const isValidOutput = typeof data === "object" && !!data && "output" in data && typeof data.output === "string" && isUrl(data.output);
|
|
1306
|
-
if (!isValidOutput) {
|
|
1307
|
-
throw new InferenceOutputError("Expected { output: string }");
|
|
1308
|
-
}
|
|
1309
|
-
const urlResponse = await fetch(data.output);
|
|
1310
|
-
return await urlResponse.blob();
|
|
1311
|
-
}
|
|
1312
1768
|
}
|
|
1313
1769
|
|
|
1314
1770
|
// src/tasks/nlp/featureExtraction.ts
|
|
1315
1771
|
async function featureExtraction(args, options) {
|
|
1772
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction");
|
|
1316
1773
|
const { data: res } = await innerRequest(args, {
|
|
1317
1774
|
...options,
|
|
1318
1775
|
task: "feature-extraction"
|
|
1319
1776
|
});
|
|
1320
|
-
|
|
1321
|
-
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
|
|
1322
|
-
if (curDepth > maxDepth)
|
|
1323
|
-
return false;
|
|
1324
|
-
if (arr.every((x) => Array.isArray(x))) {
|
|
1325
|
-
return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
|
|
1326
|
-
} else {
|
|
1327
|
-
return arr.every((x) => typeof x === "number");
|
|
1328
|
-
}
|
|
1329
|
-
};
|
|
1330
|
-
isValidOutput = Array.isArray(res) && isNumArrayRec(res, 3, 0);
|
|
1331
|
-
if (!isValidOutput) {
|
|
1332
|
-
throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
|
|
1333
|
-
}
|
|
1334
|
-
return res;
|
|
1777
|
+
return providerHelper.getResponse(res);
|
|
1335
1778
|
}
|
|
1336
1779
|
|
|
1337
1780
|
// src/tasks/nlp/fillMask.ts
|
|
1338
1781
|
async function fillMask(args, options) {
|
|
1782
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask");
|
|
1339
1783
|
const { data: res } = await innerRequest(args, {
|
|
1340
1784
|
...options,
|
|
1341
1785
|
task: "fill-mask"
|
|
1342
1786
|
});
|
|
1343
|
-
|
|
1344
|
-
(x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
|
|
1345
|
-
);
|
|
1346
|
-
if (!isValidOutput) {
|
|
1347
|
-
throw new InferenceOutputError(
|
|
1348
|
-
"Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
|
|
1349
|
-
);
|
|
1350
|
-
}
|
|
1351
|
-
return res;
|
|
1787
|
+
return providerHelper.getResponse(res);
|
|
1352
1788
|
}
|
|
1353
1789
|
|
|
1354
1790
|
// src/tasks/nlp/questionAnswering.ts
|
|
1355
1791
|
async function questionAnswering(args, options) {
|
|
1792
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering");
|
|
1356
1793
|
const { data: res } = await innerRequest(args, {
|
|
1357
1794
|
...options,
|
|
1358
1795
|
task: "question-answering"
|
|
1359
1796
|
});
|
|
1360
|
-
|
|
1361
|
-
(elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
|
|
1362
|
-
) : typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";
|
|
1363
|
-
if (!isValidOutput) {
|
|
1364
|
-
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
1365
|
-
}
|
|
1366
|
-
return Array.isArray(res) ? res[0] : res;
|
|
1797
|
+
return providerHelper.getResponse(res);
|
|
1367
1798
|
}
|
|
1368
1799
|
|
|
1369
1800
|
// src/tasks/nlp/sentenceSimilarity.ts
|
|
1370
1801
|
async function sentenceSimilarity(args, options) {
|
|
1802
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity");
|
|
1371
1803
|
const { data: res } = await innerRequest(args, {
|
|
1372
1804
|
...options,
|
|
1373
1805
|
task: "sentence-similarity"
|
|
1374
1806
|
});
|
|
1375
|
-
|
|
1376
|
-
if (!isValidOutput) {
|
|
1377
|
-
throw new InferenceOutputError("Expected number[]");
|
|
1378
|
-
}
|
|
1379
|
-
return res;
|
|
1807
|
+
return providerHelper.getResponse(res);
|
|
1380
1808
|
}
|
|
1381
1809
|
|
|
1382
1810
|
// src/tasks/nlp/summarization.ts
|
|
1383
1811
|
async function summarization(args, options) {
|
|
1812
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization");
|
|
1384
1813
|
const { data: res } = await innerRequest(args, {
|
|
1385
1814
|
...options,
|
|
1386
1815
|
task: "summarization"
|
|
1387
1816
|
});
|
|
1388
|
-
|
|
1389
|
-
if (!isValidOutput) {
|
|
1390
|
-
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
|
1391
|
-
}
|
|
1392
|
-
return res?.[0];
|
|
1817
|
+
return providerHelper.getResponse(res);
|
|
1393
1818
|
}
|
|
1394
1819
|
|
|
1395
1820
|
// src/tasks/nlp/tableQuestionAnswering.ts
|
|
1396
1821
|
async function tableQuestionAnswering(args, options) {
|
|
1822
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
|
|
1397
1823
|
const { data: res } = await innerRequest(args, {
|
|
1398
1824
|
...options,
|
|
1399
1825
|
task: "table-question-answering"
|
|
1400
1826
|
});
|
|
1401
|
-
|
|
1402
|
-
if (!isValidOutput) {
|
|
1403
|
-
throw new InferenceOutputError(
|
|
1404
|
-
"Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
|
|
1405
|
-
);
|
|
1406
|
-
}
|
|
1407
|
-
return Array.isArray(res) ? res[0] : res;
|
|
1408
|
-
}
|
|
1409
|
-
function validate(elem) {
|
|
1410
|
-
return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
|
|
1411
|
-
(coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
|
|
1412
|
-
);
|
|
1827
|
+
return providerHelper.getResponse(res);
|
|
1413
1828
|
}
|
|
1414
1829
|
|
|
1415
1830
|
// src/tasks/nlp/textClassification.ts
|
|
1416
1831
|
async function textClassification(args, options) {
|
|
1832
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
|
|
1417
1833
|
const { data: res } = await innerRequest(args, {
|
|
1418
1834
|
...options,
|
|
1419
1835
|
task: "text-classification"
|
|
1420
1836
|
});
|
|
1421
|
-
|
|
1422
|
-
const isValidOutput = Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number");
|
|
1423
|
-
if (!isValidOutput) {
|
|
1424
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
1425
|
-
}
|
|
1426
|
-
return output;
|
|
1427
|
-
}
|
|
1428
|
-
|
|
1429
|
-
// src/utils/toArray.ts
|
|
1430
|
-
function toArray(obj) {
|
|
1431
|
-
if (Array.isArray(obj)) {
|
|
1432
|
-
return obj;
|
|
1433
|
-
}
|
|
1434
|
-
return [obj];
|
|
1837
|
+
return providerHelper.getResponse(res);
|
|
1435
1838
|
}
|
|
1436
1839
|
|
|
1437
1840
|
// src/tasks/nlp/textGeneration.ts
|
|
1438
1841
|
async function textGeneration(args, options) {
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
if (!isValidOutput) {
|
|
1447
|
-
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
1448
|
-
}
|
|
1449
|
-
const completion = raw.choices[0];
|
|
1450
|
-
return {
|
|
1451
|
-
generated_text: completion.text
|
|
1452
|
-
};
|
|
1453
|
-
} else if (args.provider === "hyperbolic") {
|
|
1454
|
-
const payload = {
|
|
1455
|
-
messages: [{ content: args.inputs, role: "user" }],
|
|
1456
|
-
...args.parameters ? {
|
|
1457
|
-
max_tokens: args.parameters.max_new_tokens,
|
|
1458
|
-
...omit(args.parameters, "max_new_tokens")
|
|
1459
|
-
} : void 0,
|
|
1460
|
-
...omit(args, ["inputs", "parameters"])
|
|
1461
|
-
};
|
|
1462
|
-
const raw = (await innerRequest(payload, {
|
|
1463
|
-
...options,
|
|
1464
|
-
task: "text-generation"
|
|
1465
|
-
})).data;
|
|
1466
|
-
const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
|
|
1467
|
-
if (!isValidOutput) {
|
|
1468
|
-
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
1469
|
-
}
|
|
1470
|
-
const completion = raw.choices[0];
|
|
1471
|
-
return {
|
|
1472
|
-
generated_text: completion.message.content
|
|
1473
|
-
};
|
|
1474
|
-
} else {
|
|
1475
|
-
const { data: res } = await innerRequest(args, {
|
|
1476
|
-
...options,
|
|
1477
|
-
task: "text-generation"
|
|
1478
|
-
});
|
|
1479
|
-
const output = toArray(res);
|
|
1480
|
-
const isValidOutput = Array.isArray(output) && output.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
|
|
1481
|
-
if (!isValidOutput) {
|
|
1482
|
-
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
1483
|
-
}
|
|
1484
|
-
return output?.[0];
|
|
1485
|
-
}
|
|
1842
|
+
const provider = args.provider ?? "hf-inference";
|
|
1843
|
+
const providerHelper = getProviderHelper(provider, "text-generation");
|
|
1844
|
+
const { data: response } = await innerRequest(args, {
|
|
1845
|
+
...options,
|
|
1846
|
+
task: "text-generation"
|
|
1847
|
+
});
|
|
1848
|
+
return providerHelper.getResponse(response);
|
|
1486
1849
|
}
|
|
1487
1850
|
|
|
1488
1851
|
// src/tasks/nlp/textGenerationStream.ts
|
|
@@ -1495,77 +1858,37 @@ async function* textGenerationStream(args, options) {
|
|
|
1495
1858
|
|
|
1496
1859
|
// src/tasks/nlp/tokenClassification.ts
|
|
1497
1860
|
async function tokenClassification(args, options) {
|
|
1861
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification");
|
|
1498
1862
|
const { data: res } = await innerRequest(args, {
|
|
1499
1863
|
...options,
|
|
1500
1864
|
task: "token-classification"
|
|
1501
1865
|
});
|
|
1502
|
-
|
|
1503
|
-
const isValidOutput = Array.isArray(output) && output.every(
|
|
1504
|
-
(x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
|
|
1505
|
-
);
|
|
1506
|
-
if (!isValidOutput) {
|
|
1507
|
-
throw new InferenceOutputError(
|
|
1508
|
-
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
|
|
1509
|
-
);
|
|
1510
|
-
}
|
|
1511
|
-
return output;
|
|
1866
|
+
return providerHelper.getResponse(res);
|
|
1512
1867
|
}
|
|
1513
1868
|
|
|
1514
1869
|
// src/tasks/nlp/translation.ts
|
|
1515
1870
|
async function translation(args, options) {
|
|
1871
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation");
|
|
1516
1872
|
const { data: res } = await innerRequest(args, {
|
|
1517
1873
|
...options,
|
|
1518
1874
|
task: "translation"
|
|
1519
1875
|
});
|
|
1520
|
-
|
|
1521
|
-
if (!isValidOutput) {
|
|
1522
|
-
throw new InferenceOutputError("Expected type Array<{translation_text: string}>");
|
|
1523
|
-
}
|
|
1524
|
-
return res?.length === 1 ? res?.[0] : res;
|
|
1876
|
+
return providerHelper.getResponse(res);
|
|
1525
1877
|
}
|
|
1526
1878
|
|
|
1527
1879
|
// src/tasks/nlp/zeroShotClassification.ts
|
|
1528
1880
|
async function zeroShotClassification(args, options) {
|
|
1881
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification");
|
|
1529
1882
|
const { data: res } = await innerRequest(args, {
|
|
1530
1883
|
...options,
|
|
1531
1884
|
task: "zero-shot-classification"
|
|
1532
1885
|
});
|
|
1533
|
-
|
|
1534
|
-
const isValidOutput = Array.isArray(output) && output.every(
|
|
1535
|
-
(x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
|
|
1536
|
-
);
|
|
1537
|
-
if (!isValidOutput) {
|
|
1538
|
-
throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
|
|
1539
|
-
}
|
|
1540
|
-
return output;
|
|
1541
|
-
}
|
|
1542
|
-
|
|
1543
|
-
// src/tasks/nlp/chatCompletion.ts
|
|
1544
|
-
async function chatCompletion(args, options) {
|
|
1545
|
-
const { data: res } = await innerRequest(args, {
|
|
1546
|
-
...options,
|
|
1547
|
-
task: "text-generation",
|
|
1548
|
-
chatCompletion: true
|
|
1549
|
-
});
|
|
1550
|
-
const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
|
|
1551
|
-
(res.system_fingerprint === void 0 || res.system_fingerprint === null || typeof res.system_fingerprint === "string") && typeof res?.usage === "object";
|
|
1552
|
-
if (!isValidOutput) {
|
|
1553
|
-
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
1554
|
-
}
|
|
1555
|
-
return res;
|
|
1556
|
-
}
|
|
1557
|
-
|
|
1558
|
-
// src/tasks/nlp/chatCompletionStream.ts
|
|
1559
|
-
async function* chatCompletionStream(args, options) {
|
|
1560
|
-
yield* innerStreamingRequest(args, {
|
|
1561
|
-
...options,
|
|
1562
|
-
task: "text-generation",
|
|
1563
|
-
chatCompletion: true
|
|
1564
|
-
});
|
|
1886
|
+
return providerHelper.getResponse(res);
|
|
1565
1887
|
}
|
|
1566
1888
|
|
|
1567
1889
|
// src/tasks/multimodal/documentQuestionAnswering.ts
|
|
1568
1890
|
async function documentQuestionAnswering(args, options) {
|
|
1891
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering");
|
|
1569
1892
|
const reqArgs = {
|
|
1570
1893
|
...args,
|
|
1571
1894
|
inputs: {
|
|
@@ -1581,18 +1904,12 @@ async function documentQuestionAnswering(args, options) {
|
|
|
1581
1904
|
task: "document-question-answering"
|
|
1582
1905
|
}
|
|
1583
1906
|
);
|
|
1584
|
-
|
|
1585
|
-
const isValidOutput = Array.isArray(output) && output.every(
|
|
1586
|
-
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
|
|
1587
|
-
);
|
|
1588
|
-
if (!isValidOutput) {
|
|
1589
|
-
throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>");
|
|
1590
|
-
}
|
|
1591
|
-
return output[0];
|
|
1907
|
+
return providerHelper.getResponse(res);
|
|
1592
1908
|
}
|
|
1593
1909
|
|
|
1594
1910
|
// src/tasks/multimodal/visualQuestionAnswering.ts
|
|
1595
1911
|
async function visualQuestionAnswering(args, options) {
|
|
1912
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "visual-question-answering");
|
|
1596
1913
|
const reqArgs = {
|
|
1597
1914
|
...args,
|
|
1598
1915
|
inputs: {
|
|
@@ -1605,39 +1922,27 @@ async function visualQuestionAnswering(args, options) {
|
|
|
1605
1922
|
...options,
|
|
1606
1923
|
task: "visual-question-answering"
|
|
1607
1924
|
});
|
|
1608
|
-
|
|
1609
|
-
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
|
|
1610
|
-
);
|
|
1611
|
-
if (!isValidOutput) {
|
|
1612
|
-
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
|
1613
|
-
}
|
|
1614
|
-
return res[0];
|
|
1925
|
+
return providerHelper.getResponse(res);
|
|
1615
1926
|
}
|
|
1616
1927
|
|
|
1617
|
-
// src/tasks/tabular/
|
|
1618
|
-
async function
|
|
1928
|
+
// src/tasks/tabular/tabularClassification.ts
|
|
1929
|
+
async function tabularClassification(args, options) {
|
|
1930
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification");
|
|
1619
1931
|
const { data: res } = await innerRequest(args, {
|
|
1620
1932
|
...options,
|
|
1621
|
-
task: "tabular-
|
|
1933
|
+
task: "tabular-classification"
|
|
1622
1934
|
});
|
|
1623
|
-
|
|
1624
|
-
if (!isValidOutput) {
|
|
1625
|
-
throw new InferenceOutputError("Expected number[]");
|
|
1626
|
-
}
|
|
1627
|
-
return res;
|
|
1935
|
+
return providerHelper.getResponse(res);
|
|
1628
1936
|
}
|
|
1629
1937
|
|
|
1630
|
-
// src/tasks/tabular/
|
|
1631
|
-
async function
|
|
1938
|
+
// src/tasks/tabular/tabularRegression.ts
|
|
1939
|
+
async function tabularRegression(args, options) {
|
|
1940
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression");
|
|
1632
1941
|
const { data: res } = await innerRequest(args, {
|
|
1633
1942
|
...options,
|
|
1634
|
-
task: "tabular-
|
|
1943
|
+
task: "tabular-regression"
|
|
1635
1944
|
});
|
|
1636
|
-
|
|
1637
|
-
if (!isValidOutput) {
|
|
1638
|
-
throw new InferenceOutputError("Expected number[]");
|
|
1639
|
-
}
|
|
1640
|
-
return res;
|
|
1945
|
+
return providerHelper.getResponse(res);
|
|
1641
1946
|
}
|
|
1642
1947
|
|
|
1643
1948
|
// src/InferenceClient.ts
|
|
@@ -1706,11 +2011,11 @@ __export(snippets_exports, {
|
|
|
1706
2011
|
});
|
|
1707
2012
|
|
|
1708
2013
|
// src/snippets/getInferenceSnippets.ts
|
|
2014
|
+
import { Template } from "@huggingface/jinja";
|
|
1709
2015
|
import {
|
|
1710
|
-
|
|
1711
|
-
|
|
2016
|
+
getModelInputSnippet,
|
|
2017
|
+
inferenceSnippetLanguages
|
|
1712
2018
|
} from "@huggingface/tasks";
|
|
1713
|
-
import { Template } from "@huggingface/jinja";
|
|
1714
2019
|
|
|
1715
2020
|
// src/snippets/templates.exported.ts
|
|
1716
2021
|
var templates = {
|
|
@@ -1753,7 +2058,7 @@ const image = await client.textToVideo({
|
|
|
1753
2058
|
},
|
|
1754
2059
|
"openai": {
|
|
1755
2060
|
"conversational": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\nconst chatCompletion = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);',
|
|
1756
|
-
"conversationalStream": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\
|
|
2061
|
+
"conversationalStream": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\nconst stream = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n stream: true,\n});\n\nfor await (const chunk of stream) {\n process.stdout.write(chunk.choices[0]?.delta?.content || "");\n}'
|
|
1757
2062
|
}
|
|
1758
2063
|
},
|
|
1759
2064
|
"python": {
|
|
@@ -1885,15 +2190,23 @@ var HF_JS_METHODS = {
|
|
|
1885
2190
|
};
|
|
1886
2191
|
var snippetGenerator = (templateName, inputPreparationFn) => {
|
|
1887
2192
|
return (model, accessToken, provider, providerModelId, opts) => {
|
|
2193
|
+
let task = model.pipeline_tag;
|
|
1888
2194
|
if (model.pipeline_tag && ["text-generation", "image-text-to-text"].includes(model.pipeline_tag) && model.tags.includes("conversational")) {
|
|
1889
2195
|
templateName = opts?.streaming ? "conversationalStream" : "conversational";
|
|
1890
2196
|
inputPreparationFn = prepareConversationalInput;
|
|
2197
|
+
task = "conversational";
|
|
1891
2198
|
}
|
|
1892
2199
|
const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
|
|
1893
2200
|
const request2 = makeRequestOptionsFromResolvedModel(
|
|
1894
2201
|
providerModelId ?? model.id,
|
|
1895
|
-
{
|
|
1896
|
-
|
|
2202
|
+
{
|
|
2203
|
+
accessToken,
|
|
2204
|
+
provider,
|
|
2205
|
+
...inputs
|
|
2206
|
+
},
|
|
2207
|
+
{
|
|
2208
|
+
task
|
|
2209
|
+
}
|
|
1897
2210
|
);
|
|
1898
2211
|
let providerInputs = inputs;
|
|
1899
2212
|
const bodyAsObj = request2.info.body;
|
|
@@ -1980,7 +2293,7 @@ var prepareConversationalInput = (model, opts) => {
|
|
|
1980
2293
|
return {
|
|
1981
2294
|
messages: opts?.messages ?? getModelInputSnippet(model),
|
|
1982
2295
|
...opts?.temperature ? { temperature: opts?.temperature } : void 0,
|
|
1983
|
-
max_tokens: opts?.max_tokens ??
|
|
2296
|
+
max_tokens: opts?.max_tokens ?? 512,
|
|
1984
2297
|
...opts?.top_p ? { top_p: opts?.top_p } : void 0
|
|
1985
2298
|
};
|
|
1986
2299
|
};
|