@huggingface/inference 3.6.2 → 3.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +0 -25
- package/dist/index.cjs +1232 -898
- package/dist/index.js +1234 -900
- package/dist/src/config.d.ts +1 -0
- package/dist/src/config.d.ts.map +1 -1
- package/dist/src/lib/getProviderHelper.d.ts +37 -0
- package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
- package/dist/src/lib/makeRequestOptions.d.ts +0 -2
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/black-forest-labs.d.ts +14 -18
- package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
- package/dist/src/providers/cerebras.d.ts +4 -2
- package/dist/src/providers/cerebras.d.ts.map +1 -1
- package/dist/src/providers/cohere.d.ts +5 -2
- package/dist/src/providers/cohere.d.ts.map +1 -1
- package/dist/src/providers/fal-ai.d.ts +50 -3
- package/dist/src/providers/fal-ai.d.ts.map +1 -1
- package/dist/src/providers/fireworks-ai.d.ts +5 -2
- package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
- package/dist/src/providers/hf-inference.d.ts +125 -2
- package/dist/src/providers/hf-inference.d.ts.map +1 -1
- package/dist/src/providers/hyperbolic.d.ts +31 -2
- package/dist/src/providers/hyperbolic.d.ts.map +1 -1
- package/dist/src/providers/nebius.d.ts +20 -18
- package/dist/src/providers/nebius.d.ts.map +1 -1
- package/dist/src/providers/novita.d.ts +21 -18
- package/dist/src/providers/novita.d.ts.map +1 -1
- package/dist/src/providers/openai.d.ts +4 -2
- package/dist/src/providers/openai.d.ts.map +1 -1
- package/dist/src/providers/providerHelper.d.ts +182 -0
- package/dist/src/providers/providerHelper.d.ts.map +1 -0
- package/dist/src/providers/replicate.d.ts +23 -19
- package/dist/src/providers/replicate.d.ts.map +1 -1
- package/dist/src/providers/sambanova.d.ts +4 -2
- package/dist/src/providers/sambanova.d.ts.map +1 -1
- package/dist/src/providers/together.d.ts +32 -2
- package/dist/src/providers/together.d.ts.map +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/audio/utils.d.ts +2 -1
- package/dist/src/tasks/audio/utils.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts +1 -2
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts +1 -2
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/index.d.ts +6 -6
- package/dist/src/tasks/index.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
- package/dist/src/types.d.ts +10 -13
- package/dist/src/types.d.ts.map +1 -1
- package/dist/src/utils/request.d.ts +27 -0
- package/dist/src/utils/request.d.ts.map +1 -0
- package/package.json +3 -3
- package/src/config.ts +1 -0
- package/src/lib/getProviderHelper.ts +270 -0
- package/src/lib/makeRequestOptions.ts +36 -90
- package/src/providers/black-forest-labs.ts +73 -22
- package/src/providers/cerebras.ts +6 -27
- package/src/providers/cohere.ts +9 -28
- package/src/providers/fal-ai.ts +195 -77
- package/src/providers/fireworks-ai.ts +8 -29
- package/src/providers/hf-inference.ts +555 -34
- package/src/providers/hyperbolic.ts +107 -29
- package/src/providers/nebius.ts +65 -29
- package/src/providers/novita.ts +68 -32
- package/src/providers/openai.ts +6 -32
- package/src/providers/providerHelper.ts +354 -0
- package/src/providers/replicate.ts +124 -34
- package/src/providers/sambanova.ts +5 -30
- package/src/providers/together.ts +92 -28
- package/src/snippets/getInferenceSnippets.ts +16 -9
- package/src/snippets/templates.exported.ts +2 -2
- package/src/tasks/audio/audioClassification.ts +6 -9
- package/src/tasks/audio/audioToAudio.ts +5 -28
- package/src/tasks/audio/automaticSpeechRecognition.ts +7 -6
- package/src/tasks/audio/textToSpeech.ts +6 -30
- package/src/tasks/audio/utils.ts +2 -1
- package/src/tasks/custom/request.ts +7 -34
- package/src/tasks/custom/streamingRequest.ts +5 -87
- package/src/tasks/cv/imageClassification.ts +5 -9
- package/src/tasks/cv/imageSegmentation.ts +5 -10
- package/src/tasks/cv/imageToImage.ts +5 -8
- package/src/tasks/cv/imageToText.ts +8 -13
- package/src/tasks/cv/objectDetection.ts +6 -21
- package/src/tasks/cv/textToImage.ts +10 -138
- package/src/tasks/cv/textToVideo.ts +11 -59
- package/src/tasks/cv/zeroShotImageClassification.ts +7 -12
- package/src/tasks/index.ts +6 -6
- package/src/tasks/multimodal/documentQuestionAnswering.ts +10 -26
- package/src/tasks/multimodal/visualQuestionAnswering.ts +6 -12
- package/src/tasks/nlp/chatCompletion.ts +7 -23
- package/src/tasks/nlp/chatCompletionStream.ts +4 -5
- package/src/tasks/nlp/featureExtraction.ts +5 -20
- package/src/tasks/nlp/fillMask.ts +5 -18
- package/src/tasks/nlp/questionAnswering.ts +5 -23
- package/src/tasks/nlp/sentenceSimilarity.ts +5 -18
- package/src/tasks/nlp/summarization.ts +5 -8
- package/src/tasks/nlp/tableQuestionAnswering.ts +5 -29
- package/src/tasks/nlp/textClassification.ts +8 -14
- package/src/tasks/nlp/textGeneration.ts +13 -80
- package/src/tasks/nlp/textGenerationStream.ts +2 -2
- package/src/tasks/nlp/tokenClassification.ts +8 -24
- package/src/tasks/nlp/translation.ts +5 -8
- package/src/tasks/nlp/zeroShotClassification.ts +8 -22
- package/src/tasks/tabular/tabularClassification.ts +5 -8
- package/src/tasks/tabular/tabularRegression.ts +5 -8
- package/src/types.ts +11 -14
- package/src/utils/request.ts +161 -0
package/dist/index.js
CHANGED
|
@@ -41,90 +41,215 @@ __export(tasks_exports, {
|
|
|
41
41
|
zeroShotImageClassification: () => zeroShotImageClassification
|
|
42
42
|
});
|
|
43
43
|
|
|
44
|
+
// package.json
|
|
45
|
+
var name = "@huggingface/inference";
|
|
46
|
+
var version = "3.7.1";
|
|
47
|
+
|
|
44
48
|
// src/config.ts
|
|
45
49
|
var HF_HUB_URL = "https://huggingface.co";
|
|
46
50
|
var HF_ROUTER_URL = "https://router.huggingface.co";
|
|
51
|
+
var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
|
|
47
52
|
|
|
48
|
-
// src/
|
|
49
|
-
var
|
|
50
|
-
|
|
51
|
-
|
|
53
|
+
// src/lib/InferenceOutputError.ts
|
|
54
|
+
var InferenceOutputError = class extends TypeError {
|
|
55
|
+
constructor(message) {
|
|
56
|
+
super(
|
|
57
|
+
`Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.`
|
|
58
|
+
);
|
|
59
|
+
this.name = "InferenceOutputError";
|
|
60
|
+
}
|
|
52
61
|
};
|
|
53
|
-
|
|
54
|
-
|
|
62
|
+
|
|
63
|
+
// src/utils/delay.ts
|
|
64
|
+
function delay(ms) {
|
|
65
|
+
return new Promise((resolve) => {
|
|
66
|
+
setTimeout(() => resolve(), ms);
|
|
67
|
+
});
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
// src/utils/pick.ts
|
|
71
|
+
function pick(o, props) {
|
|
72
|
+
return Object.assign(
|
|
73
|
+
{},
|
|
74
|
+
...props.map((prop) => {
|
|
75
|
+
if (o[prop] !== void 0) {
|
|
76
|
+
return { [prop]: o[prop] };
|
|
77
|
+
}
|
|
78
|
+
})
|
|
79
|
+
);
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
// src/utils/typedInclude.ts
|
|
83
|
+
function typedInclude(arr, v) {
|
|
84
|
+
return arr.includes(v);
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
// src/utils/omit.ts
|
|
88
|
+
function omit(o, props) {
|
|
89
|
+
const propsArr = Array.isArray(props) ? props : [props];
|
|
90
|
+
const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
|
|
91
|
+
return pick(o, letsKeep);
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
// src/utils/toArray.ts
|
|
95
|
+
function toArray(obj) {
|
|
96
|
+
if (Array.isArray(obj)) {
|
|
97
|
+
return obj;
|
|
98
|
+
}
|
|
99
|
+
return [obj];
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
// src/providers/providerHelper.ts
|
|
103
|
+
var TaskProviderHelper = class {
|
|
104
|
+
constructor(provider, baseUrl, clientSideRoutingOnly = false) {
|
|
105
|
+
this.provider = provider;
|
|
106
|
+
this.baseUrl = baseUrl;
|
|
107
|
+
this.clientSideRoutingOnly = clientSideRoutingOnly;
|
|
108
|
+
}
|
|
109
|
+
/**
|
|
110
|
+
* Prepare the base URL for the request
|
|
111
|
+
*/
|
|
112
|
+
makeBaseUrl(params) {
|
|
113
|
+
return params.authMethod !== "provider-key" ? `${HF_ROUTER_URL}/${this.provider}` : this.baseUrl;
|
|
114
|
+
}
|
|
115
|
+
/**
|
|
116
|
+
* Prepare the body for the request
|
|
117
|
+
*/
|
|
118
|
+
makeBody(params) {
|
|
119
|
+
if ("data" in params.args && !!params.args.data) {
|
|
120
|
+
return params.args.data;
|
|
121
|
+
}
|
|
122
|
+
return JSON.stringify(this.preparePayload(params));
|
|
123
|
+
}
|
|
124
|
+
/**
|
|
125
|
+
* Prepare the URL for the request
|
|
126
|
+
*/
|
|
127
|
+
makeUrl(params) {
|
|
128
|
+
const baseUrl = this.makeBaseUrl(params);
|
|
129
|
+
const route = this.makeRoute(params).replace(/^\/+/, "");
|
|
130
|
+
return `${baseUrl}/${route}`;
|
|
131
|
+
}
|
|
132
|
+
/**
|
|
133
|
+
* Prepare the headers for the request
|
|
134
|
+
*/
|
|
135
|
+
prepareHeaders(params, isBinary) {
|
|
136
|
+
const headers = { Authorization: `Bearer ${params.accessToken}` };
|
|
137
|
+
if (!isBinary) {
|
|
138
|
+
headers["Content-Type"] = "application/json";
|
|
139
|
+
}
|
|
140
|
+
return headers;
|
|
141
|
+
}
|
|
55
142
|
};
|
|
56
|
-
var
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
}
|
|
60
|
-
|
|
143
|
+
var BaseConversationalTask = class extends TaskProviderHelper {
|
|
144
|
+
constructor(provider, baseUrl, clientSideRoutingOnly = false) {
|
|
145
|
+
super(provider, baseUrl, clientSideRoutingOnly);
|
|
146
|
+
}
|
|
147
|
+
makeRoute() {
|
|
148
|
+
return "v1/chat/completions";
|
|
149
|
+
}
|
|
150
|
+
preparePayload(params) {
|
|
151
|
+
return {
|
|
152
|
+
...params.args,
|
|
153
|
+
model: params.model
|
|
154
|
+
};
|
|
155
|
+
}
|
|
156
|
+
async getResponse(response) {
|
|
157
|
+
if (typeof response === "object" && Array.isArray(response?.choices) && typeof response?.created === "number" && typeof response?.id === "string" && typeof response?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
|
|
158
|
+
(response.system_fingerprint === void 0 || response.system_fingerprint === null || typeof response.system_fingerprint === "string") && typeof response?.usage === "object") {
|
|
159
|
+
return response;
|
|
160
|
+
}
|
|
161
|
+
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
61
162
|
}
|
|
62
163
|
};
|
|
63
|
-
var
|
|
64
|
-
|
|
164
|
+
var BaseTextGenerationTask = class extends TaskProviderHelper {
|
|
165
|
+
constructor(provider, baseUrl, clientSideRoutingOnly = false) {
|
|
166
|
+
super(provider, baseUrl, clientSideRoutingOnly);
|
|
167
|
+
}
|
|
168
|
+
preparePayload(params) {
|
|
169
|
+
return {
|
|
170
|
+
...params.args,
|
|
171
|
+
model: params.model
|
|
172
|
+
};
|
|
173
|
+
}
|
|
174
|
+
makeRoute() {
|
|
175
|
+
return "v1/completions";
|
|
176
|
+
}
|
|
177
|
+
async getResponse(response) {
|
|
178
|
+
const res = toArray(response);
|
|
179
|
+
if (Array.isArray(res) && res.length > 0 && res.every(
|
|
180
|
+
(x) => typeof x === "object" && !!x && "generated_text" in x && typeof x.generated_text === "string"
|
|
181
|
+
)) {
|
|
182
|
+
return res[0];
|
|
183
|
+
}
|
|
184
|
+
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
185
|
+
}
|
|
65
186
|
};
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
187
|
+
|
|
188
|
+
// src/providers/black-forest-labs.ts
|
|
189
|
+
var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
|
|
190
|
+
var BlackForestLabsTextToImageTask = class extends TaskProviderHelper {
|
|
191
|
+
constructor() {
|
|
192
|
+
super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
|
|
193
|
+
}
|
|
194
|
+
preparePayload(params) {
|
|
195
|
+
return {
|
|
196
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
197
|
+
...params.args.parameters,
|
|
198
|
+
prompt: params.args.inputs
|
|
199
|
+
};
|
|
200
|
+
}
|
|
201
|
+
prepareHeaders(params, binary) {
|
|
202
|
+
const headers = {
|
|
203
|
+
Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`
|
|
204
|
+
};
|
|
205
|
+
if (!binary) {
|
|
206
|
+
headers["Content-Type"] = "application/json";
|
|
207
|
+
}
|
|
208
|
+
return headers;
|
|
209
|
+
}
|
|
210
|
+
makeRoute(params) {
|
|
211
|
+
if (!params) {
|
|
212
|
+
throw new Error("Params are required");
|
|
213
|
+
}
|
|
214
|
+
return `/v1/${params.model}`;
|
|
215
|
+
}
|
|
216
|
+
async getResponse(response, url, headers, outputType) {
|
|
217
|
+
const urlObj = new URL(response.polling_url);
|
|
218
|
+
for (let step = 0; step < 5; step++) {
|
|
219
|
+
await delay(1e3);
|
|
220
|
+
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
|
|
221
|
+
urlObj.searchParams.set("attempt", step.toString(10));
|
|
222
|
+
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
|
|
223
|
+
if (!resp.ok) {
|
|
224
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
225
|
+
}
|
|
226
|
+
const payload = await resp.json();
|
|
227
|
+
if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
|
|
228
|
+
if (outputType === "url") {
|
|
229
|
+
return payload.result.sample;
|
|
230
|
+
}
|
|
231
|
+
const image = await fetch(payload.result.sample);
|
|
232
|
+
return await image.blob();
|
|
233
|
+
}
|
|
234
|
+
}
|
|
235
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
236
|
+
}
|
|
71
237
|
};
|
|
72
238
|
|
|
73
239
|
// src/providers/cerebras.ts
|
|
74
|
-
var
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
}
|
|
78
|
-
var makeBody2 = (params) => {
|
|
79
|
-
return {
|
|
80
|
-
...params.args,
|
|
81
|
-
model: params.model
|
|
82
|
-
};
|
|
83
|
-
};
|
|
84
|
-
var makeHeaders2 = (params) => {
|
|
85
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
86
|
-
};
|
|
87
|
-
var makeUrl2 = (params) => {
|
|
88
|
-
return `${params.baseUrl}/v1/chat/completions`;
|
|
89
|
-
};
|
|
90
|
-
var CEREBRAS_CONFIG = {
|
|
91
|
-
makeBaseUrl: makeBaseUrl2,
|
|
92
|
-
makeBody: makeBody2,
|
|
93
|
-
makeHeaders: makeHeaders2,
|
|
94
|
-
makeUrl: makeUrl2
|
|
240
|
+
var CerebrasConversationalTask = class extends BaseConversationalTask {
|
|
241
|
+
constructor() {
|
|
242
|
+
super("cerebras", "https://api.cerebras.ai");
|
|
243
|
+
}
|
|
95
244
|
};
|
|
96
245
|
|
|
97
246
|
// src/providers/cohere.ts
|
|
98
|
-
var
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
}
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
...params.args,
|
|
105
|
-
model: params.model
|
|
106
|
-
};
|
|
107
|
-
};
|
|
108
|
-
var makeHeaders3 = (params) => {
|
|
109
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
110
|
-
};
|
|
111
|
-
var makeUrl3 = (params) => {
|
|
112
|
-
return `${params.baseUrl}/compatibility/v1/chat/completions`;
|
|
113
|
-
};
|
|
114
|
-
var COHERE_CONFIG = {
|
|
115
|
-
makeBaseUrl: makeBaseUrl3,
|
|
116
|
-
makeBody: makeBody3,
|
|
117
|
-
makeHeaders: makeHeaders3,
|
|
118
|
-
makeUrl: makeUrl3
|
|
119
|
-
};
|
|
120
|
-
|
|
121
|
-
// src/lib/InferenceOutputError.ts
|
|
122
|
-
var InferenceOutputError = class extends TypeError {
|
|
123
|
-
constructor(message) {
|
|
124
|
-
super(
|
|
125
|
-
`Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.`
|
|
126
|
-
);
|
|
127
|
-
this.name = "InferenceOutputError";
|
|
247
|
+
var CohereConversationalTask = class extends BaseConversationalTask {
|
|
248
|
+
constructor() {
|
|
249
|
+
super("cohere", "https://api.cohere.com");
|
|
250
|
+
}
|
|
251
|
+
makeRoute() {
|
|
252
|
+
return "/compatibility/v1/chat/completions";
|
|
128
253
|
}
|
|
129
254
|
};
|
|
130
255
|
|
|
@@ -133,349 +258,871 @@ function isUrl(modelOrUrl) {
|
|
|
133
258
|
return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
|
|
134
259
|
}
|
|
135
260
|
|
|
136
|
-
// src/utils/delay.ts
|
|
137
|
-
function delay(ms) {
|
|
138
|
-
return new Promise((resolve) => {
|
|
139
|
-
setTimeout(() => resolve(), ms);
|
|
140
|
-
});
|
|
141
|
-
}
|
|
142
|
-
|
|
143
261
|
// src/providers/fal-ai.ts
|
|
144
|
-
var
|
|
145
|
-
var
|
|
146
|
-
|
|
147
|
-
|
|
262
|
+
var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
|
|
263
|
+
var FalAITask = class extends TaskProviderHelper {
|
|
264
|
+
constructor(url) {
|
|
265
|
+
super("fal-ai", url || "https://fal.run");
|
|
266
|
+
}
|
|
267
|
+
preparePayload(params) {
|
|
268
|
+
return params.args;
|
|
269
|
+
}
|
|
270
|
+
makeRoute(params) {
|
|
271
|
+
return `/${params.model}`;
|
|
272
|
+
}
|
|
273
|
+
prepareHeaders(params, binary) {
|
|
274
|
+
const headers = {
|
|
275
|
+
Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`
|
|
276
|
+
};
|
|
277
|
+
if (!binary) {
|
|
278
|
+
headers["Content-Type"] = "application/json";
|
|
279
|
+
}
|
|
280
|
+
return headers;
|
|
281
|
+
}
|
|
148
282
|
};
|
|
149
|
-
var
|
|
150
|
-
|
|
283
|
+
var FalAITextToImageTask = class extends FalAITask {
|
|
284
|
+
preparePayload(params) {
|
|
285
|
+
return {
|
|
286
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
287
|
+
...params.args.parameters,
|
|
288
|
+
sync_mode: true,
|
|
289
|
+
prompt: params.args.inputs
|
|
290
|
+
};
|
|
291
|
+
}
|
|
292
|
+
async getResponse(response, outputType) {
|
|
293
|
+
if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") {
|
|
294
|
+
if (outputType === "url") {
|
|
295
|
+
return response.images[0].url;
|
|
296
|
+
}
|
|
297
|
+
const urlResponse = await fetch(response.images[0].url);
|
|
298
|
+
return await urlResponse.blob();
|
|
299
|
+
}
|
|
300
|
+
throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
|
|
301
|
+
}
|
|
151
302
|
};
|
|
152
|
-
var
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
}
|
|
303
|
+
var FalAITextToVideoTask = class extends FalAITask {
|
|
304
|
+
constructor() {
|
|
305
|
+
super("https://queue.fal.run");
|
|
306
|
+
}
|
|
307
|
+
makeRoute(params) {
|
|
308
|
+
if (params.authMethod !== "provider-key") {
|
|
309
|
+
return `/${params.model}?_subdomain=queue`;
|
|
310
|
+
}
|
|
311
|
+
return `/${params.model}`;
|
|
312
|
+
}
|
|
313
|
+
preparePayload(params) {
|
|
314
|
+
return {
|
|
315
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
316
|
+
...params.args.parameters,
|
|
317
|
+
prompt: params.args.inputs
|
|
318
|
+
};
|
|
319
|
+
}
|
|
320
|
+
async getResponse(response, url, headers) {
|
|
321
|
+
if (!url || !headers) {
|
|
322
|
+
throw new InferenceOutputError("URL and headers are required for text-to-video task");
|
|
323
|
+
}
|
|
324
|
+
const requestId = response.request_id;
|
|
325
|
+
if (!requestId) {
|
|
326
|
+
throw new InferenceOutputError("No request ID found in the response");
|
|
327
|
+
}
|
|
328
|
+
let status = response.status;
|
|
329
|
+
const parsedUrl = new URL(url);
|
|
330
|
+
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
|
|
331
|
+
const modelId = new URL(response.response_url).pathname;
|
|
332
|
+
const queryParams = parsedUrl.search;
|
|
333
|
+
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
|
|
334
|
+
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
|
|
335
|
+
while (status !== "COMPLETED") {
|
|
336
|
+
await delay(500);
|
|
337
|
+
const statusResponse = await fetch(statusUrl, { headers });
|
|
338
|
+
if (!statusResponse.ok) {
|
|
339
|
+
throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
|
|
340
|
+
}
|
|
341
|
+
try {
|
|
342
|
+
status = (await statusResponse.json()).status;
|
|
343
|
+
} catch (error) {
|
|
344
|
+
throw new InferenceOutputError("Failed to parse status response from fal-ai API");
|
|
345
|
+
}
|
|
346
|
+
}
|
|
347
|
+
const resultResponse = await fetch(resultUrl, { headers });
|
|
348
|
+
let result;
|
|
349
|
+
try {
|
|
350
|
+
result = await resultResponse.json();
|
|
351
|
+
} catch (error) {
|
|
352
|
+
throw new InferenceOutputError("Failed to parse result response from fal-ai API");
|
|
353
|
+
}
|
|
354
|
+
if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
|
|
355
|
+
const urlResponse = await fetch(result.video.url);
|
|
356
|
+
return await urlResponse.blob();
|
|
357
|
+
} else {
|
|
358
|
+
throw new InferenceOutputError(
|
|
359
|
+
"Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
|
|
360
|
+
);
|
|
361
|
+
}
|
|
362
|
+
}
|
|
363
|
+
};
|
|
364
|
+
var FalAIAutomaticSpeechRecognitionTask = class extends FalAITask {
|
|
365
|
+
prepareHeaders(params, binary) {
|
|
366
|
+
const headers = super.prepareHeaders(params, binary);
|
|
367
|
+
headers["Content-Type"] = "application/json";
|
|
368
|
+
return headers;
|
|
369
|
+
}
|
|
370
|
+
async getResponse(response) {
|
|
371
|
+
const res = response;
|
|
372
|
+
if (typeof res?.text !== "string") {
|
|
373
|
+
throw new InferenceOutputError(
|
|
374
|
+
`Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
|
|
375
|
+
);
|
|
376
|
+
}
|
|
377
|
+
return { text: res.text };
|
|
378
|
+
}
|
|
156
379
|
};
|
|
157
|
-
var
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
};
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
}
|
|
170
|
-
|
|
171
|
-
const requestId = res.request_id;
|
|
172
|
-
if (!requestId) {
|
|
173
|
-
throw new InferenceOutputError("No request ID found in the response");
|
|
174
|
-
}
|
|
175
|
-
let status = res.status;
|
|
176
|
-
const parsedUrl = new URL(url);
|
|
177
|
-
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
|
|
178
|
-
const modelId = new URL(res.response_url).pathname;
|
|
179
|
-
const queryParams = parsedUrl.search;
|
|
180
|
-
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
|
|
181
|
-
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
|
|
182
|
-
while (status !== "COMPLETED") {
|
|
183
|
-
await delay(500);
|
|
184
|
-
const statusResponse = await fetch(statusUrl, { headers });
|
|
185
|
-
if (!statusResponse.ok) {
|
|
186
|
-
throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
|
|
380
|
+
var FalAITextToSpeechTask = class extends FalAITask {
|
|
381
|
+
preparePayload(params) {
|
|
382
|
+
return {
|
|
383
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
384
|
+
...params.args.parameters,
|
|
385
|
+
lyrics: params.args.inputs
|
|
386
|
+
};
|
|
387
|
+
}
|
|
388
|
+
async getResponse(response) {
|
|
389
|
+
const res = response;
|
|
390
|
+
if (typeof res?.audio?.url !== "string") {
|
|
391
|
+
throw new InferenceOutputError(
|
|
392
|
+
`Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
|
|
393
|
+
);
|
|
187
394
|
}
|
|
188
395
|
try {
|
|
189
|
-
|
|
396
|
+
const urlResponse = await fetch(res.audio.url);
|
|
397
|
+
if (!urlResponse.ok) {
|
|
398
|
+
throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
|
|
399
|
+
}
|
|
400
|
+
return await urlResponse.blob();
|
|
190
401
|
} catch (error) {
|
|
191
|
-
throw new InferenceOutputError(
|
|
402
|
+
throw new InferenceOutputError(
|
|
403
|
+
`Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)}`
|
|
404
|
+
);
|
|
192
405
|
}
|
|
193
406
|
}
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
407
|
+
};
|
|
408
|
+
|
|
409
|
+
// src/providers/fireworks-ai.ts
|
|
410
|
+
var FireworksConversationalTask = class extends BaseConversationalTask {
|
|
411
|
+
constructor() {
|
|
412
|
+
super("fireworks-ai", "https://api.fireworks.ai");
|
|
200
413
|
}
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
414
|
+
makeRoute() {
|
|
415
|
+
return "/inference/v1/chat/completions";
|
|
416
|
+
}
|
|
417
|
+
};
|
|
418
|
+
|
|
419
|
+
// src/providers/hf-inference.ts
|
|
420
|
+
var HFInferenceTask = class extends TaskProviderHelper {
|
|
421
|
+
constructor() {
|
|
422
|
+
super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
|
|
423
|
+
}
|
|
424
|
+
preparePayload(params) {
|
|
425
|
+
return params.args;
|
|
426
|
+
}
|
|
427
|
+
makeUrl(params) {
|
|
428
|
+
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
|
|
429
|
+
return params.model;
|
|
430
|
+
}
|
|
431
|
+
return super.makeUrl(params);
|
|
432
|
+
}
|
|
433
|
+
makeRoute(params) {
|
|
434
|
+
if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
|
|
435
|
+
return `pipeline/${params.task}/${params.model}`;
|
|
436
|
+
}
|
|
437
|
+
return `models/${params.model}`;
|
|
438
|
+
}
|
|
439
|
+
async getResponse(response) {
|
|
440
|
+
return response;
|
|
441
|
+
}
|
|
442
|
+
};
|
|
443
|
+
var HFInferenceTextToImageTask = class extends HFInferenceTask {
|
|
444
|
+
async getResponse(response, url, headers, outputType) {
|
|
445
|
+
if (!response) {
|
|
446
|
+
throw new InferenceOutputError("response is undefined");
|
|
447
|
+
}
|
|
448
|
+
if (typeof response == "object") {
|
|
449
|
+
if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
|
|
450
|
+
const base64Data = response.data[0].b64_json;
|
|
451
|
+
if (outputType === "url") {
|
|
452
|
+
return `data:image/jpeg;base64,${base64Data}`;
|
|
453
|
+
}
|
|
454
|
+
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
455
|
+
return await base64Response.blob();
|
|
456
|
+
}
|
|
457
|
+
if ("output" in response && Array.isArray(response.output)) {
|
|
458
|
+
if (outputType === "url") {
|
|
459
|
+
return response.output[0];
|
|
460
|
+
}
|
|
461
|
+
const urlResponse = await fetch(response.output[0]);
|
|
462
|
+
const blob = await urlResponse.blob();
|
|
463
|
+
return blob;
|
|
464
|
+
}
|
|
465
|
+
}
|
|
466
|
+
if (response instanceof Blob) {
|
|
467
|
+
if (outputType === "url") {
|
|
468
|
+
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
|
|
469
|
+
return `data:image/jpeg;base64,${b64}`;
|
|
470
|
+
}
|
|
471
|
+
return response;
|
|
472
|
+
}
|
|
473
|
+
throw new InferenceOutputError("Expected a Blob ");
|
|
474
|
+
}
|
|
475
|
+
};
|
|
476
|
+
var HFInferenceConversationalTask = class extends HFInferenceTask {
|
|
477
|
+
makeUrl(params) {
|
|
478
|
+
let url;
|
|
479
|
+
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
|
|
480
|
+
url = params.model.trim();
|
|
481
|
+
} else {
|
|
482
|
+
url = `${this.makeBaseUrl(params)}/models/${params.model}`;
|
|
483
|
+
}
|
|
484
|
+
url = url.replace(/\/+$/, "");
|
|
485
|
+
if (url.endsWith("/v1")) {
|
|
486
|
+
url += "/chat/completions";
|
|
487
|
+
} else if (!url.endsWith("/chat/completions")) {
|
|
488
|
+
url += "/v1/chat/completions";
|
|
489
|
+
}
|
|
490
|
+
return url;
|
|
491
|
+
}
|
|
492
|
+
preparePayload(params) {
|
|
493
|
+
return {
|
|
494
|
+
...params.args,
|
|
495
|
+
model: params.model
|
|
496
|
+
};
|
|
497
|
+
}
|
|
498
|
+
async getResponse(response) {
|
|
499
|
+
return response;
|
|
500
|
+
}
|
|
501
|
+
};
|
|
502
|
+
var HFInferenceTextGenerationTask = class extends HFInferenceTask {
|
|
503
|
+
async getResponse(response) {
|
|
504
|
+
const res = toArray(response);
|
|
505
|
+
if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
|
|
506
|
+
return res?.[0];
|
|
507
|
+
}
|
|
508
|
+
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
509
|
+
}
|
|
510
|
+
};
|
|
511
|
+
var HFInferenceAudioClassificationTask = class extends HFInferenceTask {
|
|
512
|
+
async getResponse(response) {
|
|
513
|
+
if (Array.isArray(response) && response.every(
|
|
514
|
+
(x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
|
|
515
|
+
)) {
|
|
516
|
+
return response;
|
|
517
|
+
}
|
|
518
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
|
|
519
|
+
}
|
|
520
|
+
};
|
|
521
|
+
var HFInferenceAutomaticSpeechRecognitionTask = class extends HFInferenceTask {
|
|
522
|
+
async getResponse(response) {
|
|
523
|
+
return response;
|
|
524
|
+
}
|
|
525
|
+
};
|
|
526
|
+
var HFInferenceAudioToAudioTask = class extends HFInferenceTask {
|
|
527
|
+
async getResponse(response) {
|
|
528
|
+
if (!Array.isArray(response)) {
|
|
529
|
+
throw new InferenceOutputError("Expected Array");
|
|
530
|
+
}
|
|
531
|
+
if (!response.every((elem) => {
|
|
532
|
+
return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
|
|
533
|
+
})) {
|
|
534
|
+
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
|
|
535
|
+
}
|
|
536
|
+
return response;
|
|
537
|
+
}
|
|
538
|
+
};
|
|
539
|
+
var HFInferenceDocumentQuestionAnsweringTask = class extends HFInferenceTask {
|
|
540
|
+
async getResponse(response) {
|
|
541
|
+
if (Array.isArray(response) && response.every(
|
|
542
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
|
|
543
|
+
)) {
|
|
544
|
+
return response[0];
|
|
545
|
+
}
|
|
546
|
+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
547
|
+
}
|
|
548
|
+
};
|
|
549
|
+
var HFInferenceFeatureExtractionTask = class extends HFInferenceTask {
|
|
550
|
+
async getResponse(response) {
|
|
551
|
+
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
|
|
552
|
+
if (curDepth > maxDepth)
|
|
553
|
+
return false;
|
|
554
|
+
if (arr.every((x) => Array.isArray(x))) {
|
|
555
|
+
return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
|
|
556
|
+
} else {
|
|
557
|
+
return arr.every((x) => typeof x === "number");
|
|
558
|
+
}
|
|
559
|
+
};
|
|
560
|
+
if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
|
|
561
|
+
return response;
|
|
562
|
+
}
|
|
563
|
+
throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
|
|
564
|
+
}
|
|
565
|
+
};
|
|
566
|
+
var HFInferenceImageClassificationTask = class extends HFInferenceTask {
|
|
567
|
+
async getResponse(response) {
|
|
568
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
|
|
569
|
+
return response;
|
|
570
|
+
}
|
|
571
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
572
|
+
}
|
|
573
|
+
};
|
|
574
|
+
var HFInferenceImageSegmentationTask = class extends HFInferenceTask {
|
|
575
|
+
async getResponse(response) {
|
|
576
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) {
|
|
577
|
+
return response;
|
|
578
|
+
}
|
|
579
|
+
throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
|
|
580
|
+
}
|
|
581
|
+
};
|
|
582
|
+
var HFInferenceImageToTextTask = class extends HFInferenceTask {
|
|
583
|
+
async getResponse(response) {
|
|
584
|
+
if (typeof response?.generated_text !== "string") {
|
|
585
|
+
throw new InferenceOutputError("Expected {generated_text: string}");
|
|
586
|
+
}
|
|
587
|
+
return response;
|
|
588
|
+
}
|
|
589
|
+
};
|
|
590
|
+
var HFInferenceImageToImageTask = class extends HFInferenceTask {
|
|
591
|
+
async getResponse(response) {
|
|
592
|
+
if (response instanceof Blob) {
|
|
593
|
+
return response;
|
|
594
|
+
}
|
|
595
|
+
throw new InferenceOutputError("Expected Blob");
|
|
596
|
+
}
|
|
597
|
+
};
|
|
598
|
+
var HFInferenceObjectDetectionTask = class extends HFInferenceTask {
|
|
599
|
+
async getResponse(response) {
|
|
600
|
+
if (Array.isArray(response) && response.every(
|
|
601
|
+
(x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
|
|
602
|
+
)) {
|
|
603
|
+
return response;
|
|
604
|
+
}
|
|
205
605
|
throw new InferenceOutputError(
|
|
206
|
-
"Expected {
|
|
606
|
+
"Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
|
|
207
607
|
);
|
|
208
608
|
}
|
|
209
|
-
}
|
|
210
|
-
|
|
211
|
-
// src/providers/fireworks-ai.ts
|
|
212
|
-
var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai";
|
|
213
|
-
var makeBaseUrl5 = () => {
|
|
214
|
-
return FIREWORKS_AI_API_BASE_URL;
|
|
215
609
|
};
|
|
216
|
-
var
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
610
|
+
var HFInferenceZeroShotImageClassificationTask = class extends HFInferenceTask {
|
|
611
|
+
async getResponse(response) {
|
|
612
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
|
|
613
|
+
return response;
|
|
614
|
+
}
|
|
615
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
616
|
+
}
|
|
617
|
+
};
|
|
618
|
+
var HFInferenceTextClassificationTask = class extends HFInferenceTask {
|
|
619
|
+
async getResponse(response) {
|
|
620
|
+
const output = response?.[0];
|
|
621
|
+
if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
|
|
622
|
+
return output;
|
|
623
|
+
}
|
|
624
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
625
|
+
}
|
|
626
|
+
};
|
|
627
|
+
var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask {
|
|
628
|
+
async getResponse(response) {
|
|
629
|
+
if (Array.isArray(response) ? response.every(
|
|
630
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
|
|
631
|
+
) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number") {
|
|
632
|
+
return Array.isArray(response) ? response[0] : response;
|
|
633
|
+
}
|
|
634
|
+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
635
|
+
}
|
|
636
|
+
};
|
|
637
|
+
var HFInferenceFillMaskTask = class extends HFInferenceTask {
|
|
638
|
+
async getResponse(response) {
|
|
639
|
+
if (Array.isArray(response) && response.every(
|
|
640
|
+
(x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
|
|
641
|
+
)) {
|
|
642
|
+
return response;
|
|
643
|
+
}
|
|
644
|
+
throw new InferenceOutputError(
|
|
645
|
+
"Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
|
|
646
|
+
);
|
|
647
|
+
}
|
|
221
648
|
};
|
|
222
|
-
var
|
|
223
|
-
|
|
649
|
+
var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask {
|
|
650
|
+
async getResponse(response) {
|
|
651
|
+
if (Array.isArray(response) && response.every(
|
|
652
|
+
(x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
|
|
653
|
+
)) {
|
|
654
|
+
return response;
|
|
655
|
+
}
|
|
656
|
+
throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
|
|
657
|
+
}
|
|
224
658
|
};
|
|
225
|
-
var
|
|
226
|
-
|
|
227
|
-
|
|
659
|
+
var HFInferenceSentenceSimilarityTask = class extends HFInferenceTask {
|
|
660
|
+
async getResponse(response) {
|
|
661
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
662
|
+
return response;
|
|
663
|
+
}
|
|
664
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
228
665
|
}
|
|
229
|
-
return `${params.baseUrl}/inference`;
|
|
230
666
|
};
|
|
231
|
-
var
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
667
|
+
var HFInferenceTableQuestionAnsweringTask = class extends HFInferenceTask {
|
|
668
|
+
static validate(elem) {
|
|
669
|
+
return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
|
|
670
|
+
(coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
|
|
671
|
+
);
|
|
672
|
+
}
|
|
673
|
+
async getResponse(response) {
|
|
674
|
+
if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) {
|
|
675
|
+
return Array.isArray(response) ? response[0] : response;
|
|
676
|
+
}
|
|
677
|
+
throw new InferenceOutputError(
|
|
678
|
+
"Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
|
|
679
|
+
);
|
|
680
|
+
}
|
|
236
681
|
};
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
682
|
+
var HFInferenceTokenClassificationTask = class extends HFInferenceTask {
|
|
683
|
+
async getResponse(response) {
|
|
684
|
+
if (Array.isArray(response) && response.every(
|
|
685
|
+
(x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
|
|
686
|
+
)) {
|
|
687
|
+
return response;
|
|
688
|
+
}
|
|
689
|
+
throw new InferenceOutputError(
|
|
690
|
+
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
|
|
691
|
+
);
|
|
692
|
+
}
|
|
241
693
|
};
|
|
242
|
-
var
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
694
|
+
var HFInferenceTranslationTask = class extends HFInferenceTask {
|
|
695
|
+
async getResponse(response) {
|
|
696
|
+
if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
|
|
697
|
+
return response?.length === 1 ? response?.[0] : response;
|
|
698
|
+
}
|
|
699
|
+
throw new InferenceOutputError("Expected Array<{translation_text: string}>");
|
|
700
|
+
}
|
|
247
701
|
};
|
|
248
|
-
var
|
|
249
|
-
|
|
702
|
+
var HFInferenceSummarizationTask = class extends HFInferenceTask {
|
|
703
|
+
async getResponse(response) {
|
|
704
|
+
if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
|
|
705
|
+
return response?.[0];
|
|
706
|
+
}
|
|
707
|
+
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
|
708
|
+
}
|
|
250
709
|
};
|
|
251
|
-
var
|
|
252
|
-
|
|
253
|
-
return
|
|
710
|
+
var HFInferenceTextToSpeechTask = class extends HFInferenceTask {
|
|
711
|
+
async getResponse(response) {
|
|
712
|
+
return response;
|
|
254
713
|
}
|
|
255
|
-
|
|
256
|
-
|
|
714
|
+
};
|
|
715
|
+
var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
|
|
716
|
+
async getResponse(response) {
|
|
717
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
718
|
+
return response;
|
|
719
|
+
}
|
|
720
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
721
|
+
}
|
|
722
|
+
};
|
|
723
|
+
var HFInferenceVisualQuestionAnsweringTask = class extends HFInferenceTask {
|
|
724
|
+
async getResponse(response) {
|
|
725
|
+
if (Array.isArray(response) && response.every(
|
|
726
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
|
|
727
|
+
)) {
|
|
728
|
+
return response[0];
|
|
729
|
+
}
|
|
730
|
+
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
|
257
731
|
}
|
|
258
|
-
return `${params.baseUrl}/models/${params.model}`;
|
|
259
732
|
};
|
|
260
|
-
var
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
733
|
+
var HFInferenceTabularRegressionTask = class extends HFInferenceTask {
|
|
734
|
+
async getResponse(response) {
|
|
735
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
736
|
+
return response;
|
|
737
|
+
}
|
|
738
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
739
|
+
}
|
|
740
|
+
};
|
|
741
|
+
var HFInferenceTextToAudioTask = class extends HFInferenceTask {
|
|
742
|
+
async getResponse(response) {
|
|
743
|
+
return response;
|
|
744
|
+
}
|
|
265
745
|
};
|
|
266
746
|
|
|
267
747
|
// src/providers/hyperbolic.ts
|
|
268
748
|
var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
|
|
269
|
-
var
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
return {
|
|
274
|
-
...params.args,
|
|
275
|
-
...params.task === "text-to-image" ? { model_name: params.model } : { model: params.model }
|
|
276
|
-
};
|
|
277
|
-
};
|
|
278
|
-
var makeHeaders7 = (params) => {
|
|
279
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
749
|
+
var HyperbolicConversationalTask = class extends BaseConversationalTask {
|
|
750
|
+
constructor() {
|
|
751
|
+
super("hyperbolic", HYPERBOLIC_API_BASE_URL);
|
|
752
|
+
}
|
|
280
753
|
};
|
|
281
|
-
var
|
|
282
|
-
|
|
283
|
-
|
|
754
|
+
var HyperbolicTextGenerationTask = class extends BaseTextGenerationTask {
|
|
755
|
+
constructor() {
|
|
756
|
+
super("hyperbolic", HYPERBOLIC_API_BASE_URL);
|
|
757
|
+
}
|
|
758
|
+
makeRoute() {
|
|
759
|
+
return "v1/chat/completions";
|
|
760
|
+
}
|
|
761
|
+
preparePayload(params) {
|
|
762
|
+
return {
|
|
763
|
+
messages: [{ content: params.args.inputs, role: "user" }],
|
|
764
|
+
...params.args.parameters ? {
|
|
765
|
+
max_tokens: params.args.parameters.max_new_tokens,
|
|
766
|
+
...omit(params.args.parameters, "max_new_tokens")
|
|
767
|
+
} : void 0,
|
|
768
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
769
|
+
model: params.model
|
|
770
|
+
};
|
|
771
|
+
}
|
|
772
|
+
async getResponse(response) {
|
|
773
|
+
if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
|
|
774
|
+
const completion = response.choices[0];
|
|
775
|
+
return {
|
|
776
|
+
generated_text: completion.message.content
|
|
777
|
+
};
|
|
778
|
+
}
|
|
779
|
+
throw new InferenceOutputError("Expected Hyperbolic text generation response format");
|
|
284
780
|
}
|
|
285
|
-
return `${params.baseUrl}/v1/chat/completions`;
|
|
286
781
|
};
|
|
287
|
-
var
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
782
|
+
var HyperbolicTextToImageTask = class extends TaskProviderHelper {
|
|
783
|
+
constructor() {
|
|
784
|
+
super("hyperbolic", HYPERBOLIC_API_BASE_URL);
|
|
785
|
+
}
|
|
786
|
+
makeRoute(params) {
|
|
787
|
+
return `/v1/images/generations`;
|
|
788
|
+
}
|
|
789
|
+
preparePayload(params) {
|
|
790
|
+
return {
|
|
791
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
792
|
+
...params.args.parameters,
|
|
793
|
+
prompt: params.args.inputs,
|
|
794
|
+
model_name: params.model
|
|
795
|
+
};
|
|
796
|
+
}
|
|
797
|
+
async getResponse(response, url, headers, outputType) {
|
|
798
|
+
if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images[0] && typeof response.images[0].image === "string") {
|
|
799
|
+
if (outputType === "url") {
|
|
800
|
+
return `data:image/jpeg;base64,${response.images[0].image}`;
|
|
801
|
+
}
|
|
802
|
+
return fetch(`data:image/jpeg;base64,${response.images[0].image}`).then((res) => res.blob());
|
|
803
|
+
}
|
|
804
|
+
throw new InferenceOutputError("Expected Hyperbolic text-to-image response format");
|
|
805
|
+
}
|
|
292
806
|
};
|
|
293
807
|
|
|
294
808
|
// src/providers/nebius.ts
|
|
295
809
|
var NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
|
|
296
|
-
var
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
return {
|
|
301
|
-
...params.args,
|
|
302
|
-
model: params.model
|
|
303
|
-
};
|
|
810
|
+
var NebiusConversationalTask = class extends BaseConversationalTask {
|
|
811
|
+
constructor() {
|
|
812
|
+
super("nebius", NEBIUS_API_BASE_URL);
|
|
813
|
+
}
|
|
304
814
|
};
|
|
305
|
-
var
|
|
306
|
-
|
|
815
|
+
var NebiusTextGenerationTask = class extends BaseTextGenerationTask {
|
|
816
|
+
constructor() {
|
|
817
|
+
super("nebius", NEBIUS_API_BASE_URL);
|
|
818
|
+
}
|
|
307
819
|
};
|
|
308
|
-
var
|
|
309
|
-
|
|
310
|
-
|
|
820
|
+
var NebiusTextToImageTask = class extends TaskProviderHelper {
|
|
821
|
+
constructor() {
|
|
822
|
+
super("nebius", NEBIUS_API_BASE_URL);
|
|
311
823
|
}
|
|
312
|
-
|
|
313
|
-
return
|
|
824
|
+
preparePayload(params) {
|
|
825
|
+
return {
|
|
826
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
827
|
+
...params.args.parameters,
|
|
828
|
+
response_format: "b64_json",
|
|
829
|
+
prompt: params.args.inputs,
|
|
830
|
+
model: params.model
|
|
831
|
+
};
|
|
314
832
|
}
|
|
315
|
-
|
|
316
|
-
return
|
|
833
|
+
makeRoute(params) {
|
|
834
|
+
return "v1/images/generations";
|
|
835
|
+
}
|
|
836
|
+
async getResponse(response, url, headers, outputType) {
|
|
837
|
+
if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
|
|
838
|
+
const base64Data = response.data[0].b64_json;
|
|
839
|
+
if (outputType === "url") {
|
|
840
|
+
return `data:image/jpeg;base64,${base64Data}`;
|
|
841
|
+
}
|
|
842
|
+
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
|
|
843
|
+
}
|
|
844
|
+
throw new InferenceOutputError("Expected Nebius text-to-image response format");
|
|
317
845
|
}
|
|
318
|
-
return params.baseUrl;
|
|
319
|
-
};
|
|
320
|
-
var NEBIUS_CONFIG = {
|
|
321
|
-
makeBaseUrl: makeBaseUrl8,
|
|
322
|
-
makeBody: makeBody8,
|
|
323
|
-
makeHeaders: makeHeaders8,
|
|
324
|
-
makeUrl: makeUrl8
|
|
325
846
|
};
|
|
326
847
|
|
|
327
848
|
// src/providers/novita.ts
|
|
328
849
|
var NOVITA_API_BASE_URL = "https://api.novita.ai";
|
|
329
|
-
var
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
};
|
|
337
|
-
};
|
|
338
|
-
var makeHeaders9 = (params) => {
|
|
339
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
850
|
+
var NovitaTextGenerationTask = class extends BaseTextGenerationTask {
|
|
851
|
+
constructor() {
|
|
852
|
+
super("novita", NOVITA_API_BASE_URL);
|
|
853
|
+
}
|
|
854
|
+
makeRoute() {
|
|
855
|
+
return "/v3/openai/chat/completions";
|
|
856
|
+
}
|
|
340
857
|
};
|
|
341
|
-
var
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
}
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
return `${params.baseUrl}/v3/hf/${params.model}`;
|
|
858
|
+
var NovitaConversationalTask = class extends BaseConversationalTask {
|
|
859
|
+
constructor() {
|
|
860
|
+
super("novita", NOVITA_API_BASE_URL);
|
|
861
|
+
}
|
|
862
|
+
makeRoute() {
|
|
863
|
+
return "/v3/openai/chat/completions";
|
|
348
864
|
}
|
|
349
|
-
return params.baseUrl;
|
|
350
865
|
};
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
866
|
+
|
|
867
|
+
// src/providers/openai.ts
|
|
868
|
+
var OPENAI_API_BASE_URL = "https://api.openai.com";
|
|
869
|
+
var OpenAIConversationalTask = class extends BaseConversationalTask {
|
|
870
|
+
constructor() {
|
|
871
|
+
super("openai", OPENAI_API_BASE_URL, true);
|
|
872
|
+
}
|
|
356
873
|
};
|
|
357
874
|
|
|
358
875
|
// src/providers/replicate.ts
|
|
359
|
-
var
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
}
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
876
|
+
var ReplicateTask = class extends TaskProviderHelper {
|
|
877
|
+
constructor(url) {
|
|
878
|
+
super("replicate", url || "https://api.replicate.com");
|
|
879
|
+
}
|
|
880
|
+
makeRoute(params) {
|
|
881
|
+
if (params.model.includes(":")) {
|
|
882
|
+
return "v1/predictions";
|
|
883
|
+
}
|
|
884
|
+
return `v1/models/${params.model}/predictions`;
|
|
885
|
+
}
|
|
886
|
+
preparePayload(params) {
|
|
887
|
+
return {
|
|
888
|
+
input: {
|
|
889
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
890
|
+
...params.args.parameters,
|
|
891
|
+
prompt: params.args.inputs
|
|
892
|
+
},
|
|
893
|
+
version: params.model.includes(":") ? params.model.split(":")[1] : void 0
|
|
894
|
+
};
|
|
895
|
+
}
|
|
896
|
+
prepareHeaders(params, binary) {
|
|
897
|
+
const headers = { Authorization: `Bearer ${params.accessToken}`, Prefer: "wait" };
|
|
898
|
+
if (!binary) {
|
|
899
|
+
headers["Content-Type"] = "application/json";
|
|
900
|
+
}
|
|
901
|
+
return headers;
|
|
902
|
+
}
|
|
903
|
+
makeUrl(params) {
|
|
904
|
+
const baseUrl = this.makeBaseUrl(params);
|
|
905
|
+
if (params.model.includes(":")) {
|
|
906
|
+
return `${baseUrl}/v1/predictions`;
|
|
907
|
+
}
|
|
908
|
+
return `${baseUrl}/v1/models/${params.model}/predictions`;
|
|
909
|
+
}
|
|
368
910
|
};
|
|
369
|
-
var
|
|
370
|
-
|
|
911
|
+
var ReplicateTextToImageTask = class extends ReplicateTask {
|
|
912
|
+
async getResponse(res, url, headers, outputType) {
|
|
913
|
+
if (typeof res === "object" && "output" in res && Array.isArray(res.output) && res.output.length > 0 && typeof res.output[0] === "string") {
|
|
914
|
+
if (outputType === "url") {
|
|
915
|
+
return res.output[0];
|
|
916
|
+
}
|
|
917
|
+
const urlResponse = await fetch(res.output[0]);
|
|
918
|
+
return await urlResponse.blob();
|
|
919
|
+
}
|
|
920
|
+
throw new InferenceOutputError("Expected Replicate text-to-image response format");
|
|
921
|
+
}
|
|
371
922
|
};
|
|
372
|
-
var
|
|
373
|
-
|
|
374
|
-
|
|
923
|
+
var ReplicateTextToSpeechTask = class extends ReplicateTask {
|
|
924
|
+
preparePayload(params) {
|
|
925
|
+
const payload = super.preparePayload(params);
|
|
926
|
+
const input = payload["input"];
|
|
927
|
+
if (typeof input === "object" && input !== null && "prompt" in input) {
|
|
928
|
+
const inputObj = input;
|
|
929
|
+
inputObj["text"] = inputObj["prompt"];
|
|
930
|
+
delete inputObj["prompt"];
|
|
931
|
+
}
|
|
932
|
+
return payload;
|
|
933
|
+
}
|
|
934
|
+
async getResponse(response) {
|
|
935
|
+
if (response instanceof Blob) {
|
|
936
|
+
return response;
|
|
937
|
+
}
|
|
938
|
+
if (response && typeof response === "object") {
|
|
939
|
+
if ("output" in response) {
|
|
940
|
+
if (typeof response.output === "string") {
|
|
941
|
+
const urlResponse = await fetch(response.output);
|
|
942
|
+
return await urlResponse.blob();
|
|
943
|
+
} else if (Array.isArray(response.output)) {
|
|
944
|
+
const urlResponse = await fetch(response.output[0]);
|
|
945
|
+
return await urlResponse.blob();
|
|
946
|
+
}
|
|
947
|
+
}
|
|
948
|
+
}
|
|
949
|
+
throw new InferenceOutputError("Expected Blob or object with output");
|
|
375
950
|
}
|
|
376
|
-
return `${params.baseUrl}/v1/models/${params.model}/predictions`;
|
|
377
951
|
};
|
|
378
|
-
var
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
952
|
+
var ReplicateTextToVideoTask = class extends ReplicateTask {
|
|
953
|
+
async getResponse(response) {
|
|
954
|
+
if (typeof response === "object" && !!response && "output" in response && typeof response.output === "string" && isUrl(response.output)) {
|
|
955
|
+
const urlResponse = await fetch(response.output);
|
|
956
|
+
return await urlResponse.blob();
|
|
957
|
+
}
|
|
958
|
+
throw new InferenceOutputError("Expected { output: string }");
|
|
959
|
+
}
|
|
383
960
|
};
|
|
384
961
|
|
|
385
962
|
// src/providers/sambanova.ts
|
|
386
|
-
var
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
};
|
|
390
|
-
var makeBody11 = (params) => {
|
|
391
|
-
return {
|
|
392
|
-
...params.args,
|
|
393
|
-
...params.chatCompletion ? { model: params.model } : void 0
|
|
394
|
-
};
|
|
395
|
-
};
|
|
396
|
-
var makeHeaders11 = (params) => {
|
|
397
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
398
|
-
};
|
|
399
|
-
var makeUrl11 = (params) => {
|
|
400
|
-
if (params.chatCompletion) {
|
|
401
|
-
return `${params.baseUrl}/v1/chat/completions`;
|
|
963
|
+
var SambanovaConversationalTask = class extends BaseConversationalTask {
|
|
964
|
+
constructor() {
|
|
965
|
+
super("sambanova", "https://api.sambanova.ai");
|
|
402
966
|
}
|
|
403
|
-
return params.baseUrl;
|
|
404
|
-
};
|
|
405
|
-
var SAMBANOVA_CONFIG = {
|
|
406
|
-
makeBaseUrl: makeBaseUrl11,
|
|
407
|
-
makeBody: makeBody11,
|
|
408
|
-
makeHeaders: makeHeaders11,
|
|
409
|
-
makeUrl: makeUrl11
|
|
410
967
|
};
|
|
411
968
|
|
|
412
969
|
// src/providers/together.ts
|
|
413
970
|
var TOGETHER_API_BASE_URL = "https://api.together.xyz";
|
|
414
|
-
var
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
return {
|
|
419
|
-
...params.args,
|
|
420
|
-
model: params.model
|
|
421
|
-
};
|
|
422
|
-
};
|
|
423
|
-
var makeHeaders12 = (params) => {
|
|
424
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
971
|
+
var TogetherConversationalTask = class extends BaseConversationalTask {
|
|
972
|
+
constructor() {
|
|
973
|
+
super("together", TOGETHER_API_BASE_URL);
|
|
974
|
+
}
|
|
425
975
|
};
|
|
426
|
-
var
|
|
427
|
-
|
|
428
|
-
|
|
976
|
+
var TogetherTextGenerationTask = class extends BaseTextGenerationTask {
|
|
977
|
+
constructor() {
|
|
978
|
+
super("together", TOGETHER_API_BASE_URL);
|
|
429
979
|
}
|
|
430
|
-
|
|
431
|
-
return
|
|
980
|
+
preparePayload(params) {
|
|
981
|
+
return {
|
|
982
|
+
model: params.model,
|
|
983
|
+
...params.args,
|
|
984
|
+
prompt: params.args.inputs
|
|
985
|
+
};
|
|
432
986
|
}
|
|
433
|
-
|
|
434
|
-
|
|
987
|
+
async getResponse(response) {
|
|
988
|
+
if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
|
|
989
|
+
const completion = response.choices[0];
|
|
990
|
+
return {
|
|
991
|
+
generated_text: completion.text
|
|
992
|
+
};
|
|
993
|
+
}
|
|
994
|
+
throw new InferenceOutputError("Expected Together text generation response format");
|
|
435
995
|
}
|
|
436
|
-
return params.baseUrl;
|
|
437
996
|
};
|
|
438
|
-
var
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
997
|
+
var TogetherTextToImageTask = class extends TaskProviderHelper {
|
|
998
|
+
constructor() {
|
|
999
|
+
super("together", TOGETHER_API_BASE_URL);
|
|
1000
|
+
}
|
|
1001
|
+
makeRoute() {
|
|
1002
|
+
return "v1/images/generations";
|
|
1003
|
+
}
|
|
1004
|
+
preparePayload(params) {
|
|
1005
|
+
return {
|
|
1006
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
1007
|
+
...params.args.parameters,
|
|
1008
|
+
prompt: params.args.inputs,
|
|
1009
|
+
response_format: "base64",
|
|
1010
|
+
model: params.model
|
|
1011
|
+
};
|
|
1012
|
+
}
|
|
1013
|
+
async getResponse(response, outputType) {
|
|
1014
|
+
if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
|
|
1015
|
+
const base64Data = response.data[0].b64_json;
|
|
1016
|
+
if (outputType === "url") {
|
|
1017
|
+
return `data:image/jpeg;base64,${base64Data}`;
|
|
1018
|
+
}
|
|
1019
|
+
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
|
|
1020
|
+
}
|
|
1021
|
+
throw new InferenceOutputError("Expected Together text-to-image response format");
|
|
1022
|
+
}
|
|
443
1023
|
};
|
|
444
1024
|
|
|
445
|
-
// src/
|
|
446
|
-
var
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
}
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
1025
|
+
// src/lib/getProviderHelper.ts
|
|
1026
|
+
var PROVIDERS = {
|
|
1027
|
+
"black-forest-labs": {
|
|
1028
|
+
"text-to-image": new BlackForestLabsTextToImageTask()
|
|
1029
|
+
},
|
|
1030
|
+
cerebras: {
|
|
1031
|
+
conversational: new CerebrasConversationalTask()
|
|
1032
|
+
},
|
|
1033
|
+
cohere: {
|
|
1034
|
+
conversational: new CohereConversationalTask()
|
|
1035
|
+
},
|
|
1036
|
+
"fal-ai": {
|
|
1037
|
+
"text-to-image": new FalAITextToImageTask(),
|
|
1038
|
+
"text-to-speech": new FalAITextToSpeechTask(),
|
|
1039
|
+
"text-to-video": new FalAITextToVideoTask(),
|
|
1040
|
+
"automatic-speech-recognition": new FalAIAutomaticSpeechRecognitionTask()
|
|
1041
|
+
},
|
|
1042
|
+
"hf-inference": {
|
|
1043
|
+
"text-to-image": new HFInferenceTextToImageTask(),
|
|
1044
|
+
conversational: new HFInferenceConversationalTask(),
|
|
1045
|
+
"text-generation": new HFInferenceTextGenerationTask(),
|
|
1046
|
+
"text-classification": new HFInferenceTextClassificationTask(),
|
|
1047
|
+
"question-answering": new HFInferenceQuestionAnsweringTask(),
|
|
1048
|
+
"audio-classification": new HFInferenceAudioClassificationTask(),
|
|
1049
|
+
"automatic-speech-recognition": new HFInferenceAutomaticSpeechRecognitionTask(),
|
|
1050
|
+
"fill-mask": new HFInferenceFillMaskTask(),
|
|
1051
|
+
"feature-extraction": new HFInferenceFeatureExtractionTask(),
|
|
1052
|
+
"image-classification": new HFInferenceImageClassificationTask(),
|
|
1053
|
+
"image-segmentation": new HFInferenceImageSegmentationTask(),
|
|
1054
|
+
"document-question-answering": new HFInferenceDocumentQuestionAnsweringTask(),
|
|
1055
|
+
"image-to-text": new HFInferenceImageToTextTask(),
|
|
1056
|
+
"object-detection": new HFInferenceObjectDetectionTask(),
|
|
1057
|
+
"audio-to-audio": new HFInferenceAudioToAudioTask(),
|
|
1058
|
+
"zero-shot-image-classification": new HFInferenceZeroShotImageClassificationTask(),
|
|
1059
|
+
"zero-shot-classification": new HFInferenceZeroShotClassificationTask(),
|
|
1060
|
+
"image-to-image": new HFInferenceImageToImageTask(),
|
|
1061
|
+
"sentence-similarity": new HFInferenceSentenceSimilarityTask(),
|
|
1062
|
+
"table-question-answering": new HFInferenceTableQuestionAnsweringTask(),
|
|
1063
|
+
"tabular-classification": new HFInferenceTabularClassificationTask(),
|
|
1064
|
+
"text-to-speech": new HFInferenceTextToSpeechTask(),
|
|
1065
|
+
"token-classification": new HFInferenceTokenClassificationTask(),
|
|
1066
|
+
translation: new HFInferenceTranslationTask(),
|
|
1067
|
+
summarization: new HFInferenceSummarizationTask(),
|
|
1068
|
+
"visual-question-answering": new HFInferenceVisualQuestionAnsweringTask(),
|
|
1069
|
+
"tabular-regression": new HFInferenceTabularRegressionTask(),
|
|
1070
|
+
"text-to-audio": new HFInferenceTextToAudioTask()
|
|
1071
|
+
},
|
|
1072
|
+
"fireworks-ai": {
|
|
1073
|
+
conversational: new FireworksConversationalTask()
|
|
1074
|
+
},
|
|
1075
|
+
hyperbolic: {
|
|
1076
|
+
"text-to-image": new HyperbolicTextToImageTask(),
|
|
1077
|
+
conversational: new HyperbolicConversationalTask(),
|
|
1078
|
+
"text-generation": new HyperbolicTextGenerationTask()
|
|
1079
|
+
},
|
|
1080
|
+
nebius: {
|
|
1081
|
+
"text-to-image": new NebiusTextToImageTask(),
|
|
1082
|
+
conversational: new NebiusConversationalTask(),
|
|
1083
|
+
"text-generation": new NebiusTextGenerationTask()
|
|
1084
|
+
},
|
|
1085
|
+
novita: {
|
|
1086
|
+
conversational: new NovitaConversationalTask(),
|
|
1087
|
+
"text-generation": new NovitaTextGenerationTask()
|
|
1088
|
+
},
|
|
1089
|
+
openai: {
|
|
1090
|
+
conversational: new OpenAIConversationalTask()
|
|
1091
|
+
},
|
|
1092
|
+
replicate: {
|
|
1093
|
+
"text-to-image": new ReplicateTextToImageTask(),
|
|
1094
|
+
"text-to-speech": new ReplicateTextToSpeechTask(),
|
|
1095
|
+
"text-to-video": new ReplicateTextToVideoTask()
|
|
1096
|
+
},
|
|
1097
|
+
sambanova: {
|
|
1098
|
+
conversational: new SambanovaConversationalTask()
|
|
1099
|
+
},
|
|
1100
|
+
together: {
|
|
1101
|
+
"text-to-image": new TogetherTextToImageTask(),
|
|
1102
|
+
conversational: new TogetherConversationalTask(),
|
|
1103
|
+
"text-generation": new TogetherTextGenerationTask()
|
|
453
1104
|
}
|
|
454
|
-
return {
|
|
455
|
-
...params.args,
|
|
456
|
-
model: params.model
|
|
457
|
-
};
|
|
458
1105
|
};
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
throw new Error("OpenAI only supports chat completions.");
|
|
1106
|
+
function getProviderHelper(provider, task) {
|
|
1107
|
+
if (provider === "hf-inference") {
|
|
1108
|
+
if (!task) {
|
|
1109
|
+
return new HFInferenceTask();
|
|
1110
|
+
}
|
|
465
1111
|
}
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
1112
|
+
if (!task) {
|
|
1113
|
+
throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'");
|
|
1114
|
+
}
|
|
1115
|
+
if (!(provider in PROVIDERS)) {
|
|
1116
|
+
throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
|
|
1117
|
+
}
|
|
1118
|
+
const providerTasks = PROVIDERS[provider];
|
|
1119
|
+
if (!providerTasks || !(task in providerTasks)) {
|
|
1120
|
+
throw new Error(
|
|
1121
|
+
`Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}`
|
|
1122
|
+
);
|
|
1123
|
+
}
|
|
1124
|
+
return providerTasks[task];
|
|
1125
|
+
}
|
|
479
1126
|
|
|
480
1127
|
// src/providers/consts.ts
|
|
481
1128
|
var HARDCODED_MODEL_ID_MAPPING = {
|
|
@@ -545,28 +1192,11 @@ async function getProviderModelId(params, args, options = {}) {
|
|
|
545
1192
|
}
|
|
546
1193
|
|
|
547
1194
|
// src/lib/makeRequestOptions.ts
|
|
548
|
-
var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
|
|
549
1195
|
var tasks = null;
|
|
550
|
-
var providerConfigs = {
|
|
551
|
-
"black-forest-labs": BLACK_FOREST_LABS_CONFIG,
|
|
552
|
-
cerebras: CEREBRAS_CONFIG,
|
|
553
|
-
cohere: COHERE_CONFIG,
|
|
554
|
-
"fal-ai": FAL_AI_CONFIG,
|
|
555
|
-
"fireworks-ai": FIREWORKS_AI_CONFIG,
|
|
556
|
-
"hf-inference": HF_INFERENCE_CONFIG,
|
|
557
|
-
hyperbolic: HYPERBOLIC_CONFIG,
|
|
558
|
-
openai: OPENAI_CONFIG,
|
|
559
|
-
nebius: NEBIUS_CONFIG,
|
|
560
|
-
novita: NOVITA_CONFIG,
|
|
561
|
-
replicate: REPLICATE_CONFIG,
|
|
562
|
-
sambanova: SAMBANOVA_CONFIG,
|
|
563
|
-
together: TOGETHER_CONFIG
|
|
564
|
-
};
|
|
565
1196
|
async function makeRequestOptions(args, options) {
|
|
566
1197
|
const { provider: maybeProvider, model: maybeModel } = args;
|
|
567
1198
|
const provider = maybeProvider ?? "hf-inference";
|
|
568
|
-
const
|
|
569
|
-
const { task, chatCompletion: chatCompletion2 } = options ?? {};
|
|
1199
|
+
const { task } = options ?? {};
|
|
570
1200
|
if (args.endpointUrl && provider !== "hf-inference") {
|
|
571
1201
|
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
|
|
572
1202
|
}
|
|
@@ -576,19 +1206,16 @@ async function makeRequestOptions(args, options) {
|
|
|
576
1206
|
if (!maybeModel && !task) {
|
|
577
1207
|
throw new Error("No model provided, and no task has been specified.");
|
|
578
1208
|
}
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
if (providerConfig.clientSideRoutingOnly && !maybeModel) {
|
|
1209
|
+
const hfModel = maybeModel ?? await loadDefaultModel(task);
|
|
1210
|
+
const providerHelper = getProviderHelper(provider, task);
|
|
1211
|
+
if (providerHelper.clientSideRoutingOnly && !maybeModel) {
|
|
583
1212
|
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
|
|
584
1213
|
}
|
|
585
|
-
const
|
|
586
|
-
const resolvedModel = providerConfig.clientSideRoutingOnly ? (
|
|
1214
|
+
const resolvedModel = providerHelper.clientSideRoutingOnly ? (
|
|
587
1215
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
588
1216
|
removeProviderPrefix(maybeModel, provider)
|
|
589
1217
|
) : await getProviderModelId({ model: hfModel, provider }, args, {
|
|
590
1218
|
task,
|
|
591
|
-
chatCompletion: chatCompletion2,
|
|
592
1219
|
fetch: options?.fetch
|
|
593
1220
|
});
|
|
594
1221
|
return makeRequestOptionsFromResolvedModel(resolvedModel, args, options);
|
|
@@ -596,10 +1223,10 @@ async function makeRequestOptions(args, options) {
|
|
|
596
1223
|
function makeRequestOptionsFromResolvedModel(resolvedModel, args, options) {
|
|
597
1224
|
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
|
|
598
1225
|
const provider = maybeProvider ?? "hf-inference";
|
|
599
|
-
const
|
|
600
|
-
const
|
|
1226
|
+
const { includeCredentials, task, signal, billTo } = options ?? {};
|
|
1227
|
+
const providerHelper = getProviderHelper(provider, task);
|
|
601
1228
|
const authMethod = (() => {
|
|
602
|
-
if (
|
|
1229
|
+
if (providerHelper.clientSideRoutingOnly) {
|
|
603
1230
|
if (accessToken && accessToken.startsWith("hf_")) {
|
|
604
1231
|
throw new Error(`Provider ${provider} is closed-source and does not support HF tokens.`);
|
|
605
1232
|
}
|
|
@@ -613,32 +1240,30 @@ function makeRequestOptionsFromResolvedModel(resolvedModel, args, options) {
|
|
|
613
1240
|
}
|
|
614
1241
|
return "none";
|
|
615
1242
|
})();
|
|
616
|
-
const
|
|
1243
|
+
const modelId = endpointUrl ?? resolvedModel;
|
|
1244
|
+
const url = providerHelper.makeUrl({
|
|
617
1245
|
authMethod,
|
|
618
|
-
|
|
619
|
-
model: resolvedModel,
|
|
620
|
-
chatCompletion: chatCompletion2,
|
|
1246
|
+
model: modelId,
|
|
621
1247
|
task
|
|
622
1248
|
});
|
|
623
|
-
const
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
1249
|
+
const headers = providerHelper.prepareHeaders(
|
|
1250
|
+
{
|
|
1251
|
+
accessToken,
|
|
1252
|
+
authMethod
|
|
1253
|
+
},
|
|
1254
|
+
"data" in args && !!args.data
|
|
1255
|
+
);
|
|
1256
|
+
if (billTo) {
|
|
1257
|
+
headers[HF_HEADER_X_BILL_TO] = billTo;
|
|
630
1258
|
}
|
|
631
1259
|
const ownUserAgent = `${name}/${version}`;
|
|
632
1260
|
const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : void 0].filter((x) => x !== void 0).join(" ");
|
|
633
1261
|
headers["User-Agent"] = userAgent;
|
|
634
|
-
const body =
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
chatCompletion: chatCompletion2
|
|
640
|
-
})
|
|
641
|
-
);
|
|
1262
|
+
const body = providerHelper.makeBody({
|
|
1263
|
+
args: remainingArgs,
|
|
1264
|
+
model: resolvedModel,
|
|
1265
|
+
task
|
|
1266
|
+
});
|
|
642
1267
|
let credentials;
|
|
643
1268
|
if (typeof includeCredentials === "string") {
|
|
644
1269
|
credentials = includeCredentials;
|
|
@@ -678,37 +1303,6 @@ function removeProviderPrefix(model, provider) {
|
|
|
678
1303
|
return model.slice(provider.length + 1);
|
|
679
1304
|
}
|
|
680
1305
|
|
|
681
|
-
// src/tasks/custom/request.ts
|
|
682
|
-
async function request(args, options) {
|
|
683
|
-
const { url, info } = await makeRequestOptions(args, options);
|
|
684
|
-
const response = await (options?.fetch ?? fetch)(url, info);
|
|
685
|
-
if (options?.retry_on_error !== false && response.status === 503) {
|
|
686
|
-
return request(args, options);
|
|
687
|
-
}
|
|
688
|
-
if (!response.ok) {
|
|
689
|
-
const contentType = response.headers.get("Content-Type");
|
|
690
|
-
if (["application/json", "application/problem+json"].some((ct) => contentType?.startsWith(ct))) {
|
|
691
|
-
const output = await response.json();
|
|
692
|
-
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
|
|
693
|
-
throw new Error(
|
|
694
|
-
`Server ${args.model} does not seem to support chat completion. Error: ${JSON.stringify(output.error)}`
|
|
695
|
-
);
|
|
696
|
-
}
|
|
697
|
-
if (output.error || output.detail) {
|
|
698
|
-
throw new Error(JSON.stringify(output.error ?? output.detail));
|
|
699
|
-
} else {
|
|
700
|
-
throw new Error(output);
|
|
701
|
-
}
|
|
702
|
-
}
|
|
703
|
-
const message = contentType?.startsWith("text/plain;") ? await response.text() : void 0;
|
|
704
|
-
throw new Error(message ?? "An error occurred while fetching the blob");
|
|
705
|
-
}
|
|
706
|
-
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
|
|
707
|
-
return await response.json();
|
|
708
|
-
}
|
|
709
|
-
return await response.blob();
|
|
710
|
-
}
|
|
711
|
-
|
|
712
1306
|
// src/vendor/fetch-event-source/parse.ts
|
|
713
1307
|
function getLines(onLine) {
|
|
714
1308
|
let buffer;
|
|
@@ -808,12 +1402,44 @@ function newMessage() {
|
|
|
808
1402
|
};
|
|
809
1403
|
}
|
|
810
1404
|
|
|
811
|
-
// src/
|
|
812
|
-
async function
|
|
1405
|
+
// src/utils/request.ts
|
|
1406
|
+
async function innerRequest(args, options) {
|
|
1407
|
+
const { url, info } = await makeRequestOptions(args, options);
|
|
1408
|
+
const response = await (options?.fetch ?? fetch)(url, info);
|
|
1409
|
+
const requestContext = { url, info };
|
|
1410
|
+
if (options?.retry_on_error !== false && response.status === 503) {
|
|
1411
|
+
return innerRequest(args, options);
|
|
1412
|
+
}
|
|
1413
|
+
if (!response.ok) {
|
|
1414
|
+
const contentType = response.headers.get("Content-Type");
|
|
1415
|
+
if (["application/json", "application/problem+json"].some((ct) => contentType?.startsWith(ct))) {
|
|
1416
|
+
const output = await response.json();
|
|
1417
|
+
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
|
|
1418
|
+
throw new Error(
|
|
1419
|
+
`Server ${args.model} does not seem to support chat completion. Error: ${JSON.stringify(output.error)}`
|
|
1420
|
+
);
|
|
1421
|
+
}
|
|
1422
|
+
if (output.error || output.detail) {
|
|
1423
|
+
throw new Error(JSON.stringify(output.error ?? output.detail));
|
|
1424
|
+
} else {
|
|
1425
|
+
throw new Error(output);
|
|
1426
|
+
}
|
|
1427
|
+
}
|
|
1428
|
+
const message = contentType?.startsWith("text/plain;") ? await response.text() : void 0;
|
|
1429
|
+
throw new Error(message ?? "An error occurred while fetching the blob");
|
|
1430
|
+
}
|
|
1431
|
+
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
|
|
1432
|
+
const data = await response.json();
|
|
1433
|
+
return { data, requestContext };
|
|
1434
|
+
}
|
|
1435
|
+
const blob = await response.blob();
|
|
1436
|
+
return { data: blob, requestContext };
|
|
1437
|
+
}
|
|
1438
|
+
async function* innerStreamingRequest(args, options) {
|
|
813
1439
|
const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
|
|
814
1440
|
const response = await (options?.fetch ?? fetch)(url, info);
|
|
815
1441
|
if (options?.retry_on_error !== false && response.status === 503) {
|
|
816
|
-
return yield*
|
|
1442
|
+
return yield* innerStreamingRequest(args, options);
|
|
817
1443
|
}
|
|
818
1444
|
if (!response.ok) {
|
|
819
1445
|
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
|
|
@@ -827,6 +1453,9 @@ async function* streamingRequest(args, options) {
|
|
|
827
1453
|
if (output.error && "message" in output.error && typeof output.error.message === "string") {
|
|
828
1454
|
throw new Error(output.error.message);
|
|
829
1455
|
}
|
|
1456
|
+
if (typeof output.message === "string") {
|
|
1457
|
+
throw new Error(output.message);
|
|
1458
|
+
}
|
|
830
1459
|
}
|
|
831
1460
|
throw new Error(`Server response contains error: ${response.status}`);
|
|
832
1461
|
}
|
|
@@ -879,28 +1508,21 @@ async function* streamingRequest(args, options) {
|
|
|
879
1508
|
}
|
|
880
1509
|
}
|
|
881
1510
|
|
|
882
|
-
// src/
|
|
883
|
-
function
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
...props.map((prop) => {
|
|
887
|
-
if (o[prop] !== void 0) {
|
|
888
|
-
return { [prop]: o[prop] };
|
|
889
|
-
}
|
|
890
|
-
})
|
|
1511
|
+
// src/tasks/custom/request.ts
|
|
1512
|
+
async function request(args, options) {
|
|
1513
|
+
console.warn(
|
|
1514
|
+
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
|
|
891
1515
|
);
|
|
1516
|
+
const result = await innerRequest(args, options);
|
|
1517
|
+
return result.data;
|
|
892
1518
|
}
|
|
893
1519
|
|
|
894
|
-
// src/
|
|
895
|
-
function
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
function omit(o, props) {
|
|
901
|
-
const propsArr = Array.isArray(props) ? props : [props];
|
|
902
|
-
const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
|
|
903
|
-
return pick(o, letsKeep);
|
|
1520
|
+
// src/tasks/custom/streamingRequest.ts
|
|
1521
|
+
async function* streamingRequest(args, options) {
|
|
1522
|
+
console.warn(
|
|
1523
|
+
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
|
|
1524
|
+
);
|
|
1525
|
+
yield* innerStreamingRequest(args, options);
|
|
904
1526
|
}
|
|
905
1527
|
|
|
906
1528
|
// src/tasks/audio/utils.ts
|
|
@@ -913,16 +1535,24 @@ function preparePayload(args) {
|
|
|
913
1535
|
|
|
914
1536
|
// src/tasks/audio/audioClassification.ts
|
|
915
1537
|
async function audioClassification(args, options) {
|
|
1538
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
|
|
916
1539
|
const payload = preparePayload(args);
|
|
917
|
-
const res = await
|
|
1540
|
+
const { data: res } = await innerRequest(payload, {
|
|
918
1541
|
...options,
|
|
919
1542
|
task: "audio-classification"
|
|
920
1543
|
});
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
1544
|
+
return providerHelper.getResponse(res);
|
|
1545
|
+
}
|
|
1546
|
+
|
|
1547
|
+
// src/tasks/audio/audioToAudio.ts
|
|
1548
|
+
async function audioToAudio(args, options) {
|
|
1549
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
|
|
1550
|
+
const payload = preparePayload(args);
|
|
1551
|
+
const { data: res } = await innerRequest(payload, {
|
|
1552
|
+
...options,
|
|
1553
|
+
task: "audio-to-audio"
|
|
1554
|
+
});
|
|
1555
|
+
return providerHelper.getResponse(res);
|
|
926
1556
|
}
|
|
927
1557
|
|
|
928
1558
|
// src/utils/base64FromBytes.ts
|
|
@@ -940,8 +1570,9 @@ function base64FromBytes(arr) {
|
|
|
940
1570
|
|
|
941
1571
|
// src/tasks/audio/automaticSpeechRecognition.ts
|
|
942
1572
|
async function automaticSpeechRecognition(args, options) {
|
|
1573
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
|
|
943
1574
|
const payload = await buildPayload(args);
|
|
944
|
-
const res = await
|
|
1575
|
+
const { data: res } = await innerRequest(payload, {
|
|
945
1576
|
...options,
|
|
946
1577
|
task: "automatic-speech-recognition"
|
|
947
1578
|
});
|
|
@@ -949,9 +1580,8 @@ async function automaticSpeechRecognition(args, options) {
|
|
|
949
1580
|
if (!isValidOutput) {
|
|
950
1581
|
throw new InferenceOutputError("Expected {text: string}");
|
|
951
1582
|
}
|
|
952
|
-
return res;
|
|
1583
|
+
return providerHelper.getResponse(res);
|
|
953
1584
|
}
|
|
954
|
-
var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
|
|
955
1585
|
async function buildPayload(args) {
|
|
956
1586
|
if (args.provider === "fal-ai") {
|
|
957
1587
|
const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : void 0;
|
|
@@ -976,57 +1606,17 @@ async function buildPayload(args) {
|
|
|
976
1606
|
} else {
|
|
977
1607
|
return preparePayload(args);
|
|
978
1608
|
}
|
|
979
|
-
}
|
|
980
|
-
|
|
981
|
-
// src/tasks/audio/textToSpeech.ts
|
|
982
|
-
async function textToSpeech(args, options) {
|
|
983
|
-
const
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
text: args.inputs
|
|
987
|
-
} : args;
|
|
988
|
-
const res = await request(payload, {
|
|
989
|
-
...options,
|
|
990
|
-
task: "text-to-speech"
|
|
991
|
-
});
|
|
992
|
-
if (res instanceof Blob) {
|
|
993
|
-
return res;
|
|
994
|
-
}
|
|
995
|
-
if (res && typeof res === "object") {
|
|
996
|
-
if ("output" in res) {
|
|
997
|
-
if (typeof res.output === "string") {
|
|
998
|
-
const urlResponse = await fetch(res.output);
|
|
999
|
-
const blob = await urlResponse.blob();
|
|
1000
|
-
return blob;
|
|
1001
|
-
} else if (Array.isArray(res.output)) {
|
|
1002
|
-
const urlResponse = await fetch(res.output[0]);
|
|
1003
|
-
const blob = await urlResponse.blob();
|
|
1004
|
-
return blob;
|
|
1005
|
-
}
|
|
1006
|
-
}
|
|
1007
|
-
}
|
|
1008
|
-
throw new InferenceOutputError("Expected Blob or object with output");
|
|
1009
|
-
}
|
|
1010
|
-
|
|
1011
|
-
// src/tasks/audio/audioToAudio.ts
|
|
1012
|
-
async function audioToAudio(args, options) {
|
|
1013
|
-
const payload = preparePayload(args);
|
|
1014
|
-
const res = await request(payload, {
|
|
1609
|
+
}
|
|
1610
|
+
|
|
1611
|
+
// src/tasks/audio/textToSpeech.ts
|
|
1612
|
+
async function textToSpeech(args, options) {
|
|
1613
|
+
const provider = args.provider ?? "hf-inference";
|
|
1614
|
+
const providerHelper = getProviderHelper(provider, "text-to-speech");
|
|
1615
|
+
const { data: res } = await innerRequest(args, {
|
|
1015
1616
|
...options,
|
|
1016
|
-
task: "
|
|
1617
|
+
task: "text-to-speech"
|
|
1017
1618
|
});
|
|
1018
|
-
return
|
|
1019
|
-
}
|
|
1020
|
-
function validateOutput(output) {
|
|
1021
|
-
if (!Array.isArray(output)) {
|
|
1022
|
-
throw new InferenceOutputError("Expected Array");
|
|
1023
|
-
}
|
|
1024
|
-
if (!output.every((elem) => {
|
|
1025
|
-
return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
|
|
1026
|
-
})) {
|
|
1027
|
-
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
|
|
1028
|
-
}
|
|
1029
|
-
return output;
|
|
1619
|
+
return providerHelper.getResponse(res);
|
|
1030
1620
|
}
|
|
1031
1621
|
|
|
1032
1622
|
// src/tasks/cv/utils.ts
|
|
@@ -1036,183 +1626,95 @@ function preparePayload2(args) {
|
|
|
1036
1626
|
|
|
1037
1627
|
// src/tasks/cv/imageClassification.ts
|
|
1038
1628
|
async function imageClassification(args, options) {
|
|
1629
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
|
|
1039
1630
|
const payload = preparePayload2(args);
|
|
1040
|
-
const res = await
|
|
1631
|
+
const { data: res } = await innerRequest(payload, {
|
|
1041
1632
|
...options,
|
|
1042
1633
|
task: "image-classification"
|
|
1043
1634
|
});
|
|
1044
|
-
|
|
1045
|
-
if (!isValidOutput) {
|
|
1046
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
1047
|
-
}
|
|
1048
|
-
return res;
|
|
1635
|
+
return providerHelper.getResponse(res);
|
|
1049
1636
|
}
|
|
1050
1637
|
|
|
1051
1638
|
// src/tasks/cv/imageSegmentation.ts
|
|
1052
1639
|
async function imageSegmentation(args, options) {
|
|
1640
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
|
|
1053
1641
|
const payload = preparePayload2(args);
|
|
1054
|
-
const res = await
|
|
1642
|
+
const { data: res } = await innerRequest(payload, {
|
|
1055
1643
|
...options,
|
|
1056
1644
|
task: "image-segmentation"
|
|
1057
1645
|
});
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1646
|
+
return providerHelper.getResponse(res);
|
|
1647
|
+
}
|
|
1648
|
+
|
|
1649
|
+
// src/tasks/cv/imageToImage.ts
|
|
1650
|
+
async function imageToImage(args, options) {
|
|
1651
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
|
|
1652
|
+
let reqArgs;
|
|
1653
|
+
if (!args.parameters) {
|
|
1654
|
+
reqArgs = {
|
|
1655
|
+
accessToken: args.accessToken,
|
|
1656
|
+
model: args.model,
|
|
1657
|
+
data: args.inputs
|
|
1658
|
+
};
|
|
1659
|
+
} else {
|
|
1660
|
+
reqArgs = {
|
|
1661
|
+
...args,
|
|
1662
|
+
inputs: base64FromBytes(
|
|
1663
|
+
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await args.inputs.arrayBuffer())
|
|
1664
|
+
)
|
|
1665
|
+
};
|
|
1061
1666
|
}
|
|
1062
|
-
|
|
1667
|
+
const { data: res } = await innerRequest(reqArgs, {
|
|
1668
|
+
...options,
|
|
1669
|
+
task: "image-to-image"
|
|
1670
|
+
});
|
|
1671
|
+
return providerHelper.getResponse(res);
|
|
1063
1672
|
}
|
|
1064
1673
|
|
|
1065
1674
|
// src/tasks/cv/imageToText.ts
|
|
1066
1675
|
async function imageToText(args, options) {
|
|
1676
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
|
|
1067
1677
|
const payload = preparePayload2(args);
|
|
1068
|
-
const res =
|
|
1678
|
+
const { data: res } = await innerRequest(payload, {
|
|
1069
1679
|
...options,
|
|
1070
1680
|
task: "image-to-text"
|
|
1071
|
-
})
|
|
1072
|
-
|
|
1073
|
-
throw new InferenceOutputError("Expected {generated_text: string}");
|
|
1074
|
-
}
|
|
1075
|
-
return res;
|
|
1681
|
+
});
|
|
1682
|
+
return providerHelper.getResponse(res[0]);
|
|
1076
1683
|
}
|
|
1077
1684
|
|
|
1078
1685
|
// src/tasks/cv/objectDetection.ts
|
|
1079
1686
|
async function objectDetection(args, options) {
|
|
1687
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
|
|
1080
1688
|
const payload = preparePayload2(args);
|
|
1081
|
-
const res = await
|
|
1689
|
+
const { data: res } = await innerRequest(payload, {
|
|
1082
1690
|
...options,
|
|
1083
1691
|
task: "object-detection"
|
|
1084
1692
|
});
|
|
1085
|
-
|
|
1086
|
-
(x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
|
|
1087
|
-
);
|
|
1088
|
-
if (!isValidOutput) {
|
|
1089
|
-
throw new InferenceOutputError(
|
|
1090
|
-
"Expected Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
|
|
1091
|
-
);
|
|
1092
|
-
}
|
|
1093
|
-
return res;
|
|
1693
|
+
return providerHelper.getResponse(res);
|
|
1094
1694
|
}
|
|
1095
1695
|
|
|
1096
1696
|
// src/tasks/cv/textToImage.ts
|
|
1097
|
-
function getResponseFormatArg(provider) {
|
|
1098
|
-
switch (provider) {
|
|
1099
|
-
case "fal-ai":
|
|
1100
|
-
return { sync_mode: true };
|
|
1101
|
-
case "nebius":
|
|
1102
|
-
return { response_format: "b64_json" };
|
|
1103
|
-
case "replicate":
|
|
1104
|
-
return void 0;
|
|
1105
|
-
case "together":
|
|
1106
|
-
return { response_format: "base64" };
|
|
1107
|
-
default:
|
|
1108
|
-
return void 0;
|
|
1109
|
-
}
|
|
1110
|
-
}
|
|
1111
1697
|
async function textToImage(args, options) {
|
|
1112
|
-
const
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
...getResponseFormatArg(args.provider),
|
|
1116
|
-
prompt: args.inputs
|
|
1117
|
-
};
|
|
1118
|
-
const res = await request(payload, {
|
|
1698
|
+
const provider = args.provider ?? "hf-inference";
|
|
1699
|
+
const providerHelper = getProviderHelper(provider, "text-to-image");
|
|
1700
|
+
const { data: res } = await innerRequest(args, {
|
|
1119
1701
|
...options,
|
|
1120
1702
|
task: "text-to-image"
|
|
1121
1703
|
});
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
return await pollBflResponse(res.polling_url, options?.outputType);
|
|
1125
|
-
}
|
|
1126
|
-
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
|
|
1127
|
-
if (options?.outputType === "url") {
|
|
1128
|
-
return res.images[0].url;
|
|
1129
|
-
} else {
|
|
1130
|
-
const image = await fetch(res.images[0].url);
|
|
1131
|
-
return await image.blob();
|
|
1132
|
-
}
|
|
1133
|
-
}
|
|
1134
|
-
if (args.provider === "hyperbolic" && "images" in res && Array.isArray(res.images) && res.images[0] && typeof res.images[0].image === "string") {
|
|
1135
|
-
if (options?.outputType === "url") {
|
|
1136
|
-
return `data:image/jpeg;base64,${res.images[0].image}`;
|
|
1137
|
-
}
|
|
1138
|
-
const base64Response = await fetch(`data:image/jpeg;base64,${res.images[0].image}`);
|
|
1139
|
-
return await base64Response.blob();
|
|
1140
|
-
}
|
|
1141
|
-
if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
|
|
1142
|
-
const base64Data = res.data[0].b64_json;
|
|
1143
|
-
if (options?.outputType === "url") {
|
|
1144
|
-
return `data:image/jpeg;base64,${base64Data}`;
|
|
1145
|
-
}
|
|
1146
|
-
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
1147
|
-
return await base64Response.blob();
|
|
1148
|
-
}
|
|
1149
|
-
if ("output" in res && Array.isArray(res.output)) {
|
|
1150
|
-
if (options?.outputType === "url") {
|
|
1151
|
-
return res.output[0];
|
|
1152
|
-
}
|
|
1153
|
-
const urlResponse = await fetch(res.output[0]);
|
|
1154
|
-
const blob = await urlResponse.blob();
|
|
1155
|
-
return blob;
|
|
1156
|
-
}
|
|
1157
|
-
}
|
|
1158
|
-
const isValidOutput = res && res instanceof Blob;
|
|
1159
|
-
if (!isValidOutput) {
|
|
1160
|
-
throw new InferenceOutputError("Expected Blob");
|
|
1161
|
-
}
|
|
1162
|
-
if (options?.outputType === "url") {
|
|
1163
|
-
const b64 = await res.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
|
|
1164
|
-
return `data:image/jpeg;base64,${b64}`;
|
|
1165
|
-
}
|
|
1166
|
-
return res;
|
|
1167
|
-
}
|
|
1168
|
-
async function pollBflResponse(url, outputType) {
|
|
1169
|
-
const urlObj = new URL(url);
|
|
1170
|
-
for (let step = 0; step < 5; step++) {
|
|
1171
|
-
await delay(1e3);
|
|
1172
|
-
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
|
|
1173
|
-
urlObj.searchParams.set("attempt", step.toString(10));
|
|
1174
|
-
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
|
|
1175
|
-
if (!resp.ok) {
|
|
1176
|
-
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
1177
|
-
}
|
|
1178
|
-
const payload = await resp.json();
|
|
1179
|
-
if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
|
|
1180
|
-
if (outputType === "url") {
|
|
1181
|
-
return payload.result.sample;
|
|
1182
|
-
}
|
|
1183
|
-
const image = await fetch(payload.result.sample);
|
|
1184
|
-
return await image.blob();
|
|
1185
|
-
}
|
|
1186
|
-
}
|
|
1187
|
-
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
1704
|
+
const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-image" });
|
|
1705
|
+
return providerHelper.getResponse(res, url, info.headers, options?.outputType);
|
|
1188
1706
|
}
|
|
1189
1707
|
|
|
1190
|
-
// src/tasks/cv/
|
|
1191
|
-
async function
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
accessToken: args.accessToken,
|
|
1196
|
-
model: args.model,
|
|
1197
|
-
data: args.inputs
|
|
1198
|
-
};
|
|
1199
|
-
} else {
|
|
1200
|
-
reqArgs = {
|
|
1201
|
-
...args,
|
|
1202
|
-
inputs: base64FromBytes(
|
|
1203
|
-
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await args.inputs.arrayBuffer())
|
|
1204
|
-
)
|
|
1205
|
-
};
|
|
1206
|
-
}
|
|
1207
|
-
const res = await request(reqArgs, {
|
|
1708
|
+
// src/tasks/cv/textToVideo.ts
|
|
1709
|
+
async function textToVideo(args, options) {
|
|
1710
|
+
const provider = args.provider ?? "hf-inference";
|
|
1711
|
+
const providerHelper = getProviderHelper(provider, "text-to-video");
|
|
1712
|
+
const { data: response } = await innerRequest(args, {
|
|
1208
1713
|
...options,
|
|
1209
|
-
task: "
|
|
1714
|
+
task: "text-to-video"
|
|
1210
1715
|
});
|
|
1211
|
-
const
|
|
1212
|
-
|
|
1213
|
-
throw new InferenceOutputError("Expected Blob");
|
|
1214
|
-
}
|
|
1215
|
-
return res;
|
|
1716
|
+
const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-video" });
|
|
1717
|
+
return providerHelper.getResponse(response, url, info.headers);
|
|
1216
1718
|
}
|
|
1217
1719
|
|
|
1218
1720
|
// src/tasks/cv/zeroShotImageClassification.ts
|
|
@@ -1238,235 +1740,117 @@ async function preparePayload3(args) {
|
|
|
1238
1740
|
}
|
|
1239
1741
|
}
|
|
1240
1742
|
async function zeroShotImageClassification(args, options) {
|
|
1743
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
|
|
1241
1744
|
const payload = await preparePayload3(args);
|
|
1242
|
-
const res = await
|
|
1745
|
+
const { data: res } = await innerRequest(payload, {
|
|
1243
1746
|
...options,
|
|
1244
1747
|
task: "zero-shot-image-classification"
|
|
1245
1748
|
});
|
|
1246
|
-
|
|
1247
|
-
if (!isValidOutput) {
|
|
1248
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
1249
|
-
}
|
|
1250
|
-
return res;
|
|
1749
|
+
return providerHelper.getResponse(res);
|
|
1251
1750
|
}
|
|
1252
1751
|
|
|
1253
|
-
// src/tasks/
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
throw new Error(
|
|
1258
|
-
`textToVideo inference is only supported for the following providers: ${SUPPORTED_PROVIDERS.join(", ")}`
|
|
1259
|
-
);
|
|
1260
|
-
}
|
|
1261
|
-
const payload = args.provider === "fal-ai" || args.provider === "replicate" || args.provider === "novita" ? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs } : args;
|
|
1262
|
-
const res = await request(payload, {
|
|
1752
|
+
// src/tasks/nlp/chatCompletion.ts
|
|
1753
|
+
async function chatCompletion(args, options) {
|
|
1754
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
|
|
1755
|
+
const { data: response } = await innerRequest(args, {
|
|
1263
1756
|
...options,
|
|
1264
|
-
task: "
|
|
1757
|
+
task: "conversational"
|
|
1758
|
+
});
|
|
1759
|
+
return providerHelper.getResponse(response);
|
|
1760
|
+
}
|
|
1761
|
+
|
|
1762
|
+
// src/tasks/nlp/chatCompletionStream.ts
|
|
1763
|
+
async function* chatCompletionStream(args, options) {
|
|
1764
|
+
yield* innerStreamingRequest(args, {
|
|
1765
|
+
...options,
|
|
1766
|
+
task: "conversational"
|
|
1265
1767
|
});
|
|
1266
|
-
if (args.provider === "fal-ai") {
|
|
1267
|
-
const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-video" });
|
|
1268
|
-
return await pollFalResponse(res, url, info.headers);
|
|
1269
|
-
} else if (args.provider === "novita") {
|
|
1270
|
-
const isValidOutput = typeof res === "object" && !!res && "video" in res && typeof res.video === "object" && !!res.video && "video_url" in res.video && typeof res.video.video_url === "string" && isUrl(res.video.video_url);
|
|
1271
|
-
if (!isValidOutput) {
|
|
1272
|
-
throw new InferenceOutputError("Expected { video: { video_url: string } }");
|
|
1273
|
-
}
|
|
1274
|
-
const urlResponse = await fetch(res.video.video_url);
|
|
1275
|
-
return await urlResponse.blob();
|
|
1276
|
-
} else {
|
|
1277
|
-
const isValidOutput = typeof res === "object" && !!res && "output" in res && typeof res.output === "string" && isUrl(res.output);
|
|
1278
|
-
if (!isValidOutput) {
|
|
1279
|
-
throw new InferenceOutputError("Expected { output: string }");
|
|
1280
|
-
}
|
|
1281
|
-
const urlResponse = await fetch(res.output);
|
|
1282
|
-
return await urlResponse.blob();
|
|
1283
|
-
}
|
|
1284
1768
|
}
|
|
1285
1769
|
|
|
1286
1770
|
// src/tasks/nlp/featureExtraction.ts
|
|
1287
1771
|
async function featureExtraction(args, options) {
|
|
1288
|
-
const
|
|
1772
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction");
|
|
1773
|
+
const { data: res } = await innerRequest(args, {
|
|
1289
1774
|
...options,
|
|
1290
1775
|
task: "feature-extraction"
|
|
1291
1776
|
});
|
|
1292
|
-
|
|
1293
|
-
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
|
|
1294
|
-
if (curDepth > maxDepth)
|
|
1295
|
-
return false;
|
|
1296
|
-
if (arr.every((x) => Array.isArray(x))) {
|
|
1297
|
-
return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
|
|
1298
|
-
} else {
|
|
1299
|
-
return arr.every((x) => typeof x === "number");
|
|
1300
|
-
}
|
|
1301
|
-
};
|
|
1302
|
-
isValidOutput = Array.isArray(res) && isNumArrayRec(res, 3, 0);
|
|
1303
|
-
if (!isValidOutput) {
|
|
1304
|
-
throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
|
|
1305
|
-
}
|
|
1306
|
-
return res;
|
|
1777
|
+
return providerHelper.getResponse(res);
|
|
1307
1778
|
}
|
|
1308
1779
|
|
|
1309
1780
|
// src/tasks/nlp/fillMask.ts
|
|
1310
1781
|
async function fillMask(args, options) {
|
|
1311
|
-
const
|
|
1782
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask");
|
|
1783
|
+
const { data: res } = await innerRequest(args, {
|
|
1312
1784
|
...options,
|
|
1313
1785
|
task: "fill-mask"
|
|
1314
1786
|
});
|
|
1315
|
-
|
|
1316
|
-
(x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
|
|
1317
|
-
);
|
|
1318
|
-
if (!isValidOutput) {
|
|
1319
|
-
throw new InferenceOutputError(
|
|
1320
|
-
"Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
|
|
1321
|
-
);
|
|
1322
|
-
}
|
|
1323
|
-
return res;
|
|
1787
|
+
return providerHelper.getResponse(res);
|
|
1324
1788
|
}
|
|
1325
1789
|
|
|
1326
1790
|
// src/tasks/nlp/questionAnswering.ts
|
|
1327
1791
|
async function questionAnswering(args, options) {
|
|
1328
|
-
const
|
|
1792
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering");
|
|
1793
|
+
const { data: res } = await innerRequest(args, {
|
|
1329
1794
|
...options,
|
|
1330
1795
|
task: "question-answering"
|
|
1331
1796
|
});
|
|
1332
|
-
|
|
1333
|
-
(elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
|
|
1334
|
-
) : typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";
|
|
1335
|
-
if (!isValidOutput) {
|
|
1336
|
-
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
1337
|
-
}
|
|
1338
|
-
return Array.isArray(res) ? res[0] : res;
|
|
1797
|
+
return providerHelper.getResponse(res);
|
|
1339
1798
|
}
|
|
1340
1799
|
|
|
1341
1800
|
// src/tasks/nlp/sentenceSimilarity.ts
|
|
1342
1801
|
async function sentenceSimilarity(args, options) {
|
|
1343
|
-
const
|
|
1802
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity");
|
|
1803
|
+
const { data: res } = await innerRequest(args, {
|
|
1344
1804
|
...options,
|
|
1345
1805
|
task: "sentence-similarity"
|
|
1346
1806
|
});
|
|
1347
|
-
|
|
1348
|
-
if (!isValidOutput) {
|
|
1349
|
-
throw new InferenceOutputError("Expected number[]");
|
|
1350
|
-
}
|
|
1351
|
-
return res;
|
|
1352
|
-
}
|
|
1353
|
-
function prepareInput(args) {
|
|
1354
|
-
return {
|
|
1355
|
-
...omit(args, ["inputs", "parameters"]),
|
|
1356
|
-
inputs: { ...omit(args.inputs, "sourceSentence") },
|
|
1357
|
-
parameters: { source_sentence: args.inputs.sourceSentence, ...args.parameters }
|
|
1358
|
-
};
|
|
1807
|
+
return providerHelper.getResponse(res);
|
|
1359
1808
|
}
|
|
1360
1809
|
|
|
1361
1810
|
// src/tasks/nlp/summarization.ts
|
|
1362
1811
|
async function summarization(args, options) {
|
|
1363
|
-
const
|
|
1812
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization");
|
|
1813
|
+
const { data: res } = await innerRequest(args, {
|
|
1364
1814
|
...options,
|
|
1365
1815
|
task: "summarization"
|
|
1366
1816
|
});
|
|
1367
|
-
|
|
1368
|
-
if (!isValidOutput) {
|
|
1369
|
-
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
|
1370
|
-
}
|
|
1371
|
-
return res?.[0];
|
|
1817
|
+
return providerHelper.getResponse(res);
|
|
1372
1818
|
}
|
|
1373
1819
|
|
|
1374
1820
|
// src/tasks/nlp/tableQuestionAnswering.ts
|
|
1375
1821
|
async function tableQuestionAnswering(args, options) {
|
|
1376
|
-
const
|
|
1822
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
|
|
1823
|
+
const { data: res } = await innerRequest(args, {
|
|
1377
1824
|
...options,
|
|
1378
1825
|
task: "table-question-answering"
|
|
1379
1826
|
});
|
|
1380
|
-
|
|
1381
|
-
if (!isValidOutput) {
|
|
1382
|
-
throw new InferenceOutputError(
|
|
1383
|
-
"Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
|
|
1384
|
-
);
|
|
1385
|
-
}
|
|
1386
|
-
return Array.isArray(res) ? res[0] : res;
|
|
1387
|
-
}
|
|
1388
|
-
function validate(elem) {
|
|
1389
|
-
return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
|
|
1390
|
-
(coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
|
|
1391
|
-
);
|
|
1827
|
+
return providerHelper.getResponse(res);
|
|
1392
1828
|
}
|
|
1393
1829
|
|
|
1394
1830
|
// src/tasks/nlp/textClassification.ts
|
|
1395
1831
|
async function textClassification(args, options) {
|
|
1396
|
-
const
|
|
1832
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
|
|
1833
|
+
const { data: res } = await innerRequest(args, {
|
|
1397
1834
|
...options,
|
|
1398
1835
|
task: "text-classification"
|
|
1399
|
-
})
|
|
1400
|
-
|
|
1401
|
-
if (!isValidOutput) {
|
|
1402
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
1403
|
-
}
|
|
1404
|
-
return res;
|
|
1405
|
-
}
|
|
1406
|
-
|
|
1407
|
-
// src/utils/toArray.ts
|
|
1408
|
-
function toArray(obj) {
|
|
1409
|
-
if (Array.isArray(obj)) {
|
|
1410
|
-
return obj;
|
|
1411
|
-
}
|
|
1412
|
-
return [obj];
|
|
1836
|
+
});
|
|
1837
|
+
return providerHelper.getResponse(res);
|
|
1413
1838
|
}
|
|
1414
1839
|
|
|
1415
1840
|
// src/tasks/nlp/textGeneration.ts
|
|
1416
1841
|
async function textGeneration(args, options) {
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
if (!isValidOutput) {
|
|
1425
|
-
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
1426
|
-
}
|
|
1427
|
-
const completion = raw.choices[0];
|
|
1428
|
-
return {
|
|
1429
|
-
generated_text: completion.text
|
|
1430
|
-
};
|
|
1431
|
-
} else if (args.provider === "hyperbolic") {
|
|
1432
|
-
const payload = {
|
|
1433
|
-
messages: [{ content: args.inputs, role: "user" }],
|
|
1434
|
-
...args.parameters ? {
|
|
1435
|
-
max_tokens: args.parameters.max_new_tokens,
|
|
1436
|
-
...omit(args.parameters, "max_new_tokens")
|
|
1437
|
-
} : void 0,
|
|
1438
|
-
...omit(args, ["inputs", "parameters"])
|
|
1439
|
-
};
|
|
1440
|
-
const raw = await request(payload, {
|
|
1441
|
-
...options,
|
|
1442
|
-
task: "text-generation"
|
|
1443
|
-
});
|
|
1444
|
-
const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
|
|
1445
|
-
if (!isValidOutput) {
|
|
1446
|
-
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
1447
|
-
}
|
|
1448
|
-
const completion = raw.choices[0];
|
|
1449
|
-
return {
|
|
1450
|
-
generated_text: completion.message.content
|
|
1451
|
-
};
|
|
1452
|
-
} else {
|
|
1453
|
-
const res = toArray(
|
|
1454
|
-
await request(args, {
|
|
1455
|
-
...options,
|
|
1456
|
-
task: "text-generation"
|
|
1457
|
-
})
|
|
1458
|
-
);
|
|
1459
|
-
const isValidOutput = Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
|
|
1460
|
-
if (!isValidOutput) {
|
|
1461
|
-
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
1462
|
-
}
|
|
1463
|
-
return res?.[0];
|
|
1464
|
-
}
|
|
1842
|
+
const provider = args.provider ?? "hf-inference";
|
|
1843
|
+
const providerHelper = getProviderHelper(provider, "text-generation");
|
|
1844
|
+
const { data: response } = await innerRequest(args, {
|
|
1845
|
+
...options,
|
|
1846
|
+
task: "text-generation"
|
|
1847
|
+
});
|
|
1848
|
+
return providerHelper.getResponse(response);
|
|
1465
1849
|
}
|
|
1466
1850
|
|
|
1467
1851
|
// src/tasks/nlp/textGenerationStream.ts
|
|
1468
1852
|
async function* textGenerationStream(args, options) {
|
|
1469
|
-
yield*
|
|
1853
|
+
yield* innerStreamingRequest(args, {
|
|
1470
1854
|
...options,
|
|
1471
1855
|
task: "text-generation"
|
|
1472
1856
|
});
|
|
@@ -1474,79 +1858,37 @@ async function* textGenerationStream(args, options) {
|
|
|
1474
1858
|
|
|
1475
1859
|
// src/tasks/nlp/tokenClassification.ts
|
|
1476
1860
|
async function tokenClassification(args, options) {
|
|
1477
|
-
const
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
);
|
|
1483
|
-
const isValidOutput = Array.isArray(res) && res.every(
|
|
1484
|
-
(x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
|
|
1485
|
-
);
|
|
1486
|
-
if (!isValidOutput) {
|
|
1487
|
-
throw new InferenceOutputError(
|
|
1488
|
-
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
|
|
1489
|
-
);
|
|
1490
|
-
}
|
|
1491
|
-
return res;
|
|
1861
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification");
|
|
1862
|
+
const { data: res } = await innerRequest(args, {
|
|
1863
|
+
...options,
|
|
1864
|
+
task: "token-classification"
|
|
1865
|
+
});
|
|
1866
|
+
return providerHelper.getResponse(res);
|
|
1492
1867
|
}
|
|
1493
1868
|
|
|
1494
1869
|
// src/tasks/nlp/translation.ts
|
|
1495
1870
|
async function translation(args, options) {
|
|
1496
|
-
const
|
|
1871
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation");
|
|
1872
|
+
const { data: res } = await innerRequest(args, {
|
|
1497
1873
|
...options,
|
|
1498
1874
|
task: "translation"
|
|
1499
1875
|
});
|
|
1500
|
-
|
|
1501
|
-
if (!isValidOutput) {
|
|
1502
|
-
throw new InferenceOutputError("Expected type Array<{translation_text: string}>");
|
|
1503
|
-
}
|
|
1504
|
-
return res?.length === 1 ? res?.[0] : res;
|
|
1876
|
+
return providerHelper.getResponse(res);
|
|
1505
1877
|
}
|
|
1506
1878
|
|
|
1507
1879
|
// src/tasks/nlp/zeroShotClassification.ts
|
|
1508
1880
|
async function zeroShotClassification(args, options) {
|
|
1509
|
-
const
|
|
1510
|
-
|
|
1511
|
-
...options,
|
|
1512
|
-
task: "zero-shot-classification"
|
|
1513
|
-
})
|
|
1514
|
-
);
|
|
1515
|
-
const isValidOutput = Array.isArray(res) && res.every(
|
|
1516
|
-
(x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
|
|
1517
|
-
);
|
|
1518
|
-
if (!isValidOutput) {
|
|
1519
|
-
throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
|
|
1520
|
-
}
|
|
1521
|
-
return res;
|
|
1522
|
-
}
|
|
1523
|
-
|
|
1524
|
-
// src/tasks/nlp/chatCompletion.ts
|
|
1525
|
-
async function chatCompletion(args, options) {
|
|
1526
|
-
const res = await request(args, {
|
|
1527
|
-
...options,
|
|
1528
|
-
task: "text-generation",
|
|
1529
|
-
chatCompletion: true
|
|
1530
|
-
});
|
|
1531
|
-
const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
|
|
1532
|
-
(res.system_fingerprint === void 0 || res.system_fingerprint === null || typeof res.system_fingerprint === "string") && typeof res?.usage === "object";
|
|
1533
|
-
if (!isValidOutput) {
|
|
1534
|
-
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
1535
|
-
}
|
|
1536
|
-
return res;
|
|
1537
|
-
}
|
|
1538
|
-
|
|
1539
|
-
// src/tasks/nlp/chatCompletionStream.ts
|
|
1540
|
-
async function* chatCompletionStream(args, options) {
|
|
1541
|
-
yield* streamingRequest(args, {
|
|
1881
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification");
|
|
1882
|
+
const { data: res } = await innerRequest(args, {
|
|
1542
1883
|
...options,
|
|
1543
|
-
task: "
|
|
1544
|
-
chatCompletion: true
|
|
1884
|
+
task: "zero-shot-classification"
|
|
1545
1885
|
});
|
|
1886
|
+
return providerHelper.getResponse(res);
|
|
1546
1887
|
}
|
|
1547
1888
|
|
|
1548
1889
|
// src/tasks/multimodal/documentQuestionAnswering.ts
|
|
1549
1890
|
async function documentQuestionAnswering(args, options) {
|
|
1891
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering");
|
|
1550
1892
|
const reqArgs = {
|
|
1551
1893
|
...args,
|
|
1552
1894
|
inputs: {
|
|
@@ -1555,23 +1897,19 @@ async function documentQuestionAnswering(args, options) {
|
|
|
1555
1897
|
image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer()))
|
|
1556
1898
|
}
|
|
1557
1899
|
};
|
|
1558
|
-
const res =
|
|
1559
|
-
|
|
1900
|
+
const { data: res } = await innerRequest(
|
|
1901
|
+
reqArgs,
|
|
1902
|
+
{
|
|
1560
1903
|
...options,
|
|
1561
1904
|
task: "document-question-answering"
|
|
1562
|
-
}
|
|
1563
|
-
);
|
|
1564
|
-
const isValidOutput = Array.isArray(res) && res.every(
|
|
1565
|
-
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
|
|
1905
|
+
}
|
|
1566
1906
|
);
|
|
1567
|
-
|
|
1568
|
-
throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>");
|
|
1569
|
-
}
|
|
1570
|
-
return res[0];
|
|
1907
|
+
return providerHelper.getResponse(res);
|
|
1571
1908
|
}
|
|
1572
1909
|
|
|
1573
1910
|
// src/tasks/multimodal/visualQuestionAnswering.ts
|
|
1574
1911
|
async function visualQuestionAnswering(args, options) {
|
|
1912
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "visual-question-answering");
|
|
1575
1913
|
const reqArgs = {
|
|
1576
1914
|
...args,
|
|
1577
1915
|
inputs: {
|
|
@@ -1580,43 +1918,31 @@ async function visualQuestionAnswering(args, options) {
|
|
|
1580
1918
|
image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer()))
|
|
1581
1919
|
}
|
|
1582
1920
|
};
|
|
1583
|
-
const res = await
|
|
1921
|
+
const { data: res } = await innerRequest(reqArgs, {
|
|
1584
1922
|
...options,
|
|
1585
1923
|
task: "visual-question-answering"
|
|
1586
1924
|
});
|
|
1587
|
-
|
|
1588
|
-
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
|
|
1589
|
-
);
|
|
1590
|
-
if (!isValidOutput) {
|
|
1591
|
-
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
|
1592
|
-
}
|
|
1593
|
-
return res[0];
|
|
1925
|
+
return providerHelper.getResponse(res);
|
|
1594
1926
|
}
|
|
1595
1927
|
|
|
1596
|
-
// src/tasks/tabular/
|
|
1597
|
-
async function
|
|
1598
|
-
const
|
|
1928
|
+
// src/tasks/tabular/tabularClassification.ts
|
|
1929
|
+
async function tabularClassification(args, options) {
|
|
1930
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification");
|
|
1931
|
+
const { data: res } = await innerRequest(args, {
|
|
1599
1932
|
...options,
|
|
1600
|
-
task: "tabular-
|
|
1933
|
+
task: "tabular-classification"
|
|
1601
1934
|
});
|
|
1602
|
-
|
|
1603
|
-
if (!isValidOutput) {
|
|
1604
|
-
throw new InferenceOutputError("Expected number[]");
|
|
1605
|
-
}
|
|
1606
|
-
return res;
|
|
1935
|
+
return providerHelper.getResponse(res);
|
|
1607
1936
|
}
|
|
1608
1937
|
|
|
1609
|
-
// src/tasks/tabular/
|
|
1610
|
-
async function
|
|
1611
|
-
const
|
|
1938
|
+
// src/tasks/tabular/tabularRegression.ts
|
|
1939
|
+
async function tabularRegression(args, options) {
|
|
1940
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression");
|
|
1941
|
+
const { data: res } = await innerRequest(args, {
|
|
1612
1942
|
...options,
|
|
1613
|
-
task: "tabular-
|
|
1943
|
+
task: "tabular-regression"
|
|
1614
1944
|
});
|
|
1615
|
-
|
|
1616
|
-
if (!isValidOutput) {
|
|
1617
|
-
throw new InferenceOutputError("Expected number[]");
|
|
1618
|
-
}
|
|
1619
|
-
return res;
|
|
1945
|
+
return providerHelper.getResponse(res);
|
|
1620
1946
|
}
|
|
1621
1947
|
|
|
1622
1948
|
// src/InferenceClient.ts
|
|
@@ -1685,11 +2011,11 @@ __export(snippets_exports, {
|
|
|
1685
2011
|
});
|
|
1686
2012
|
|
|
1687
2013
|
// src/snippets/getInferenceSnippets.ts
|
|
2014
|
+
import { Template } from "@huggingface/jinja";
|
|
1688
2015
|
import {
|
|
1689
|
-
|
|
1690
|
-
|
|
2016
|
+
getModelInputSnippet,
|
|
2017
|
+
inferenceSnippetLanguages
|
|
1691
2018
|
} from "@huggingface/tasks";
|
|
1692
|
-
import { Template } from "@huggingface/jinja";
|
|
1693
2019
|
|
|
1694
2020
|
// src/snippets/templates.exported.ts
|
|
1695
2021
|
var templates = {
|
|
@@ -1699,7 +2025,7 @@ var templates = {
|
|
|
1699
2025
|
"basicAudio": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "audio/flac"\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
|
|
1700
2026
|
"basicImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "image/jpeg"\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
|
|
1701
2027
|
"textToAudio": '{% if model.library_name == "transformers" %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ',
|
|
1702
|
-
"textToImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\nquery({
|
|
2028
|
+
"textToImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\n\nquery({ {{ providerInputs.asTsString }} }).then((response) => {\n // Use image\n});',
|
|
1703
2029
|
"zeroShotClassification": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({\n inputs: {{ providerInputs.asObj.inputs }},\n parameters: { candidate_labels: ["refund", "legal", "faq"] }\n}).then((response) => {\n console.log(JSON.stringify(response));\n});'
|
|
1704
2030
|
},
|
|
1705
2031
|
"huggingface.js": {
|
|
@@ -1732,7 +2058,7 @@ const image = await client.textToVideo({
|
|
|
1732
2058
|
},
|
|
1733
2059
|
"openai": {
|
|
1734
2060
|
"conversational": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\nconst chatCompletion = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);',
|
|
1735
|
-
"conversationalStream": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\
|
|
2061
|
+
"conversationalStream": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\nconst stream = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n stream: true,\n});\n\nfor await (const chunk of stream) {\n process.stdout.write(chunk.choices[0]?.delta?.content || "");\n}'
|
|
1736
2062
|
}
|
|
1737
2063
|
},
|
|
1738
2064
|
"python": {
|
|
@@ -1864,15 +2190,23 @@ var HF_JS_METHODS = {
|
|
|
1864
2190
|
};
|
|
1865
2191
|
var snippetGenerator = (templateName, inputPreparationFn) => {
|
|
1866
2192
|
return (model, accessToken, provider, providerModelId, opts) => {
|
|
2193
|
+
let task = model.pipeline_tag;
|
|
1867
2194
|
if (model.pipeline_tag && ["text-generation", "image-text-to-text"].includes(model.pipeline_tag) && model.tags.includes("conversational")) {
|
|
1868
2195
|
templateName = opts?.streaming ? "conversationalStream" : "conversational";
|
|
1869
2196
|
inputPreparationFn = prepareConversationalInput;
|
|
2197
|
+
task = "conversational";
|
|
1870
2198
|
}
|
|
1871
2199
|
const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
|
|
1872
2200
|
const request2 = makeRequestOptionsFromResolvedModel(
|
|
1873
2201
|
providerModelId ?? model.id,
|
|
1874
|
-
{
|
|
1875
|
-
|
|
2202
|
+
{
|
|
2203
|
+
accessToken,
|
|
2204
|
+
provider,
|
|
2205
|
+
...inputs
|
|
2206
|
+
},
|
|
2207
|
+
{
|
|
2208
|
+
task
|
|
2209
|
+
}
|
|
1876
2210
|
);
|
|
1877
2211
|
let providerInputs = inputs;
|
|
1878
2212
|
const bodyAsObj = request2.info.body;
|
|
@@ -1959,7 +2293,7 @@ var prepareConversationalInput = (model, opts) => {
|
|
|
1959
2293
|
return {
|
|
1960
2294
|
messages: opts?.messages ?? getModelInputSnippet(model),
|
|
1961
2295
|
...opts?.temperature ? { temperature: opts?.temperature } : void 0,
|
|
1962
|
-
max_tokens: opts?.max_tokens ??
|
|
2296
|
+
max_tokens: opts?.max_tokens ?? 512,
|
|
1963
2297
|
...opts?.top_p ? { top_p: opts?.top_p } : void 0
|
|
1964
2298
|
};
|
|
1965
2299
|
};
|