@huggingface/inference 3.7.0 → 3.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.cjs +1369 -941
- package/dist/index.js +1371 -943
- package/dist/src/lib/getInferenceProviderMapping.d.ts +21 -0
- package/dist/src/lib/getInferenceProviderMapping.d.ts.map +1 -0
- package/dist/src/lib/getProviderHelper.d.ts +37 -0
- package/dist/src/lib/getProviderHelper.d.ts.map +1 -0
- package/dist/src/lib/makeRequestOptions.d.ts +5 -5
- package/dist/src/lib/makeRequestOptions.d.ts.map +1 -1
- package/dist/src/providers/black-forest-labs.d.ts +14 -18
- package/dist/src/providers/black-forest-labs.d.ts.map +1 -1
- package/dist/src/providers/cerebras.d.ts +4 -2
- package/dist/src/providers/cerebras.d.ts.map +1 -1
- package/dist/src/providers/cohere.d.ts +5 -2
- package/dist/src/providers/cohere.d.ts.map +1 -1
- package/dist/src/providers/consts.d.ts +2 -3
- package/dist/src/providers/consts.d.ts.map +1 -1
- package/dist/src/providers/fal-ai.d.ts +50 -3
- package/dist/src/providers/fal-ai.d.ts.map +1 -1
- package/dist/src/providers/fireworks-ai.d.ts +5 -2
- package/dist/src/providers/fireworks-ai.d.ts.map +1 -1
- package/dist/src/providers/hf-inference.d.ts +126 -2
- package/dist/src/providers/hf-inference.d.ts.map +1 -1
- package/dist/src/providers/hyperbolic.d.ts +31 -2
- package/dist/src/providers/hyperbolic.d.ts.map +1 -1
- package/dist/src/providers/nebius.d.ts +20 -18
- package/dist/src/providers/nebius.d.ts.map +1 -1
- package/dist/src/providers/novita.d.ts +21 -18
- package/dist/src/providers/novita.d.ts.map +1 -1
- package/dist/src/providers/openai.d.ts +4 -2
- package/dist/src/providers/openai.d.ts.map +1 -1
- package/dist/src/providers/providerHelper.d.ts +182 -0
- package/dist/src/providers/providerHelper.d.ts.map +1 -0
- package/dist/src/providers/replicate.d.ts +23 -19
- package/dist/src/providers/replicate.d.ts.map +1 -1
- package/dist/src/providers/sambanova.d.ts +4 -2
- package/dist/src/providers/sambanova.d.ts.map +1 -1
- package/dist/src/providers/together.d.ts +32 -2
- package/dist/src/providers/together.d.ts.map +1 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts +2 -1
- package/dist/src/snippets/getInferenceSnippets.d.ts.map +1 -1
- package/dist/src/tasks/audio/audioClassification.d.ts.map +1 -1
- package/dist/src/tasks/audio/automaticSpeechRecognition.d.ts.map +1 -1
- package/dist/src/tasks/audio/textToSpeech.d.ts.map +1 -1
- package/dist/src/tasks/audio/utils.d.ts +2 -1
- package/dist/src/tasks/audio/utils.d.ts.map +1 -1
- package/dist/src/tasks/custom/request.d.ts +0 -2
- package/dist/src/tasks/custom/request.d.ts.map +1 -1
- package/dist/src/tasks/custom/streamingRequest.d.ts +0 -2
- package/dist/src/tasks/custom/streamingRequest.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageClassification.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageSegmentation.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/imageToText.d.ts.map +1 -1
- package/dist/src/tasks/cv/objectDetection.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToImage.d.ts.map +1 -1
- package/dist/src/tasks/cv/textToVideo.d.ts.map +1 -1
- package/dist/src/tasks/cv/zeroShotImageClassification.d.ts.map +1 -1
- package/dist/src/tasks/index.d.ts +6 -6
- package/dist/src/tasks/index.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/documentQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/multimodal/visualQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletion.d.ts.map +1 -1
- package/dist/src/tasks/nlp/chatCompletionStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/featureExtraction.d.ts.map +1 -1
- package/dist/src/tasks/nlp/fillMask.d.ts.map +1 -1
- package/dist/src/tasks/nlp/questionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/sentenceSimilarity.d.ts.map +1 -1
- package/dist/src/tasks/nlp/summarization.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tableQuestionAnswering.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGeneration.d.ts.map +1 -1
- package/dist/src/tasks/nlp/textGenerationStream.d.ts.map +1 -1
- package/dist/src/tasks/nlp/tokenClassification.d.ts.map +1 -1
- package/dist/src/tasks/nlp/translation.d.ts.map +1 -1
- package/dist/src/tasks/nlp/zeroShotClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularClassification.d.ts.map +1 -1
- package/dist/src/tasks/tabular/tabularRegression.d.ts.map +1 -1
- package/dist/src/types.d.ts +5 -13
- package/dist/src/types.d.ts.map +1 -1
- package/dist/src/utils/request.d.ts +3 -2
- package/dist/src/utils/request.d.ts.map +1 -1
- package/package.json +3 -3
- package/src/lib/getInferenceProviderMapping.ts +96 -0
- package/src/lib/getProviderHelper.ts +270 -0
- package/src/lib/makeRequestOptions.ts +78 -97
- package/src/providers/black-forest-labs.ts +73 -22
- package/src/providers/cerebras.ts +6 -27
- package/src/providers/cohere.ts +9 -28
- package/src/providers/consts.ts +5 -2
- package/src/providers/fal-ai.ts +224 -77
- package/src/providers/fireworks-ai.ts +8 -29
- package/src/providers/hf-inference.ts +557 -34
- package/src/providers/hyperbolic.ts +107 -29
- package/src/providers/nebius.ts +65 -29
- package/src/providers/novita.ts +68 -32
- package/src/providers/openai.ts +6 -32
- package/src/providers/providerHelper.ts +354 -0
- package/src/providers/replicate.ts +124 -34
- package/src/providers/sambanova.ts +5 -30
- package/src/providers/together.ts +92 -28
- package/src/snippets/getInferenceSnippets.ts +39 -14
- package/src/snippets/templates.exported.ts +25 -25
- package/src/tasks/audio/audioClassification.ts +5 -8
- package/src/tasks/audio/audioToAudio.ts +4 -27
- package/src/tasks/audio/automaticSpeechRecognition.ts +5 -4
- package/src/tasks/audio/textToSpeech.ts +5 -29
- package/src/tasks/audio/utils.ts +2 -1
- package/src/tasks/custom/request.ts +3 -3
- package/src/tasks/custom/streamingRequest.ts +4 -3
- package/src/tasks/cv/imageClassification.ts +4 -8
- package/src/tasks/cv/imageSegmentation.ts +4 -9
- package/src/tasks/cv/imageToImage.ts +4 -7
- package/src/tasks/cv/imageToText.ts +4 -7
- package/src/tasks/cv/objectDetection.ts +4 -19
- package/src/tasks/cv/textToImage.ts +9 -137
- package/src/tasks/cv/textToVideo.ts +17 -64
- package/src/tasks/cv/zeroShotImageClassification.ts +4 -8
- package/src/tasks/index.ts +6 -6
- package/src/tasks/multimodal/documentQuestionAnswering.ts +4 -19
- package/src/tasks/multimodal/visualQuestionAnswering.ts +4 -12
- package/src/tasks/nlp/chatCompletion.ts +5 -20
- package/src/tasks/nlp/chatCompletionStream.ts +4 -3
- package/src/tasks/nlp/featureExtraction.ts +4 -19
- package/src/tasks/nlp/fillMask.ts +4 -17
- package/src/tasks/nlp/questionAnswering.ts +11 -26
- package/src/tasks/nlp/sentenceSimilarity.ts +4 -8
- package/src/tasks/nlp/summarization.ts +4 -7
- package/src/tasks/nlp/tableQuestionAnswering.ts +10 -30
- package/src/tasks/nlp/textClassification.ts +4 -9
- package/src/tasks/nlp/textGeneration.ts +11 -79
- package/src/tasks/nlp/textGenerationStream.ts +3 -1
- package/src/tasks/nlp/tokenClassification.ts +11 -23
- package/src/tasks/nlp/translation.ts +4 -7
- package/src/tasks/nlp/zeroShotClassification.ts +11 -21
- package/src/tasks/tabular/tabularClassification.ts +4 -7
- package/src/tasks/tabular/tabularRegression.ts +4 -7
- package/src/types.ts +5 -14
- package/src/utils/request.ts +7 -4
- package/dist/src/lib/getProviderModelId.d.ts +0 -10
- package/dist/src/lib/getProviderModelId.d.ts.map +0 -1
- package/src/lib/getProviderModelId.ts +0 -74
package/dist/index.js
CHANGED
|
@@ -41,91 +41,211 @@ __export(tasks_exports, {
|
|
|
41
41
|
zeroShotImageClassification: () => zeroShotImageClassification
|
|
42
42
|
});
|
|
43
43
|
|
|
44
|
+
// src/lib/InferenceOutputError.ts
|
|
45
|
+
var InferenceOutputError = class extends TypeError {
|
|
46
|
+
constructor(message) {
|
|
47
|
+
super(
|
|
48
|
+
`Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.`
|
|
49
|
+
);
|
|
50
|
+
this.name = "InferenceOutputError";
|
|
51
|
+
}
|
|
52
|
+
};
|
|
53
|
+
|
|
54
|
+
// src/utils/delay.ts
|
|
55
|
+
function delay(ms) {
|
|
56
|
+
return new Promise((resolve) => {
|
|
57
|
+
setTimeout(() => resolve(), ms);
|
|
58
|
+
});
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
// src/utils/pick.ts
|
|
62
|
+
function pick(o, props) {
|
|
63
|
+
return Object.assign(
|
|
64
|
+
{},
|
|
65
|
+
...props.map((prop) => {
|
|
66
|
+
if (o[prop] !== void 0) {
|
|
67
|
+
return { [prop]: o[prop] };
|
|
68
|
+
}
|
|
69
|
+
})
|
|
70
|
+
);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
// src/utils/typedInclude.ts
|
|
74
|
+
function typedInclude(arr, v) {
|
|
75
|
+
return arr.includes(v);
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
// src/utils/omit.ts
|
|
79
|
+
function omit(o, props) {
|
|
80
|
+
const propsArr = Array.isArray(props) ? props : [props];
|
|
81
|
+
const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
|
|
82
|
+
return pick(o, letsKeep);
|
|
83
|
+
}
|
|
84
|
+
|
|
44
85
|
// src/config.ts
|
|
45
86
|
var HF_HUB_URL = "https://huggingface.co";
|
|
46
87
|
var HF_ROUTER_URL = "https://router.huggingface.co";
|
|
47
88
|
var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
|
|
48
89
|
|
|
49
|
-
// src/
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
}
|
|
54
|
-
|
|
55
|
-
|
|
90
|
+
// src/utils/toArray.ts
|
|
91
|
+
function toArray(obj) {
|
|
92
|
+
if (Array.isArray(obj)) {
|
|
93
|
+
return obj;
|
|
94
|
+
}
|
|
95
|
+
return [obj];
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
// src/providers/providerHelper.ts
|
|
99
|
+
var TaskProviderHelper = class {
|
|
100
|
+
constructor(provider, baseUrl, clientSideRoutingOnly = false) {
|
|
101
|
+
this.provider = provider;
|
|
102
|
+
this.baseUrl = baseUrl;
|
|
103
|
+
this.clientSideRoutingOnly = clientSideRoutingOnly;
|
|
104
|
+
}
|
|
105
|
+
/**
|
|
106
|
+
* Prepare the base URL for the request
|
|
107
|
+
*/
|
|
108
|
+
makeBaseUrl(params) {
|
|
109
|
+
return params.authMethod !== "provider-key" ? `${HF_ROUTER_URL}/${this.provider}` : this.baseUrl;
|
|
110
|
+
}
|
|
111
|
+
/**
|
|
112
|
+
* Prepare the body for the request
|
|
113
|
+
*/
|
|
114
|
+
makeBody(params) {
|
|
115
|
+
if ("data" in params.args && !!params.args.data) {
|
|
116
|
+
return params.args.data;
|
|
117
|
+
}
|
|
118
|
+
return JSON.stringify(this.preparePayload(params));
|
|
119
|
+
}
|
|
120
|
+
/**
|
|
121
|
+
* Prepare the URL for the request
|
|
122
|
+
*/
|
|
123
|
+
makeUrl(params) {
|
|
124
|
+
const baseUrl = this.makeBaseUrl(params);
|
|
125
|
+
const route = this.makeRoute(params).replace(/^\/+/, "");
|
|
126
|
+
return `${baseUrl}/${route}`;
|
|
127
|
+
}
|
|
128
|
+
/**
|
|
129
|
+
* Prepare the headers for the request
|
|
130
|
+
*/
|
|
131
|
+
prepareHeaders(params, isBinary) {
|
|
132
|
+
const headers = { Authorization: `Bearer ${params.accessToken}` };
|
|
133
|
+
if (!isBinary) {
|
|
134
|
+
headers["Content-Type"] = "application/json";
|
|
135
|
+
}
|
|
136
|
+
return headers;
|
|
137
|
+
}
|
|
56
138
|
};
|
|
57
|
-
var
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
}
|
|
61
|
-
|
|
139
|
+
var BaseConversationalTask = class extends TaskProviderHelper {
|
|
140
|
+
constructor(provider, baseUrl, clientSideRoutingOnly = false) {
|
|
141
|
+
super(provider, baseUrl, clientSideRoutingOnly);
|
|
142
|
+
}
|
|
143
|
+
makeRoute() {
|
|
144
|
+
return "v1/chat/completions";
|
|
145
|
+
}
|
|
146
|
+
preparePayload(params) {
|
|
147
|
+
return {
|
|
148
|
+
...params.args,
|
|
149
|
+
model: params.model
|
|
150
|
+
};
|
|
151
|
+
}
|
|
152
|
+
async getResponse(response) {
|
|
153
|
+
if (typeof response === "object" && Array.isArray(response?.choices) && typeof response?.created === "number" && typeof response?.id === "string" && typeof response?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
|
|
154
|
+
(response.system_fingerprint === void 0 || response.system_fingerprint === null || typeof response.system_fingerprint === "string") && typeof response?.usage === "object") {
|
|
155
|
+
return response;
|
|
156
|
+
}
|
|
157
|
+
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
62
158
|
}
|
|
63
159
|
};
|
|
64
|
-
var
|
|
65
|
-
|
|
160
|
+
var BaseTextGenerationTask = class extends TaskProviderHelper {
|
|
161
|
+
constructor(provider, baseUrl, clientSideRoutingOnly = false) {
|
|
162
|
+
super(provider, baseUrl, clientSideRoutingOnly);
|
|
163
|
+
}
|
|
164
|
+
preparePayload(params) {
|
|
165
|
+
return {
|
|
166
|
+
...params.args,
|
|
167
|
+
model: params.model
|
|
168
|
+
};
|
|
169
|
+
}
|
|
170
|
+
makeRoute() {
|
|
171
|
+
return "v1/completions";
|
|
172
|
+
}
|
|
173
|
+
async getResponse(response) {
|
|
174
|
+
const res = toArray(response);
|
|
175
|
+
if (Array.isArray(res) && res.length > 0 && res.every(
|
|
176
|
+
(x) => typeof x === "object" && !!x && "generated_text" in x && typeof x.generated_text === "string"
|
|
177
|
+
)) {
|
|
178
|
+
return res[0];
|
|
179
|
+
}
|
|
180
|
+
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
181
|
+
}
|
|
66
182
|
};
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
183
|
+
|
|
184
|
+
// src/providers/black-forest-labs.ts
|
|
185
|
+
var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
|
|
186
|
+
var BlackForestLabsTextToImageTask = class extends TaskProviderHelper {
|
|
187
|
+
constructor() {
|
|
188
|
+
super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
|
|
189
|
+
}
|
|
190
|
+
preparePayload(params) {
|
|
191
|
+
return {
|
|
192
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
193
|
+
...params.args.parameters,
|
|
194
|
+
prompt: params.args.inputs
|
|
195
|
+
};
|
|
196
|
+
}
|
|
197
|
+
prepareHeaders(params, binary) {
|
|
198
|
+
const headers = {
|
|
199
|
+
Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`
|
|
200
|
+
};
|
|
201
|
+
if (!binary) {
|
|
202
|
+
headers["Content-Type"] = "application/json";
|
|
203
|
+
}
|
|
204
|
+
return headers;
|
|
205
|
+
}
|
|
206
|
+
makeRoute(params) {
|
|
207
|
+
if (!params) {
|
|
208
|
+
throw new Error("Params are required");
|
|
209
|
+
}
|
|
210
|
+
return `/v1/${params.model}`;
|
|
211
|
+
}
|
|
212
|
+
async getResponse(response, url, headers, outputType) {
|
|
213
|
+
const urlObj = new URL(response.polling_url);
|
|
214
|
+
for (let step = 0; step < 5; step++) {
|
|
215
|
+
await delay(1e3);
|
|
216
|
+
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
|
|
217
|
+
urlObj.searchParams.set("attempt", step.toString(10));
|
|
218
|
+
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
|
|
219
|
+
if (!resp.ok) {
|
|
220
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
221
|
+
}
|
|
222
|
+
const payload = await resp.json();
|
|
223
|
+
if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
|
|
224
|
+
if (outputType === "url") {
|
|
225
|
+
return payload.result.sample;
|
|
226
|
+
}
|
|
227
|
+
const image = await fetch(payload.result.sample);
|
|
228
|
+
return await image.blob();
|
|
229
|
+
}
|
|
230
|
+
}
|
|
231
|
+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
232
|
+
}
|
|
72
233
|
};
|
|
73
234
|
|
|
74
235
|
// src/providers/cerebras.ts
|
|
75
|
-
var
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
}
|
|
79
|
-
var makeBody2 = (params) => {
|
|
80
|
-
return {
|
|
81
|
-
...params.args,
|
|
82
|
-
model: params.model
|
|
83
|
-
};
|
|
84
|
-
};
|
|
85
|
-
var makeHeaders2 = (params) => {
|
|
86
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
87
|
-
};
|
|
88
|
-
var makeUrl2 = (params) => {
|
|
89
|
-
return `${params.baseUrl}/v1/chat/completions`;
|
|
90
|
-
};
|
|
91
|
-
var CEREBRAS_CONFIG = {
|
|
92
|
-
makeBaseUrl: makeBaseUrl2,
|
|
93
|
-
makeBody: makeBody2,
|
|
94
|
-
makeHeaders: makeHeaders2,
|
|
95
|
-
makeUrl: makeUrl2
|
|
236
|
+
var CerebrasConversationalTask = class extends BaseConversationalTask {
|
|
237
|
+
constructor() {
|
|
238
|
+
super("cerebras", "https://api.cerebras.ai");
|
|
239
|
+
}
|
|
96
240
|
};
|
|
97
241
|
|
|
98
242
|
// src/providers/cohere.ts
|
|
99
|
-
var
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
}
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
...params.args,
|
|
106
|
-
model: params.model
|
|
107
|
-
};
|
|
108
|
-
};
|
|
109
|
-
var makeHeaders3 = (params) => {
|
|
110
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
111
|
-
};
|
|
112
|
-
var makeUrl3 = (params) => {
|
|
113
|
-
return `${params.baseUrl}/compatibility/v1/chat/completions`;
|
|
114
|
-
};
|
|
115
|
-
var COHERE_CONFIG = {
|
|
116
|
-
makeBaseUrl: makeBaseUrl3,
|
|
117
|
-
makeBody: makeBody3,
|
|
118
|
-
makeHeaders: makeHeaders3,
|
|
119
|
-
makeUrl: makeUrl3
|
|
120
|
-
};
|
|
121
|
-
|
|
122
|
-
// src/lib/InferenceOutputError.ts
|
|
123
|
-
var InferenceOutputError = class extends TypeError {
|
|
124
|
-
constructor(message) {
|
|
125
|
-
super(
|
|
126
|
-
`Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.`
|
|
127
|
-
);
|
|
128
|
-
this.name = "InferenceOutputError";
|
|
243
|
+
var CohereConversationalTask = class extends BaseConversationalTask {
|
|
244
|
+
constructor() {
|
|
245
|
+
super("cohere", "https://api.cohere.com");
|
|
246
|
+
}
|
|
247
|
+
makeRoute() {
|
|
248
|
+
return "/compatibility/v1/chat/completions";
|
|
129
249
|
}
|
|
130
250
|
};
|
|
131
251
|
|
|
@@ -134,352 +254,902 @@ function isUrl(modelOrUrl) {
|
|
|
134
254
|
return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
|
|
135
255
|
}
|
|
136
256
|
|
|
137
|
-
// src/utils/delay.ts
|
|
138
|
-
function delay(ms) {
|
|
139
|
-
return new Promise((resolve) => {
|
|
140
|
-
setTimeout(() => resolve(), ms);
|
|
141
|
-
});
|
|
142
|
-
}
|
|
143
|
-
|
|
144
257
|
// src/providers/fal-ai.ts
|
|
145
|
-
var
|
|
146
|
-
var
|
|
147
|
-
|
|
148
|
-
|
|
258
|
+
var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
|
|
259
|
+
var FalAITask = class extends TaskProviderHelper {
|
|
260
|
+
constructor(url) {
|
|
261
|
+
super("fal-ai", url || "https://fal.run");
|
|
262
|
+
}
|
|
263
|
+
preparePayload(params) {
|
|
264
|
+
return params.args;
|
|
265
|
+
}
|
|
266
|
+
makeRoute(params) {
|
|
267
|
+
return `/${params.model}`;
|
|
268
|
+
}
|
|
269
|
+
prepareHeaders(params, binary) {
|
|
270
|
+
const headers = {
|
|
271
|
+
Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`
|
|
272
|
+
};
|
|
273
|
+
if (!binary) {
|
|
274
|
+
headers["Content-Type"] = "application/json";
|
|
275
|
+
}
|
|
276
|
+
return headers;
|
|
277
|
+
}
|
|
149
278
|
};
|
|
150
|
-
|
|
151
|
-
return
|
|
279
|
+
function buildLoraPath(modelId, adapterWeightsPath) {
|
|
280
|
+
return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
|
|
281
|
+
}
|
|
282
|
+
var FalAITextToImageTask = class extends FalAITask {
|
|
283
|
+
preparePayload(params) {
|
|
284
|
+
const payload = {
|
|
285
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
286
|
+
...params.args.parameters,
|
|
287
|
+
sync_mode: true,
|
|
288
|
+
prompt: params.args.inputs,
|
|
289
|
+
...params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath ? {
|
|
290
|
+
loras: [
|
|
291
|
+
{
|
|
292
|
+
path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
|
|
293
|
+
scale: 1
|
|
294
|
+
}
|
|
295
|
+
]
|
|
296
|
+
} : void 0
|
|
297
|
+
};
|
|
298
|
+
if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
|
|
299
|
+
payload.loras = [
|
|
300
|
+
{
|
|
301
|
+
path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
|
|
302
|
+
scale: 1
|
|
303
|
+
}
|
|
304
|
+
];
|
|
305
|
+
if (params.mapping.providerId === "fal-ai/lora") {
|
|
306
|
+
payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
|
|
307
|
+
}
|
|
308
|
+
}
|
|
309
|
+
return payload;
|
|
310
|
+
}
|
|
311
|
+
async getResponse(response, outputType) {
|
|
312
|
+
if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") {
|
|
313
|
+
if (outputType === "url") {
|
|
314
|
+
return response.images[0].url;
|
|
315
|
+
}
|
|
316
|
+
const urlResponse = await fetch(response.images[0].url);
|
|
317
|
+
return await urlResponse.blob();
|
|
318
|
+
}
|
|
319
|
+
throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
|
|
320
|
+
}
|
|
152
321
|
};
|
|
153
|
-
var
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
}
|
|
322
|
+
var FalAITextToVideoTask = class extends FalAITask {
|
|
323
|
+
constructor() {
|
|
324
|
+
super("https://queue.fal.run");
|
|
325
|
+
}
|
|
326
|
+
makeRoute(params) {
|
|
327
|
+
if (params.authMethod !== "provider-key") {
|
|
328
|
+
return `/${params.model}?_subdomain=queue`;
|
|
329
|
+
}
|
|
330
|
+
return `/${params.model}`;
|
|
331
|
+
}
|
|
332
|
+
preparePayload(params) {
|
|
333
|
+
return {
|
|
334
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
335
|
+
...params.args.parameters,
|
|
336
|
+
prompt: params.args.inputs
|
|
337
|
+
};
|
|
338
|
+
}
|
|
339
|
+
async getResponse(response, url, headers) {
|
|
340
|
+
if (!url || !headers) {
|
|
341
|
+
throw new InferenceOutputError("URL and headers are required for text-to-video task");
|
|
342
|
+
}
|
|
343
|
+
const requestId = response.request_id;
|
|
344
|
+
if (!requestId) {
|
|
345
|
+
throw new InferenceOutputError("No request ID found in the response");
|
|
346
|
+
}
|
|
347
|
+
let status = response.status;
|
|
348
|
+
const parsedUrl = new URL(url);
|
|
349
|
+
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
|
|
350
|
+
const modelId = new URL(response.response_url).pathname;
|
|
351
|
+
const queryParams = parsedUrl.search;
|
|
352
|
+
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
|
|
353
|
+
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
|
|
354
|
+
while (status !== "COMPLETED") {
|
|
355
|
+
await delay(500);
|
|
356
|
+
const statusResponse = await fetch(statusUrl, { headers });
|
|
357
|
+
if (!statusResponse.ok) {
|
|
358
|
+
throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
|
|
359
|
+
}
|
|
360
|
+
try {
|
|
361
|
+
status = (await statusResponse.json()).status;
|
|
362
|
+
} catch (error) {
|
|
363
|
+
throw new InferenceOutputError("Failed to parse status response from fal-ai API");
|
|
364
|
+
}
|
|
365
|
+
}
|
|
366
|
+
const resultResponse = await fetch(resultUrl, { headers });
|
|
367
|
+
let result;
|
|
368
|
+
try {
|
|
369
|
+
result = await resultResponse.json();
|
|
370
|
+
} catch (error) {
|
|
371
|
+
throw new InferenceOutputError("Failed to parse result response from fal-ai API");
|
|
372
|
+
}
|
|
373
|
+
if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
|
|
374
|
+
const urlResponse = await fetch(result.video.url);
|
|
375
|
+
return await urlResponse.blob();
|
|
376
|
+
} else {
|
|
377
|
+
throw new InferenceOutputError(
|
|
378
|
+
"Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
|
|
379
|
+
);
|
|
380
|
+
}
|
|
381
|
+
}
|
|
382
|
+
};
|
|
383
|
+
var FalAIAutomaticSpeechRecognitionTask = class extends FalAITask {
|
|
384
|
+
prepareHeaders(params, binary) {
|
|
385
|
+
const headers = super.prepareHeaders(params, binary);
|
|
386
|
+
headers["Content-Type"] = "application/json";
|
|
387
|
+
return headers;
|
|
388
|
+
}
|
|
389
|
+
async getResponse(response) {
|
|
390
|
+
const res = response;
|
|
391
|
+
if (typeof res?.text !== "string") {
|
|
392
|
+
throw new InferenceOutputError(
|
|
393
|
+
`Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
|
|
394
|
+
);
|
|
395
|
+
}
|
|
396
|
+
return { text: res.text };
|
|
397
|
+
}
|
|
157
398
|
};
|
|
158
|
-
var
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
};
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
}
|
|
171
|
-
|
|
172
|
-
const requestId = res.request_id;
|
|
173
|
-
if (!requestId) {
|
|
174
|
-
throw new InferenceOutputError("No request ID found in the response");
|
|
175
|
-
}
|
|
176
|
-
let status = res.status;
|
|
177
|
-
const parsedUrl = new URL(url);
|
|
178
|
-
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
|
|
179
|
-
const modelId = new URL(res.response_url).pathname;
|
|
180
|
-
const queryParams = parsedUrl.search;
|
|
181
|
-
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
|
|
182
|
-
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
|
|
183
|
-
while (status !== "COMPLETED") {
|
|
184
|
-
await delay(500);
|
|
185
|
-
const statusResponse = await fetch(statusUrl, { headers });
|
|
186
|
-
if (!statusResponse.ok) {
|
|
187
|
-
throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
|
|
399
|
+
var FalAITextToSpeechTask = class extends FalAITask {
|
|
400
|
+
preparePayload(params) {
|
|
401
|
+
return {
|
|
402
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
403
|
+
...params.args.parameters,
|
|
404
|
+
lyrics: params.args.inputs
|
|
405
|
+
};
|
|
406
|
+
}
|
|
407
|
+
async getResponse(response) {
|
|
408
|
+
const res = response;
|
|
409
|
+
if (typeof res?.audio?.url !== "string") {
|
|
410
|
+
throw new InferenceOutputError(
|
|
411
|
+
`Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
|
|
412
|
+
);
|
|
188
413
|
}
|
|
189
414
|
try {
|
|
190
|
-
|
|
415
|
+
const urlResponse = await fetch(res.audio.url);
|
|
416
|
+
if (!urlResponse.ok) {
|
|
417
|
+
throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
|
|
418
|
+
}
|
|
419
|
+
return await urlResponse.blob();
|
|
191
420
|
} catch (error) {
|
|
192
|
-
throw new InferenceOutputError(
|
|
421
|
+
throw new InferenceOutputError(
|
|
422
|
+
`Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)}`
|
|
423
|
+
);
|
|
193
424
|
}
|
|
194
425
|
}
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
426
|
+
};
|
|
427
|
+
|
|
428
|
+
// src/providers/fireworks-ai.ts
|
|
429
|
+
var FireworksConversationalTask = class extends BaseConversationalTask {
|
|
430
|
+
constructor() {
|
|
431
|
+
super("fireworks-ai", "https://api.fireworks.ai");
|
|
201
432
|
}
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
433
|
+
makeRoute() {
|
|
434
|
+
return "/inference/v1/chat/completions";
|
|
435
|
+
}
|
|
436
|
+
};
|
|
437
|
+
|
|
438
|
+
// src/providers/hf-inference.ts
|
|
439
|
+
var EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = ["feature-extraction", "sentence-similarity"];
|
|
440
|
+
var HFInferenceTask = class extends TaskProviderHelper {
|
|
441
|
+
constructor() {
|
|
442
|
+
super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
|
|
443
|
+
}
|
|
444
|
+
preparePayload(params) {
|
|
445
|
+
return params.args;
|
|
446
|
+
}
|
|
447
|
+
makeUrl(params) {
|
|
448
|
+
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
|
|
449
|
+
return params.model;
|
|
450
|
+
}
|
|
451
|
+
return super.makeUrl(params);
|
|
452
|
+
}
|
|
453
|
+
makeRoute(params) {
|
|
454
|
+
if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
|
|
455
|
+
return `pipeline/${params.task}/${params.model}`;
|
|
456
|
+
}
|
|
457
|
+
return `models/${params.model}`;
|
|
458
|
+
}
|
|
459
|
+
async getResponse(response) {
|
|
460
|
+
return response;
|
|
461
|
+
}
|
|
462
|
+
};
|
|
463
|
+
var HFInferenceTextToImageTask = class extends HFInferenceTask {
|
|
464
|
+
async getResponse(response, url, headers, outputType) {
|
|
465
|
+
if (!response) {
|
|
466
|
+
throw new InferenceOutputError("response is undefined");
|
|
467
|
+
}
|
|
468
|
+
if (typeof response == "object") {
|
|
469
|
+
if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
|
|
470
|
+
const base64Data = response.data[0].b64_json;
|
|
471
|
+
if (outputType === "url") {
|
|
472
|
+
return `data:image/jpeg;base64,${base64Data}`;
|
|
473
|
+
}
|
|
474
|
+
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
475
|
+
return await base64Response.blob();
|
|
476
|
+
}
|
|
477
|
+
if ("output" in response && Array.isArray(response.output)) {
|
|
478
|
+
if (outputType === "url") {
|
|
479
|
+
return response.output[0];
|
|
480
|
+
}
|
|
481
|
+
const urlResponse = await fetch(response.output[0]);
|
|
482
|
+
const blob = await urlResponse.blob();
|
|
483
|
+
return blob;
|
|
484
|
+
}
|
|
485
|
+
}
|
|
486
|
+
if (response instanceof Blob) {
|
|
487
|
+
if (outputType === "url") {
|
|
488
|
+
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
|
|
489
|
+
return `data:image/jpeg;base64,${b64}`;
|
|
490
|
+
}
|
|
491
|
+
return response;
|
|
492
|
+
}
|
|
493
|
+
throw new InferenceOutputError("Expected a Blob ");
|
|
494
|
+
}
|
|
495
|
+
};
|
|
496
|
+
var HFInferenceConversationalTask = class extends HFInferenceTask {
|
|
497
|
+
makeUrl(params) {
|
|
498
|
+
let url;
|
|
499
|
+
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
|
|
500
|
+
url = params.model.trim();
|
|
501
|
+
} else {
|
|
502
|
+
url = `${this.makeBaseUrl(params)}/models/${params.model}`;
|
|
503
|
+
}
|
|
504
|
+
url = url.replace(/\/+$/, "");
|
|
505
|
+
if (url.endsWith("/v1")) {
|
|
506
|
+
url += "/chat/completions";
|
|
507
|
+
} else if (!url.endsWith("/chat/completions")) {
|
|
508
|
+
url += "/v1/chat/completions";
|
|
509
|
+
}
|
|
510
|
+
return url;
|
|
511
|
+
}
|
|
512
|
+
preparePayload(params) {
|
|
513
|
+
return {
|
|
514
|
+
...params.args,
|
|
515
|
+
model: params.model
|
|
516
|
+
};
|
|
517
|
+
}
|
|
518
|
+
async getResponse(response) {
|
|
519
|
+
return response;
|
|
520
|
+
}
|
|
521
|
+
};
|
|
522
|
+
var HFInferenceTextGenerationTask = class extends HFInferenceTask {
|
|
523
|
+
async getResponse(response) {
|
|
524
|
+
const res = toArray(response);
|
|
525
|
+
if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
|
|
526
|
+
return res?.[0];
|
|
527
|
+
}
|
|
528
|
+
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
529
|
+
}
|
|
530
|
+
};
|
|
531
|
+
var HFInferenceAudioClassificationTask = class extends HFInferenceTask {
|
|
532
|
+
async getResponse(response) {
|
|
533
|
+
if (Array.isArray(response) && response.every(
|
|
534
|
+
(x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
|
|
535
|
+
)) {
|
|
536
|
+
return response;
|
|
537
|
+
}
|
|
538
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
|
|
539
|
+
}
|
|
540
|
+
};
|
|
541
|
+
var HFInferenceAutomaticSpeechRecognitionTask = class extends HFInferenceTask {
|
|
542
|
+
async getResponse(response) {
|
|
543
|
+
return response;
|
|
544
|
+
}
|
|
545
|
+
};
|
|
546
|
+
var HFInferenceAudioToAudioTask = class extends HFInferenceTask {
|
|
547
|
+
async getResponse(response) {
|
|
548
|
+
if (!Array.isArray(response)) {
|
|
549
|
+
throw new InferenceOutputError("Expected Array");
|
|
550
|
+
}
|
|
551
|
+
if (!response.every((elem) => {
|
|
552
|
+
return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
|
|
553
|
+
})) {
|
|
554
|
+
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
|
|
555
|
+
}
|
|
556
|
+
return response;
|
|
557
|
+
}
|
|
558
|
+
};
|
|
559
|
+
var HFInferenceDocumentQuestionAnsweringTask = class extends HFInferenceTask {
|
|
560
|
+
async getResponse(response) {
|
|
561
|
+
if (Array.isArray(response) && response.every(
|
|
562
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
|
|
563
|
+
)) {
|
|
564
|
+
return response[0];
|
|
565
|
+
}
|
|
566
|
+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
567
|
+
}
|
|
568
|
+
};
|
|
569
|
+
var HFInferenceFeatureExtractionTask = class extends HFInferenceTask {
|
|
570
|
+
async getResponse(response) {
|
|
571
|
+
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
|
|
572
|
+
if (curDepth > maxDepth)
|
|
573
|
+
return false;
|
|
574
|
+
if (arr.every((x) => Array.isArray(x))) {
|
|
575
|
+
return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
|
|
576
|
+
} else {
|
|
577
|
+
return arr.every((x) => typeof x === "number");
|
|
578
|
+
}
|
|
579
|
+
};
|
|
580
|
+
if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
|
|
581
|
+
return response;
|
|
582
|
+
}
|
|
583
|
+
throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
|
|
584
|
+
}
|
|
585
|
+
};
|
|
586
|
+
var HFInferenceImageClassificationTask = class extends HFInferenceTask {
|
|
587
|
+
async getResponse(response) {
|
|
588
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
|
|
589
|
+
return response;
|
|
590
|
+
}
|
|
591
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
592
|
+
}
|
|
593
|
+
};
|
|
594
|
+
var HFInferenceImageSegmentationTask = class extends HFInferenceTask {
|
|
595
|
+
async getResponse(response) {
|
|
596
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) {
|
|
597
|
+
return response;
|
|
598
|
+
}
|
|
599
|
+
throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
|
|
600
|
+
}
|
|
601
|
+
};
|
|
602
|
+
var HFInferenceImageToTextTask = class extends HFInferenceTask {
|
|
603
|
+
async getResponse(response) {
|
|
604
|
+
if (typeof response?.generated_text !== "string") {
|
|
605
|
+
throw new InferenceOutputError("Expected {generated_text: string}");
|
|
606
|
+
}
|
|
607
|
+
return response;
|
|
608
|
+
}
|
|
609
|
+
};
|
|
610
|
+
var HFInferenceImageToImageTask = class extends HFInferenceTask {
|
|
611
|
+
async getResponse(response) {
|
|
612
|
+
if (response instanceof Blob) {
|
|
613
|
+
return response;
|
|
614
|
+
}
|
|
615
|
+
throw new InferenceOutputError("Expected Blob");
|
|
616
|
+
}
|
|
617
|
+
};
|
|
618
|
+
var HFInferenceObjectDetectionTask = class extends HFInferenceTask {
|
|
619
|
+
async getResponse(response) {
|
|
620
|
+
if (Array.isArray(response) && response.every(
|
|
621
|
+
(x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
|
|
622
|
+
)) {
|
|
623
|
+
return response;
|
|
624
|
+
}
|
|
206
625
|
throw new InferenceOutputError(
|
|
207
|
-
"Expected {
|
|
626
|
+
"Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
|
|
208
627
|
);
|
|
209
628
|
}
|
|
210
|
-
}
|
|
211
|
-
|
|
212
|
-
// src/providers/fireworks-ai.ts
|
|
213
|
-
var FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai";
|
|
214
|
-
var makeBaseUrl5 = () => {
|
|
215
|
-
return FIREWORKS_AI_API_BASE_URL;
|
|
216
629
|
};
|
|
217
|
-
var
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
630
|
+
var HFInferenceZeroShotImageClassificationTask = class extends HFInferenceTask {
|
|
631
|
+
async getResponse(response) {
|
|
632
|
+
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
|
|
633
|
+
return response;
|
|
634
|
+
}
|
|
635
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
636
|
+
}
|
|
222
637
|
};
|
|
223
|
-
var
|
|
224
|
-
|
|
638
|
+
var HFInferenceTextClassificationTask = class extends HFInferenceTask {
|
|
639
|
+
async getResponse(response) {
|
|
640
|
+
const output = response?.[0];
|
|
641
|
+
if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
|
|
642
|
+
return output;
|
|
643
|
+
}
|
|
644
|
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
645
|
+
}
|
|
225
646
|
};
|
|
226
|
-
var
|
|
227
|
-
|
|
228
|
-
|
|
647
|
+
var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask {
|
|
648
|
+
async getResponse(response) {
|
|
649
|
+
if (Array.isArray(response) ? response.every(
|
|
650
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
|
|
651
|
+
) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number") {
|
|
652
|
+
return Array.isArray(response) ? response[0] : response;
|
|
653
|
+
}
|
|
654
|
+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
|
|
229
655
|
}
|
|
230
|
-
return `${params.baseUrl}/inference`;
|
|
231
656
|
};
|
|
232
|
-
var
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
657
|
+
var HFInferenceFillMaskTask = class extends HFInferenceTask {
|
|
658
|
+
async getResponse(response) {
|
|
659
|
+
if (Array.isArray(response) && response.every(
|
|
660
|
+
(x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
|
|
661
|
+
)) {
|
|
662
|
+
return response;
|
|
663
|
+
}
|
|
664
|
+
throw new InferenceOutputError(
|
|
665
|
+
"Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
|
|
666
|
+
);
|
|
667
|
+
}
|
|
237
668
|
};
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
669
|
+
var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask {
|
|
670
|
+
async getResponse(response) {
|
|
671
|
+
if (Array.isArray(response) && response.every(
|
|
672
|
+
(x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
|
|
673
|
+
)) {
|
|
674
|
+
return response;
|
|
675
|
+
}
|
|
676
|
+
throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
|
|
677
|
+
}
|
|
242
678
|
};
|
|
243
|
-
var
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
679
|
+
var HFInferenceSentenceSimilarityTask = class extends HFInferenceTask {
|
|
680
|
+
async getResponse(response) {
|
|
681
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
682
|
+
return response;
|
|
683
|
+
}
|
|
684
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
685
|
+
}
|
|
248
686
|
};
|
|
249
|
-
var
|
|
250
|
-
|
|
687
|
+
var HFInferenceTableQuestionAnsweringTask = class extends HFInferenceTask {
|
|
688
|
+
static validate(elem) {
|
|
689
|
+
return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
|
|
690
|
+
(coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
|
|
691
|
+
);
|
|
692
|
+
}
|
|
693
|
+
async getResponse(response) {
|
|
694
|
+
if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) {
|
|
695
|
+
return Array.isArray(response) ? response[0] : response;
|
|
696
|
+
}
|
|
697
|
+
throw new InferenceOutputError(
|
|
698
|
+
"Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
|
|
699
|
+
);
|
|
700
|
+
}
|
|
251
701
|
};
|
|
252
|
-
var
|
|
253
|
-
|
|
254
|
-
|
|
702
|
+
var HFInferenceTokenClassificationTask = class extends HFInferenceTask {
|
|
703
|
+
async getResponse(response) {
|
|
704
|
+
if (Array.isArray(response) && response.every(
|
|
705
|
+
(x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
|
|
706
|
+
)) {
|
|
707
|
+
return response;
|
|
708
|
+
}
|
|
709
|
+
throw new InferenceOutputError(
|
|
710
|
+
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
|
|
711
|
+
);
|
|
255
712
|
}
|
|
256
|
-
|
|
257
|
-
|
|
713
|
+
};
|
|
714
|
+
var HFInferenceTranslationTask = class extends HFInferenceTask {
|
|
715
|
+
async getResponse(response) {
|
|
716
|
+
if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
|
|
717
|
+
return response?.length === 1 ? response?.[0] : response;
|
|
718
|
+
}
|
|
719
|
+
throw new InferenceOutputError("Expected Array<{translation_text: string}>");
|
|
720
|
+
}
|
|
721
|
+
};
|
|
722
|
+
var HFInferenceSummarizationTask = class extends HFInferenceTask {
|
|
723
|
+
async getResponse(response) {
|
|
724
|
+
if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
|
|
725
|
+
return response?.[0];
|
|
726
|
+
}
|
|
727
|
+
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
|
728
|
+
}
|
|
729
|
+
};
|
|
730
|
+
var HFInferenceTextToSpeechTask = class extends HFInferenceTask {
|
|
731
|
+
async getResponse(response) {
|
|
732
|
+
return response;
|
|
258
733
|
}
|
|
259
|
-
return `${params.baseUrl}/models/${params.model}`;
|
|
260
734
|
};
|
|
261
|
-
var
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
735
|
+
var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
|
|
736
|
+
async getResponse(response) {
|
|
737
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
738
|
+
return response;
|
|
739
|
+
}
|
|
740
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
741
|
+
}
|
|
742
|
+
};
|
|
743
|
+
var HFInferenceVisualQuestionAnsweringTask = class extends HFInferenceTask {
|
|
744
|
+
async getResponse(response) {
|
|
745
|
+
if (Array.isArray(response) && response.every(
|
|
746
|
+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
|
|
747
|
+
)) {
|
|
748
|
+
return response[0];
|
|
749
|
+
}
|
|
750
|
+
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
|
751
|
+
}
|
|
752
|
+
};
|
|
753
|
+
var HFInferenceTabularRegressionTask = class extends HFInferenceTask {
|
|
754
|
+
async getResponse(response) {
|
|
755
|
+
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
|
|
756
|
+
return response;
|
|
757
|
+
}
|
|
758
|
+
throw new InferenceOutputError("Expected Array<number>");
|
|
759
|
+
}
|
|
760
|
+
};
|
|
761
|
+
var HFInferenceTextToAudioTask = class extends HFInferenceTask {
|
|
762
|
+
async getResponse(response) {
|
|
763
|
+
return response;
|
|
764
|
+
}
|
|
266
765
|
};
|
|
267
766
|
|
|
268
767
|
// src/providers/hyperbolic.ts
|
|
269
768
|
var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
|
|
270
|
-
var
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
return {
|
|
275
|
-
...params.args,
|
|
276
|
-
...params.task === "text-to-image" ? { model_name: params.model } : { model: params.model }
|
|
277
|
-
};
|
|
278
|
-
};
|
|
279
|
-
var makeHeaders7 = (params) => {
|
|
280
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
769
|
+
var HyperbolicConversationalTask = class extends BaseConversationalTask {
|
|
770
|
+
constructor() {
|
|
771
|
+
super("hyperbolic", HYPERBOLIC_API_BASE_URL);
|
|
772
|
+
}
|
|
281
773
|
};
|
|
282
|
-
var
|
|
283
|
-
|
|
284
|
-
|
|
774
|
+
var HyperbolicTextGenerationTask = class extends BaseTextGenerationTask {
|
|
775
|
+
constructor() {
|
|
776
|
+
super("hyperbolic", HYPERBOLIC_API_BASE_URL);
|
|
777
|
+
}
|
|
778
|
+
makeRoute() {
|
|
779
|
+
return "v1/chat/completions";
|
|
780
|
+
}
|
|
781
|
+
preparePayload(params) {
|
|
782
|
+
return {
|
|
783
|
+
messages: [{ content: params.args.inputs, role: "user" }],
|
|
784
|
+
...params.args.parameters ? {
|
|
785
|
+
max_tokens: params.args.parameters.max_new_tokens,
|
|
786
|
+
...omit(params.args.parameters, "max_new_tokens")
|
|
787
|
+
} : void 0,
|
|
788
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
789
|
+
model: params.model
|
|
790
|
+
};
|
|
791
|
+
}
|
|
792
|
+
async getResponse(response) {
|
|
793
|
+
if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
|
|
794
|
+
const completion = response.choices[0];
|
|
795
|
+
return {
|
|
796
|
+
generated_text: completion.message.content
|
|
797
|
+
};
|
|
798
|
+
}
|
|
799
|
+
throw new InferenceOutputError("Expected Hyperbolic text generation response format");
|
|
285
800
|
}
|
|
286
|
-
return `${params.baseUrl}/v1/chat/completions`;
|
|
287
801
|
};
|
|
288
|
-
var
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
802
|
+
var HyperbolicTextToImageTask = class extends TaskProviderHelper {
|
|
803
|
+
constructor() {
|
|
804
|
+
super("hyperbolic", HYPERBOLIC_API_BASE_URL);
|
|
805
|
+
}
|
|
806
|
+
makeRoute(params) {
|
|
807
|
+
return `/v1/images/generations`;
|
|
808
|
+
}
|
|
809
|
+
preparePayload(params) {
|
|
810
|
+
return {
|
|
811
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
812
|
+
...params.args.parameters,
|
|
813
|
+
prompt: params.args.inputs,
|
|
814
|
+
model_name: params.model
|
|
815
|
+
};
|
|
816
|
+
}
|
|
817
|
+
async getResponse(response, url, headers, outputType) {
|
|
818
|
+
if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images[0] && typeof response.images[0].image === "string") {
|
|
819
|
+
if (outputType === "url") {
|
|
820
|
+
return `data:image/jpeg;base64,${response.images[0].image}`;
|
|
821
|
+
}
|
|
822
|
+
return fetch(`data:image/jpeg;base64,${response.images[0].image}`).then((res) => res.blob());
|
|
823
|
+
}
|
|
824
|
+
throw new InferenceOutputError("Expected Hyperbolic text-to-image response format");
|
|
825
|
+
}
|
|
293
826
|
};
|
|
294
827
|
|
|
295
828
|
// src/providers/nebius.ts
|
|
296
829
|
var NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
|
|
297
|
-
var
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
return {
|
|
302
|
-
...params.args,
|
|
303
|
-
model: params.model
|
|
304
|
-
};
|
|
830
|
+
var NebiusConversationalTask = class extends BaseConversationalTask {
|
|
831
|
+
constructor() {
|
|
832
|
+
super("nebius", NEBIUS_API_BASE_URL);
|
|
833
|
+
}
|
|
305
834
|
};
|
|
306
|
-
var
|
|
307
|
-
|
|
835
|
+
var NebiusTextGenerationTask = class extends BaseTextGenerationTask {
|
|
836
|
+
constructor() {
|
|
837
|
+
super("nebius", NEBIUS_API_BASE_URL);
|
|
838
|
+
}
|
|
308
839
|
};
|
|
309
|
-
var
|
|
310
|
-
|
|
311
|
-
|
|
840
|
+
var NebiusTextToImageTask = class extends TaskProviderHelper {
|
|
841
|
+
constructor() {
|
|
842
|
+
super("nebius", NEBIUS_API_BASE_URL);
|
|
312
843
|
}
|
|
313
|
-
|
|
314
|
-
return
|
|
844
|
+
preparePayload(params) {
|
|
845
|
+
return {
|
|
846
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
847
|
+
...params.args.parameters,
|
|
848
|
+
response_format: "b64_json",
|
|
849
|
+
prompt: params.args.inputs,
|
|
850
|
+
model: params.model
|
|
851
|
+
};
|
|
315
852
|
}
|
|
316
|
-
|
|
317
|
-
return
|
|
853
|
+
makeRoute(params) {
|
|
854
|
+
return "v1/images/generations";
|
|
855
|
+
}
|
|
856
|
+
async getResponse(response, url, headers, outputType) {
|
|
857
|
+
if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
|
|
858
|
+
const base64Data = response.data[0].b64_json;
|
|
859
|
+
if (outputType === "url") {
|
|
860
|
+
return `data:image/jpeg;base64,${base64Data}`;
|
|
861
|
+
}
|
|
862
|
+
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
|
|
863
|
+
}
|
|
864
|
+
throw new InferenceOutputError("Expected Nebius text-to-image response format");
|
|
318
865
|
}
|
|
319
|
-
return params.baseUrl;
|
|
320
|
-
};
|
|
321
|
-
var NEBIUS_CONFIG = {
|
|
322
|
-
makeBaseUrl: makeBaseUrl8,
|
|
323
|
-
makeBody: makeBody8,
|
|
324
|
-
makeHeaders: makeHeaders8,
|
|
325
|
-
makeUrl: makeUrl8
|
|
326
866
|
};
|
|
327
867
|
|
|
328
868
|
// src/providers/novita.ts
|
|
329
869
|
var NOVITA_API_BASE_URL = "https://api.novita.ai";
|
|
330
|
-
var
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
};
|
|
338
|
-
};
|
|
339
|
-
var makeHeaders9 = (params) => {
|
|
340
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
870
|
+
var NovitaTextGenerationTask = class extends BaseTextGenerationTask {
|
|
871
|
+
constructor() {
|
|
872
|
+
super("novita", NOVITA_API_BASE_URL);
|
|
873
|
+
}
|
|
874
|
+
makeRoute() {
|
|
875
|
+
return "/v3/openai/chat/completions";
|
|
876
|
+
}
|
|
341
877
|
};
|
|
342
|
-
var
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
}
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
return `${params.baseUrl}/v3/hf/${params.model}`;
|
|
878
|
+
var NovitaConversationalTask = class extends BaseConversationalTask {
|
|
879
|
+
constructor() {
|
|
880
|
+
super("novita", NOVITA_API_BASE_URL);
|
|
881
|
+
}
|
|
882
|
+
makeRoute() {
|
|
883
|
+
return "/v3/openai/chat/completions";
|
|
349
884
|
}
|
|
350
|
-
return params.baseUrl;
|
|
351
885
|
};
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
886
|
+
|
|
887
|
+
// src/providers/openai.ts
|
|
888
|
+
var OPENAI_API_BASE_URL = "https://api.openai.com";
|
|
889
|
+
var OpenAIConversationalTask = class extends BaseConversationalTask {
|
|
890
|
+
constructor() {
|
|
891
|
+
super("openai", OPENAI_API_BASE_URL, true);
|
|
892
|
+
}
|
|
357
893
|
};
|
|
358
894
|
|
|
359
895
|
// src/providers/replicate.ts
|
|
360
|
-
var
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
}
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
896
|
+
var ReplicateTask = class extends TaskProviderHelper {
|
|
897
|
+
constructor(url) {
|
|
898
|
+
super("replicate", url || "https://api.replicate.com");
|
|
899
|
+
}
|
|
900
|
+
makeRoute(params) {
|
|
901
|
+
if (params.model.includes(":")) {
|
|
902
|
+
return "v1/predictions";
|
|
903
|
+
}
|
|
904
|
+
return `v1/models/${params.model}/predictions`;
|
|
905
|
+
}
|
|
906
|
+
preparePayload(params) {
|
|
907
|
+
return {
|
|
908
|
+
input: {
|
|
909
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
910
|
+
...params.args.parameters,
|
|
911
|
+
prompt: params.args.inputs
|
|
912
|
+
},
|
|
913
|
+
version: params.model.includes(":") ? params.model.split(":")[1] : void 0
|
|
914
|
+
};
|
|
915
|
+
}
|
|
916
|
+
prepareHeaders(params, binary) {
|
|
917
|
+
const headers = { Authorization: `Bearer ${params.accessToken}`, Prefer: "wait" };
|
|
918
|
+
if (!binary) {
|
|
919
|
+
headers["Content-Type"] = "application/json";
|
|
920
|
+
}
|
|
921
|
+
return headers;
|
|
922
|
+
}
|
|
923
|
+
makeUrl(params) {
|
|
924
|
+
const baseUrl = this.makeBaseUrl(params);
|
|
925
|
+
if (params.model.includes(":")) {
|
|
926
|
+
return `${baseUrl}/v1/predictions`;
|
|
927
|
+
}
|
|
928
|
+
return `${baseUrl}/v1/models/${params.model}/predictions`;
|
|
929
|
+
}
|
|
369
930
|
};
|
|
370
|
-
var
|
|
371
|
-
|
|
931
|
+
var ReplicateTextToImageTask = class extends ReplicateTask {
|
|
932
|
+
async getResponse(res, url, headers, outputType) {
|
|
933
|
+
if (typeof res === "object" && "output" in res && Array.isArray(res.output) && res.output.length > 0 && typeof res.output[0] === "string") {
|
|
934
|
+
if (outputType === "url") {
|
|
935
|
+
return res.output[0];
|
|
936
|
+
}
|
|
937
|
+
const urlResponse = await fetch(res.output[0]);
|
|
938
|
+
return await urlResponse.blob();
|
|
939
|
+
}
|
|
940
|
+
throw new InferenceOutputError("Expected Replicate text-to-image response format");
|
|
941
|
+
}
|
|
372
942
|
};
|
|
373
|
-
var
|
|
374
|
-
|
|
375
|
-
|
|
943
|
+
var ReplicateTextToSpeechTask = class extends ReplicateTask {
|
|
944
|
+
preparePayload(params) {
|
|
945
|
+
const payload = super.preparePayload(params);
|
|
946
|
+
const input = payload["input"];
|
|
947
|
+
if (typeof input === "object" && input !== null && "prompt" in input) {
|
|
948
|
+
const inputObj = input;
|
|
949
|
+
inputObj["text"] = inputObj["prompt"];
|
|
950
|
+
delete inputObj["prompt"];
|
|
951
|
+
}
|
|
952
|
+
return payload;
|
|
953
|
+
}
|
|
954
|
+
async getResponse(response) {
|
|
955
|
+
if (response instanceof Blob) {
|
|
956
|
+
return response;
|
|
957
|
+
}
|
|
958
|
+
if (response && typeof response === "object") {
|
|
959
|
+
if ("output" in response) {
|
|
960
|
+
if (typeof response.output === "string") {
|
|
961
|
+
const urlResponse = await fetch(response.output);
|
|
962
|
+
return await urlResponse.blob();
|
|
963
|
+
} else if (Array.isArray(response.output)) {
|
|
964
|
+
const urlResponse = await fetch(response.output[0]);
|
|
965
|
+
return await urlResponse.blob();
|
|
966
|
+
}
|
|
967
|
+
}
|
|
968
|
+
}
|
|
969
|
+
throw new InferenceOutputError("Expected Blob or object with output");
|
|
376
970
|
}
|
|
377
|
-
return `${params.baseUrl}/v1/models/${params.model}/predictions`;
|
|
378
971
|
};
|
|
379
|
-
var
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
972
|
+
var ReplicateTextToVideoTask = class extends ReplicateTask {
|
|
973
|
+
async getResponse(response) {
|
|
974
|
+
if (typeof response === "object" && !!response && "output" in response && typeof response.output === "string" && isUrl(response.output)) {
|
|
975
|
+
const urlResponse = await fetch(response.output);
|
|
976
|
+
return await urlResponse.blob();
|
|
977
|
+
}
|
|
978
|
+
throw new InferenceOutputError("Expected { output: string }");
|
|
979
|
+
}
|
|
384
980
|
};
|
|
385
981
|
|
|
386
982
|
// src/providers/sambanova.ts
|
|
387
|
-
var
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
};
|
|
391
|
-
var makeBody11 = (params) => {
|
|
392
|
-
return {
|
|
393
|
-
...params.args,
|
|
394
|
-
...params.chatCompletion ? { model: params.model } : void 0
|
|
395
|
-
};
|
|
396
|
-
};
|
|
397
|
-
var makeHeaders11 = (params) => {
|
|
398
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
399
|
-
};
|
|
400
|
-
var makeUrl11 = (params) => {
|
|
401
|
-
if (params.chatCompletion) {
|
|
402
|
-
return `${params.baseUrl}/v1/chat/completions`;
|
|
983
|
+
var SambanovaConversationalTask = class extends BaseConversationalTask {
|
|
984
|
+
constructor() {
|
|
985
|
+
super("sambanova", "https://api.sambanova.ai");
|
|
403
986
|
}
|
|
404
|
-
return params.baseUrl;
|
|
405
|
-
};
|
|
406
|
-
var SAMBANOVA_CONFIG = {
|
|
407
|
-
makeBaseUrl: makeBaseUrl11,
|
|
408
|
-
makeBody: makeBody11,
|
|
409
|
-
makeHeaders: makeHeaders11,
|
|
410
|
-
makeUrl: makeUrl11
|
|
411
987
|
};
|
|
412
988
|
|
|
413
989
|
// src/providers/together.ts
|
|
414
990
|
var TOGETHER_API_BASE_URL = "https://api.together.xyz";
|
|
415
|
-
var
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
return {
|
|
420
|
-
...params.args,
|
|
421
|
-
model: params.model
|
|
422
|
-
};
|
|
423
|
-
};
|
|
424
|
-
var makeHeaders12 = (params) => {
|
|
425
|
-
return { Authorization: `Bearer ${params.accessToken}` };
|
|
991
|
+
var TogetherConversationalTask = class extends BaseConversationalTask {
|
|
992
|
+
constructor() {
|
|
993
|
+
super("together", TOGETHER_API_BASE_URL);
|
|
994
|
+
}
|
|
426
995
|
};
|
|
427
|
-
var
|
|
428
|
-
|
|
429
|
-
|
|
996
|
+
var TogetherTextGenerationTask = class extends BaseTextGenerationTask {
|
|
997
|
+
constructor() {
|
|
998
|
+
super("together", TOGETHER_API_BASE_URL);
|
|
430
999
|
}
|
|
431
|
-
|
|
432
|
-
return
|
|
1000
|
+
preparePayload(params) {
|
|
1001
|
+
return {
|
|
1002
|
+
model: params.model,
|
|
1003
|
+
...params.args,
|
|
1004
|
+
prompt: params.args.inputs
|
|
1005
|
+
};
|
|
433
1006
|
}
|
|
434
|
-
|
|
435
|
-
|
|
1007
|
+
async getResponse(response) {
|
|
1008
|
+
if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
|
|
1009
|
+
const completion = response.choices[0];
|
|
1010
|
+
return {
|
|
1011
|
+
generated_text: completion.text
|
|
1012
|
+
};
|
|
1013
|
+
}
|
|
1014
|
+
throw new InferenceOutputError("Expected Together text generation response format");
|
|
436
1015
|
}
|
|
437
|
-
return params.baseUrl;
|
|
438
1016
|
};
|
|
439
|
-
var
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
1017
|
+
var TogetherTextToImageTask = class extends TaskProviderHelper {
|
|
1018
|
+
constructor() {
|
|
1019
|
+
super("together", TOGETHER_API_BASE_URL);
|
|
1020
|
+
}
|
|
1021
|
+
makeRoute() {
|
|
1022
|
+
return "v1/images/generations";
|
|
1023
|
+
}
|
|
1024
|
+
preparePayload(params) {
|
|
1025
|
+
return {
|
|
1026
|
+
...omit(params.args, ["inputs", "parameters"]),
|
|
1027
|
+
...params.args.parameters,
|
|
1028
|
+
prompt: params.args.inputs,
|
|
1029
|
+
response_format: "base64",
|
|
1030
|
+
model: params.model
|
|
1031
|
+
};
|
|
1032
|
+
}
|
|
1033
|
+
async getResponse(response, outputType) {
|
|
1034
|
+
if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
|
|
1035
|
+
const base64Data = response.data[0].b64_json;
|
|
1036
|
+
if (outputType === "url") {
|
|
1037
|
+
return `data:image/jpeg;base64,${base64Data}`;
|
|
1038
|
+
}
|
|
1039
|
+
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
|
|
1040
|
+
}
|
|
1041
|
+
throw new InferenceOutputError("Expected Together text-to-image response format");
|
|
1042
|
+
}
|
|
444
1043
|
};
|
|
445
1044
|
|
|
446
|
-
// src/
|
|
447
|
-
var
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
}
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
1045
|
+
// src/lib/getProviderHelper.ts
|
|
1046
|
+
var PROVIDERS = {
|
|
1047
|
+
"black-forest-labs": {
|
|
1048
|
+
"text-to-image": new BlackForestLabsTextToImageTask()
|
|
1049
|
+
},
|
|
1050
|
+
cerebras: {
|
|
1051
|
+
conversational: new CerebrasConversationalTask()
|
|
1052
|
+
},
|
|
1053
|
+
cohere: {
|
|
1054
|
+
conversational: new CohereConversationalTask()
|
|
1055
|
+
},
|
|
1056
|
+
"fal-ai": {
|
|
1057
|
+
"text-to-image": new FalAITextToImageTask(),
|
|
1058
|
+
"text-to-speech": new FalAITextToSpeechTask(),
|
|
1059
|
+
"text-to-video": new FalAITextToVideoTask(),
|
|
1060
|
+
"automatic-speech-recognition": new FalAIAutomaticSpeechRecognitionTask()
|
|
1061
|
+
},
|
|
1062
|
+
"hf-inference": {
|
|
1063
|
+
"text-to-image": new HFInferenceTextToImageTask(),
|
|
1064
|
+
conversational: new HFInferenceConversationalTask(),
|
|
1065
|
+
"text-generation": new HFInferenceTextGenerationTask(),
|
|
1066
|
+
"text-classification": new HFInferenceTextClassificationTask(),
|
|
1067
|
+
"question-answering": new HFInferenceQuestionAnsweringTask(),
|
|
1068
|
+
"audio-classification": new HFInferenceAudioClassificationTask(),
|
|
1069
|
+
"automatic-speech-recognition": new HFInferenceAutomaticSpeechRecognitionTask(),
|
|
1070
|
+
"fill-mask": new HFInferenceFillMaskTask(),
|
|
1071
|
+
"feature-extraction": new HFInferenceFeatureExtractionTask(),
|
|
1072
|
+
"image-classification": new HFInferenceImageClassificationTask(),
|
|
1073
|
+
"image-segmentation": new HFInferenceImageSegmentationTask(),
|
|
1074
|
+
"document-question-answering": new HFInferenceDocumentQuestionAnsweringTask(),
|
|
1075
|
+
"image-to-text": new HFInferenceImageToTextTask(),
|
|
1076
|
+
"object-detection": new HFInferenceObjectDetectionTask(),
|
|
1077
|
+
"audio-to-audio": new HFInferenceAudioToAudioTask(),
|
|
1078
|
+
"zero-shot-image-classification": new HFInferenceZeroShotImageClassificationTask(),
|
|
1079
|
+
"zero-shot-classification": new HFInferenceZeroShotClassificationTask(),
|
|
1080
|
+
"image-to-image": new HFInferenceImageToImageTask(),
|
|
1081
|
+
"sentence-similarity": new HFInferenceSentenceSimilarityTask(),
|
|
1082
|
+
"table-question-answering": new HFInferenceTableQuestionAnsweringTask(),
|
|
1083
|
+
"tabular-classification": new HFInferenceTabularClassificationTask(),
|
|
1084
|
+
"text-to-speech": new HFInferenceTextToSpeechTask(),
|
|
1085
|
+
"token-classification": new HFInferenceTokenClassificationTask(),
|
|
1086
|
+
translation: new HFInferenceTranslationTask(),
|
|
1087
|
+
summarization: new HFInferenceSummarizationTask(),
|
|
1088
|
+
"visual-question-answering": new HFInferenceVisualQuestionAnsweringTask(),
|
|
1089
|
+
"tabular-regression": new HFInferenceTabularRegressionTask(),
|
|
1090
|
+
"text-to-audio": new HFInferenceTextToAudioTask()
|
|
1091
|
+
},
|
|
1092
|
+
"fireworks-ai": {
|
|
1093
|
+
conversational: new FireworksConversationalTask()
|
|
1094
|
+
},
|
|
1095
|
+
hyperbolic: {
|
|
1096
|
+
"text-to-image": new HyperbolicTextToImageTask(),
|
|
1097
|
+
conversational: new HyperbolicConversationalTask(),
|
|
1098
|
+
"text-generation": new HyperbolicTextGenerationTask()
|
|
1099
|
+
},
|
|
1100
|
+
nebius: {
|
|
1101
|
+
"text-to-image": new NebiusTextToImageTask(),
|
|
1102
|
+
conversational: new NebiusConversationalTask(),
|
|
1103
|
+
"text-generation": new NebiusTextGenerationTask()
|
|
1104
|
+
},
|
|
1105
|
+
novita: {
|
|
1106
|
+
conversational: new NovitaConversationalTask(),
|
|
1107
|
+
"text-generation": new NovitaTextGenerationTask()
|
|
1108
|
+
},
|
|
1109
|
+
openai: {
|
|
1110
|
+
conversational: new OpenAIConversationalTask()
|
|
1111
|
+
},
|
|
1112
|
+
replicate: {
|
|
1113
|
+
"text-to-image": new ReplicateTextToImageTask(),
|
|
1114
|
+
"text-to-speech": new ReplicateTextToSpeechTask(),
|
|
1115
|
+
"text-to-video": new ReplicateTextToVideoTask()
|
|
1116
|
+
},
|
|
1117
|
+
sambanova: {
|
|
1118
|
+
conversational: new SambanovaConversationalTask()
|
|
1119
|
+
},
|
|
1120
|
+
together: {
|
|
1121
|
+
"text-to-image": new TogetherTextToImageTask(),
|
|
1122
|
+
conversational: new TogetherConversationalTask(),
|
|
1123
|
+
"text-generation": new TogetherTextGenerationTask()
|
|
454
1124
|
}
|
|
455
|
-
return {
|
|
456
|
-
...params.args,
|
|
457
|
-
model: params.model
|
|
458
|
-
};
|
|
459
1125
|
};
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
throw new Error("OpenAI only supports chat completions.");
|
|
1126
|
+
function getProviderHelper(provider, task) {
|
|
1127
|
+
if (provider === "hf-inference") {
|
|
1128
|
+
if (!task) {
|
|
1129
|
+
return new HFInferenceTask();
|
|
1130
|
+
}
|
|
466
1131
|
}
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
1132
|
+
if (!task) {
|
|
1133
|
+
throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'");
|
|
1134
|
+
}
|
|
1135
|
+
if (!(provider in PROVIDERS)) {
|
|
1136
|
+
throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
|
|
1137
|
+
}
|
|
1138
|
+
const providerTasks = PROVIDERS[provider];
|
|
1139
|
+
if (!providerTasks || !(task in providerTasks)) {
|
|
1140
|
+
throw new Error(
|
|
1141
|
+
`Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}`
|
|
1142
|
+
);
|
|
1143
|
+
}
|
|
1144
|
+
return providerTasks[task];
|
|
1145
|
+
}
|
|
476
1146
|
|
|
477
1147
|
// package.json
|
|
478
1148
|
var name = "@huggingface/inference";
|
|
479
|
-
var version = "3.
|
|
1149
|
+
var version = "3.8.0";
|
|
480
1150
|
|
|
481
1151
|
// src/providers/consts.ts
|
|
482
|
-
var
|
|
1152
|
+
var HARDCODED_MODEL_INFERENCE_MAPPING = {
|
|
483
1153
|
/**
|
|
484
1154
|
* "HF model ID" => "Model ID on Inference Provider's side"
|
|
485
1155
|
*
|
|
@@ -501,106 +1171,127 @@ var HARDCODED_MODEL_ID_MAPPING = {
|
|
|
501
1171
|
together: {}
|
|
502
1172
|
};
|
|
503
1173
|
|
|
504
|
-
// src/lib/
|
|
1174
|
+
// src/lib/getInferenceProviderMapping.ts
|
|
505
1175
|
var inferenceProviderMappingCache = /* @__PURE__ */ new Map();
|
|
506
|
-
async function
|
|
507
|
-
if (params.provider
|
|
508
|
-
return params.
|
|
509
|
-
}
|
|
510
|
-
if (!options.task) {
|
|
511
|
-
throw new Error("task must be specified when using a third-party provider");
|
|
512
|
-
}
|
|
513
|
-
const task = options.task === "text-generation" && options.chatCompletion ? "conversational" : options.task;
|
|
514
|
-
if (HARDCODED_MODEL_ID_MAPPING[params.provider]?.[params.model]) {
|
|
515
|
-
return HARDCODED_MODEL_ID_MAPPING[params.provider][params.model];
|
|
1176
|
+
async function getInferenceProviderMapping(params, options) {
|
|
1177
|
+
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
|
|
1178
|
+
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
|
|
516
1179
|
}
|
|
517
1180
|
let inferenceProviderMapping;
|
|
518
|
-
if (inferenceProviderMappingCache.has(params.
|
|
519
|
-
inferenceProviderMapping = inferenceProviderMappingCache.get(params.
|
|
1181
|
+
if (inferenceProviderMappingCache.has(params.modelId)) {
|
|
1182
|
+
inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId);
|
|
520
1183
|
} else {
|
|
521
|
-
|
|
522
|
-
`${HF_HUB_URL}/api/models/${params.
|
|
1184
|
+
const resp = await (options?.fetch ?? fetch)(
|
|
1185
|
+
`${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`,
|
|
523
1186
|
{
|
|
524
|
-
headers:
|
|
1187
|
+
headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {}
|
|
525
1188
|
}
|
|
526
|
-
)
|
|
1189
|
+
);
|
|
1190
|
+
if (resp.status === 404) {
|
|
1191
|
+
throw new Error(`Model ${params.modelId} does not exist`);
|
|
1192
|
+
}
|
|
1193
|
+
inferenceProviderMapping = await resp.json().then((json) => json.inferenceProviderMapping).catch(() => null);
|
|
527
1194
|
}
|
|
528
1195
|
if (!inferenceProviderMapping) {
|
|
529
|
-
throw new Error(`We have not been able to find inference provider information for model ${params.
|
|
1196
|
+
throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`);
|
|
530
1197
|
}
|
|
531
1198
|
const providerMapping = inferenceProviderMapping[params.provider];
|
|
532
1199
|
if (providerMapping) {
|
|
533
|
-
|
|
1200
|
+
const equivalentTasks = params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task];
|
|
1201
|
+
if (!typedInclude(equivalentTasks, providerMapping.task)) {
|
|
534
1202
|
throw new Error(
|
|
535
|
-
`Model ${params.
|
|
1203
|
+
`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
|
|
536
1204
|
);
|
|
537
1205
|
}
|
|
538
1206
|
if (providerMapping.status === "staging") {
|
|
539
1207
|
console.warn(
|
|
540
|
-
`Model ${params.
|
|
1208
|
+
`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
|
|
541
1209
|
);
|
|
542
1210
|
}
|
|
543
|
-
|
|
1211
|
+
if (providerMapping.adapter === "lora") {
|
|
1212
|
+
const treeResp = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${params.modelId}/tree/main`);
|
|
1213
|
+
if (!treeResp.ok) {
|
|
1214
|
+
throw new Error(`Unable to fetch the model tree for ${params.modelId}.`);
|
|
1215
|
+
}
|
|
1216
|
+
const tree = await treeResp.json();
|
|
1217
|
+
const adapterWeightsPath = tree.find(({ type, path }) => type === "file" && path.endsWith(".safetensors"))?.path;
|
|
1218
|
+
if (!adapterWeightsPath) {
|
|
1219
|
+
throw new Error(`No .safetensors file found in the model tree for ${params.modelId}.`);
|
|
1220
|
+
}
|
|
1221
|
+
return {
|
|
1222
|
+
...providerMapping,
|
|
1223
|
+
hfModelId: params.modelId,
|
|
1224
|
+
adapterWeightsPath
|
|
1225
|
+
};
|
|
1226
|
+
}
|
|
1227
|
+
return { ...providerMapping, hfModelId: params.modelId };
|
|
544
1228
|
}
|
|
545
|
-
|
|
1229
|
+
return null;
|
|
546
1230
|
}
|
|
547
1231
|
|
|
548
1232
|
// src/lib/makeRequestOptions.ts
|
|
549
|
-
var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
|
|
550
1233
|
var tasks = null;
|
|
551
|
-
|
|
552
|
-
"black-forest-labs": BLACK_FOREST_LABS_CONFIG,
|
|
553
|
-
cerebras: CEREBRAS_CONFIG,
|
|
554
|
-
cohere: COHERE_CONFIG,
|
|
555
|
-
"fal-ai": FAL_AI_CONFIG,
|
|
556
|
-
"fireworks-ai": FIREWORKS_AI_CONFIG,
|
|
557
|
-
"hf-inference": HF_INFERENCE_CONFIG,
|
|
558
|
-
hyperbolic: HYPERBOLIC_CONFIG,
|
|
559
|
-
openai: OPENAI_CONFIG,
|
|
560
|
-
nebius: NEBIUS_CONFIG,
|
|
561
|
-
novita: NOVITA_CONFIG,
|
|
562
|
-
replicate: REPLICATE_CONFIG,
|
|
563
|
-
sambanova: SAMBANOVA_CONFIG,
|
|
564
|
-
together: TOGETHER_CONFIG
|
|
565
|
-
};
|
|
566
|
-
async function makeRequestOptions(args, options) {
|
|
1234
|
+
async function makeRequestOptions(args, providerHelper, options) {
|
|
567
1235
|
const { provider: maybeProvider, model: maybeModel } = args;
|
|
568
1236
|
const provider = maybeProvider ?? "hf-inference";
|
|
569
|
-
const
|
|
570
|
-
const { task, chatCompletion: chatCompletion2 } = options ?? {};
|
|
1237
|
+
const { task } = options ?? {};
|
|
571
1238
|
if (args.endpointUrl && provider !== "hf-inference") {
|
|
572
1239
|
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
|
|
573
1240
|
}
|
|
574
1241
|
if (maybeModel && isUrl(maybeModel)) {
|
|
575
1242
|
throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
|
|
576
1243
|
}
|
|
1244
|
+
if (args.endpointUrl) {
|
|
1245
|
+
return makeRequestOptionsFromResolvedModel(
|
|
1246
|
+
maybeModel ?? args.endpointUrl,
|
|
1247
|
+
providerHelper,
|
|
1248
|
+
args,
|
|
1249
|
+
void 0,
|
|
1250
|
+
options
|
|
1251
|
+
);
|
|
1252
|
+
}
|
|
577
1253
|
if (!maybeModel && !task) {
|
|
578
1254
|
throw new Error("No model provided, and no task has been specified.");
|
|
579
1255
|
}
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
}
|
|
583
|
-
if (providerConfig.clientSideRoutingOnly && !maybeModel) {
|
|
1256
|
+
const hfModel = maybeModel ?? await loadDefaultModel(task);
|
|
1257
|
+
if (providerHelper.clientSideRoutingOnly && !maybeModel) {
|
|
584
1258
|
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
|
|
585
1259
|
}
|
|
586
|
-
const
|
|
587
|
-
const resolvedModel = providerConfig.clientSideRoutingOnly ? (
|
|
1260
|
+
const inferenceProviderMapping = providerHelper.clientSideRoutingOnly ? {
|
|
588
1261
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
589
|
-
removeProviderPrefix(maybeModel, provider)
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
1262
|
+
providerId: removeProviderPrefix(maybeModel, provider),
|
|
1263
|
+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
1264
|
+
hfModelId: maybeModel,
|
|
1265
|
+
status: "live",
|
|
1266
|
+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
1267
|
+
task
|
|
1268
|
+
} : await getInferenceProviderMapping(
|
|
1269
|
+
{
|
|
1270
|
+
modelId: hfModel,
|
|
1271
|
+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
1272
|
+
task,
|
|
1273
|
+
provider,
|
|
1274
|
+
accessToken: args.accessToken
|
|
1275
|
+
},
|
|
1276
|
+
{ fetch: options?.fetch }
|
|
1277
|
+
);
|
|
1278
|
+
if (!inferenceProviderMapping) {
|
|
1279
|
+
throw new Error(`We have not been able to find inference provider information for model ${hfModel}.`);
|
|
1280
|
+
}
|
|
1281
|
+
return makeRequestOptionsFromResolvedModel(
|
|
1282
|
+
inferenceProviderMapping.providerId,
|
|
1283
|
+
providerHelper,
|
|
1284
|
+
args,
|
|
1285
|
+
inferenceProviderMapping,
|
|
1286
|
+
options
|
|
1287
|
+
);
|
|
596
1288
|
}
|
|
597
|
-
function makeRequestOptionsFromResolvedModel(resolvedModel, args, options) {
|
|
1289
|
+
function makeRequestOptionsFromResolvedModel(resolvedModel, providerHelper, args, mapping, options) {
|
|
598
1290
|
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
|
|
599
1291
|
const provider = maybeProvider ?? "hf-inference";
|
|
600
|
-
const
|
|
601
|
-
const { includeCredentials, task, chatCompletion: chatCompletion2, signal, billTo } = options ?? {};
|
|
1292
|
+
const { includeCredentials, task, signal, billTo } = options ?? {};
|
|
602
1293
|
const authMethod = (() => {
|
|
603
|
-
if (
|
|
1294
|
+
if (providerHelper.clientSideRoutingOnly) {
|
|
604
1295
|
if (accessToken && accessToken.startsWith("hf_")) {
|
|
605
1296
|
throw new Error(`Provider ${provider} is closed-source and does not support HF tokens.`);
|
|
606
1297
|
}
|
|
@@ -614,35 +1305,31 @@ function makeRequestOptionsFromResolvedModel(resolvedModel, args, options) {
|
|
|
614
1305
|
}
|
|
615
1306
|
return "none";
|
|
616
1307
|
})();
|
|
617
|
-
const
|
|
1308
|
+
const modelId = endpointUrl ?? resolvedModel;
|
|
1309
|
+
const url = providerHelper.makeUrl({
|
|
618
1310
|
authMethod,
|
|
619
|
-
|
|
620
|
-
model: resolvedModel,
|
|
621
|
-
chatCompletion: chatCompletion2,
|
|
1311
|
+
model: modelId,
|
|
622
1312
|
task
|
|
623
1313
|
});
|
|
624
|
-
const
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
1314
|
+
const headers = providerHelper.prepareHeaders(
|
|
1315
|
+
{
|
|
1316
|
+
accessToken,
|
|
1317
|
+
authMethod
|
|
1318
|
+
},
|
|
1319
|
+
"data" in args && !!args.data
|
|
1320
|
+
);
|
|
629
1321
|
if (billTo) {
|
|
630
1322
|
headers[HF_HEADER_X_BILL_TO] = billTo;
|
|
631
1323
|
}
|
|
632
|
-
if (!binary) {
|
|
633
|
-
headers["Content-Type"] = "application/json";
|
|
634
|
-
}
|
|
635
1324
|
const ownUserAgent = `${name}/${version}`;
|
|
636
1325
|
const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : void 0].filter((x) => x !== void 0).join(" ");
|
|
637
1326
|
headers["User-Agent"] = userAgent;
|
|
638
|
-
const body =
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
})
|
|
645
|
-
);
|
|
1327
|
+
const body = providerHelper.makeBody({
|
|
1328
|
+
args: remainingArgs,
|
|
1329
|
+
model: resolvedModel,
|
|
1330
|
+
task,
|
|
1331
|
+
mapping
|
|
1332
|
+
});
|
|
646
1333
|
let credentials;
|
|
647
1334
|
if (typeof includeCredentials === "string") {
|
|
648
1335
|
credentials = includeCredentials;
|
|
@@ -782,12 +1469,12 @@ function newMessage() {
|
|
|
782
1469
|
}
|
|
783
1470
|
|
|
784
1471
|
// src/utils/request.ts
|
|
785
|
-
async function innerRequest(args, options) {
|
|
786
|
-
const { url, info } = await makeRequestOptions(args, options);
|
|
1472
|
+
async function innerRequest(args, providerHelper, options) {
|
|
1473
|
+
const { url, info } = await makeRequestOptions(args, providerHelper, options);
|
|
787
1474
|
const response = await (options?.fetch ?? fetch)(url, info);
|
|
788
1475
|
const requestContext = { url, info };
|
|
789
1476
|
if (options?.retry_on_error !== false && response.status === 503) {
|
|
790
|
-
return innerRequest(args, options);
|
|
1477
|
+
return innerRequest(args, providerHelper, options);
|
|
791
1478
|
}
|
|
792
1479
|
if (!response.ok) {
|
|
793
1480
|
const contentType = response.headers.get("Content-Type");
|
|
@@ -814,11 +1501,11 @@ async function innerRequest(args, options) {
|
|
|
814
1501
|
const blob = await response.blob();
|
|
815
1502
|
return { data: blob, requestContext };
|
|
816
1503
|
}
|
|
817
|
-
async function* innerStreamingRequest(args, options) {
|
|
818
|
-
const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
|
|
1504
|
+
async function* innerStreamingRequest(args, providerHelper, options) {
|
|
1505
|
+
const { url, info } = await makeRequestOptions({ ...args, stream: true }, providerHelper, options);
|
|
819
1506
|
const response = await (options?.fetch ?? fetch)(url, info);
|
|
820
1507
|
if (options?.retry_on_error !== false && response.status === 503) {
|
|
821
|
-
return yield* innerStreamingRequest(args, options);
|
|
1508
|
+
return yield* innerStreamingRequest(args, providerHelper, options);
|
|
822
1509
|
}
|
|
823
1510
|
if (!response.ok) {
|
|
824
1511
|
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
|
|
@@ -892,7 +1579,8 @@ async function request(args, options) {
|
|
|
892
1579
|
console.warn(
|
|
893
1580
|
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
|
|
894
1581
|
);
|
|
895
|
-
const
|
|
1582
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
|
|
1583
|
+
const result = await innerRequest(args, providerHelper, options);
|
|
896
1584
|
return result.data;
|
|
897
1585
|
}
|
|
898
1586
|
|
|
@@ -901,31 +1589,8 @@ async function* streamingRequest(args, options) {
|
|
|
901
1589
|
console.warn(
|
|
902
1590
|
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
|
|
903
1591
|
);
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
// src/utils/pick.ts
|
|
908
|
-
function pick(o, props) {
|
|
909
|
-
return Object.assign(
|
|
910
|
-
{},
|
|
911
|
-
...props.map((prop) => {
|
|
912
|
-
if (o[prop] !== void 0) {
|
|
913
|
-
return { [prop]: o[prop] };
|
|
914
|
-
}
|
|
915
|
-
})
|
|
916
|
-
);
|
|
917
|
-
}
|
|
918
|
-
|
|
919
|
-
// src/utils/typedInclude.ts
|
|
920
|
-
function typedInclude(arr, v) {
|
|
921
|
-
return arr.includes(v);
|
|
922
|
-
}
|
|
923
|
-
|
|
924
|
-
// src/utils/omit.ts
|
|
925
|
-
function omit(o, props) {
|
|
926
|
-
const propsArr = Array.isArray(props) ? props : [props];
|
|
927
|
-
const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
|
|
928
|
-
return pick(o, letsKeep);
|
|
1592
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
|
|
1593
|
+
yield* innerStreamingRequest(args, providerHelper, options);
|
|
929
1594
|
}
|
|
930
1595
|
|
|
931
1596
|
// src/tasks/audio/utils.ts
|
|
@@ -938,16 +1603,24 @@ function preparePayload(args) {
|
|
|
938
1603
|
|
|
939
1604
|
// src/tasks/audio/audioClassification.ts
|
|
940
1605
|
async function audioClassification(args, options) {
|
|
1606
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
|
|
941
1607
|
const payload = preparePayload(args);
|
|
942
|
-
const { data: res } = await innerRequest(payload, {
|
|
1608
|
+
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
943
1609
|
...options,
|
|
944
1610
|
task: "audio-classification"
|
|
945
1611
|
});
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
1612
|
+
return providerHelper.getResponse(res);
|
|
1613
|
+
}
|
|
1614
|
+
|
|
1615
|
+
// src/tasks/audio/audioToAudio.ts
|
|
1616
|
+
async function audioToAudio(args, options) {
|
|
1617
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
|
|
1618
|
+
const payload = preparePayload(args);
|
|
1619
|
+
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1620
|
+
...options,
|
|
1621
|
+
task: "audio-to-audio"
|
|
1622
|
+
});
|
|
1623
|
+
return providerHelper.getResponse(res);
|
|
951
1624
|
}
|
|
952
1625
|
|
|
953
1626
|
// src/utils/base64FromBytes.ts
|
|
@@ -965,8 +1638,9 @@ function base64FromBytes(arr) {
|
|
|
965
1638
|
|
|
966
1639
|
// src/tasks/audio/automaticSpeechRecognition.ts
|
|
967
1640
|
async function automaticSpeechRecognition(args, options) {
|
|
1641
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
|
|
968
1642
|
const payload = await buildPayload(args);
|
|
969
|
-
const { data: res } = await innerRequest(payload, {
|
|
1643
|
+
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
970
1644
|
...options,
|
|
971
1645
|
task: "automatic-speech-recognition"
|
|
972
1646
|
});
|
|
@@ -974,9 +1648,8 @@ async function automaticSpeechRecognition(args, options) {
|
|
|
974
1648
|
if (!isValidOutput) {
|
|
975
1649
|
throw new InferenceOutputError("Expected {text: string}");
|
|
976
1650
|
}
|
|
977
|
-
return res;
|
|
1651
|
+
return providerHelper.getResponse(res);
|
|
978
1652
|
}
|
|
979
|
-
var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
|
|
980
1653
|
async function buildPayload(args) {
|
|
981
1654
|
if (args.provider === "fal-ai") {
|
|
982
1655
|
const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : void 0;
|
|
@@ -1005,215 +1678,45 @@ async function buildPayload(args) {
|
|
|
1005
1678
|
|
|
1006
1679
|
// src/tasks/audio/textToSpeech.ts
|
|
1007
1680
|
async function textToSpeech(args, options) {
|
|
1008
|
-
const
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
text: args.inputs
|
|
1012
|
-
} : args;
|
|
1013
|
-
const { data: res } = await innerRequest(payload, {
|
|
1681
|
+
const provider = args.provider ?? "hf-inference";
|
|
1682
|
+
const providerHelper = getProviderHelper(provider, "text-to-speech");
|
|
1683
|
+
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1014
1684
|
...options,
|
|
1015
1685
|
task: "text-to-speech"
|
|
1016
1686
|
});
|
|
1017
|
-
|
|
1018
|
-
return res;
|
|
1019
|
-
}
|
|
1020
|
-
if (res && typeof res === "object") {
|
|
1021
|
-
if ("output" in res) {
|
|
1022
|
-
if (typeof res.output === "string") {
|
|
1023
|
-
const urlResponse = await fetch(res.output);
|
|
1024
|
-
const blob = await urlResponse.blob();
|
|
1025
|
-
return blob;
|
|
1026
|
-
} else if (Array.isArray(res.output)) {
|
|
1027
|
-
const urlResponse = await fetch(res.output[0]);
|
|
1028
|
-
const blob = await urlResponse.blob();
|
|
1029
|
-
return blob;
|
|
1030
|
-
}
|
|
1031
|
-
}
|
|
1032
|
-
}
|
|
1033
|
-
throw new InferenceOutputError("Expected Blob or object with output");
|
|
1034
|
-
}
|
|
1035
|
-
|
|
1036
|
-
// src/tasks/audio/audioToAudio.ts
|
|
1037
|
-
async function audioToAudio(args, options) {
|
|
1038
|
-
const payload = preparePayload(args);
|
|
1039
|
-
const { data: res } = await innerRequest(payload, {
|
|
1040
|
-
...options,
|
|
1041
|
-
task: "audio-to-audio"
|
|
1042
|
-
});
|
|
1043
|
-
return validateOutput(res);
|
|
1044
|
-
}
|
|
1045
|
-
function validateOutput(output) {
|
|
1046
|
-
if (!Array.isArray(output)) {
|
|
1047
|
-
throw new InferenceOutputError("Expected Array");
|
|
1048
|
-
}
|
|
1049
|
-
if (!output.every((elem) => {
|
|
1050
|
-
return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
|
|
1051
|
-
})) {
|
|
1052
|
-
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
|
|
1053
|
-
}
|
|
1054
|
-
return output;
|
|
1055
|
-
}
|
|
1056
|
-
|
|
1057
|
-
// src/tasks/cv/utils.ts
|
|
1058
|
-
function preparePayload2(args) {
|
|
1059
|
-
return "data" in args ? args : { ...omit(args, "inputs"), data: args.inputs };
|
|
1060
|
-
}
|
|
1061
|
-
|
|
1062
|
-
// src/tasks/cv/imageClassification.ts
|
|
1063
|
-
async function imageClassification(args, options) {
|
|
1064
|
-
const payload = preparePayload2(args);
|
|
1065
|
-
const { data: res } = await innerRequest(payload, {
|
|
1066
|
-
...options,
|
|
1067
|
-
task: "image-classification"
|
|
1068
|
-
});
|
|
1069
|
-
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
1070
|
-
if (!isValidOutput) {
|
|
1071
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
1072
|
-
}
|
|
1073
|
-
return res;
|
|
1074
|
-
}
|
|
1075
|
-
|
|
1076
|
-
// src/tasks/cv/imageSegmentation.ts
|
|
1077
|
-
async function imageSegmentation(args, options) {
|
|
1078
|
-
const payload = preparePayload2(args);
|
|
1079
|
-
const { data: res } = await innerRequest(payload, {
|
|
1080
|
-
...options,
|
|
1081
|
-
task: "image-segmentation"
|
|
1082
|
-
});
|
|
1083
|
-
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
|
|
1084
|
-
if (!isValidOutput) {
|
|
1085
|
-
throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
|
|
1086
|
-
}
|
|
1087
|
-
return res;
|
|
1687
|
+
return providerHelper.getResponse(res);
|
|
1088
1688
|
}
|
|
1089
|
-
|
|
1090
|
-
// src/tasks/cv/
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
const { data: res } = await innerRequest(payload, {
|
|
1094
|
-
...options,
|
|
1095
|
-
task: "image-to-text"
|
|
1096
|
-
});
|
|
1097
|
-
if (typeof res?.[0]?.generated_text !== "string") {
|
|
1098
|
-
throw new InferenceOutputError("Expected {generated_text: string}");
|
|
1099
|
-
}
|
|
1100
|
-
return res?.[0];
|
|
1689
|
+
|
|
1690
|
+
// src/tasks/cv/utils.ts
|
|
1691
|
+
function preparePayload2(args) {
|
|
1692
|
+
return "data" in args ? args : { ...omit(args, "inputs"), data: args.inputs };
|
|
1101
1693
|
}
|
|
1102
1694
|
|
|
1103
|
-
// src/tasks/cv/
|
|
1104
|
-
async function
|
|
1695
|
+
// src/tasks/cv/imageClassification.ts
|
|
1696
|
+
async function imageClassification(args, options) {
|
|
1697
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
|
|
1105
1698
|
const payload = preparePayload2(args);
|
|
1106
|
-
const { data: res } = await innerRequest(payload, {
|
|
1699
|
+
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1107
1700
|
...options,
|
|
1108
|
-
task: "
|
|
1701
|
+
task: "image-classification"
|
|
1109
1702
|
});
|
|
1110
|
-
|
|
1111
|
-
(x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
|
|
1112
|
-
);
|
|
1113
|
-
if (!isValidOutput) {
|
|
1114
|
-
throw new InferenceOutputError(
|
|
1115
|
-
"Expected Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
|
|
1116
|
-
);
|
|
1117
|
-
}
|
|
1118
|
-
return res;
|
|
1703
|
+
return providerHelper.getResponse(res);
|
|
1119
1704
|
}
|
|
1120
1705
|
|
|
1121
|
-
// src/tasks/cv/
|
|
1122
|
-
function
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
case "nebius":
|
|
1127
|
-
return { response_format: "b64_json" };
|
|
1128
|
-
case "replicate":
|
|
1129
|
-
return void 0;
|
|
1130
|
-
case "together":
|
|
1131
|
-
return { response_format: "base64" };
|
|
1132
|
-
default:
|
|
1133
|
-
return void 0;
|
|
1134
|
-
}
|
|
1135
|
-
}
|
|
1136
|
-
async function textToImage(args, options) {
|
|
1137
|
-
const payload = !args.provider || args.provider === "hf-inference" || args.provider === "sambanova" ? args : {
|
|
1138
|
-
...omit(args, ["inputs", "parameters"]),
|
|
1139
|
-
...args.parameters,
|
|
1140
|
-
...getResponseFormatArg(args.provider),
|
|
1141
|
-
prompt: args.inputs
|
|
1142
|
-
};
|
|
1143
|
-
const { data: res } = await innerRequest(payload, {
|
|
1706
|
+
// src/tasks/cv/imageSegmentation.ts
|
|
1707
|
+
async function imageSegmentation(args, options) {
|
|
1708
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
|
|
1709
|
+
const payload = preparePayload2(args);
|
|
1710
|
+
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1144
1711
|
...options,
|
|
1145
|
-
task: "
|
|
1712
|
+
task: "image-segmentation"
|
|
1146
1713
|
});
|
|
1147
|
-
|
|
1148
|
-
if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
|
|
1149
|
-
return await pollBflResponse(res.polling_url, options?.outputType);
|
|
1150
|
-
}
|
|
1151
|
-
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
|
|
1152
|
-
if (options?.outputType === "url") {
|
|
1153
|
-
return res.images[0].url;
|
|
1154
|
-
} else {
|
|
1155
|
-
const image = await fetch(res.images[0].url);
|
|
1156
|
-
return await image.blob();
|
|
1157
|
-
}
|
|
1158
|
-
}
|
|
1159
|
-
if (args.provider === "hyperbolic" && "images" in res && Array.isArray(res.images) && res.images[0] && typeof res.images[0].image === "string") {
|
|
1160
|
-
if (options?.outputType === "url") {
|
|
1161
|
-
return `data:image/jpeg;base64,${res.images[0].image}`;
|
|
1162
|
-
}
|
|
1163
|
-
const base64Response = await fetch(`data:image/jpeg;base64,${res.images[0].image}`);
|
|
1164
|
-
return await base64Response.blob();
|
|
1165
|
-
}
|
|
1166
|
-
if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
|
|
1167
|
-
const base64Data = res.data[0].b64_json;
|
|
1168
|
-
if (options?.outputType === "url") {
|
|
1169
|
-
return `data:image/jpeg;base64,${base64Data}`;
|
|
1170
|
-
}
|
|
1171
|
-
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
|
|
1172
|
-
return await base64Response.blob();
|
|
1173
|
-
}
|
|
1174
|
-
if ("output" in res && Array.isArray(res.output)) {
|
|
1175
|
-
if (options?.outputType === "url") {
|
|
1176
|
-
return res.output[0];
|
|
1177
|
-
}
|
|
1178
|
-
const urlResponse = await fetch(res.output[0]);
|
|
1179
|
-
const blob = await urlResponse.blob();
|
|
1180
|
-
return blob;
|
|
1181
|
-
}
|
|
1182
|
-
}
|
|
1183
|
-
const isValidOutput = res && res instanceof Blob;
|
|
1184
|
-
if (!isValidOutput) {
|
|
1185
|
-
throw new InferenceOutputError("Expected Blob");
|
|
1186
|
-
}
|
|
1187
|
-
if (options?.outputType === "url") {
|
|
1188
|
-
const b64 = await res.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
|
|
1189
|
-
return `data:image/jpeg;base64,${b64}`;
|
|
1190
|
-
}
|
|
1191
|
-
return res;
|
|
1192
|
-
}
|
|
1193
|
-
async function pollBflResponse(url, outputType) {
|
|
1194
|
-
const urlObj = new URL(url);
|
|
1195
|
-
for (let step = 0; step < 5; step++) {
|
|
1196
|
-
await delay(1e3);
|
|
1197
|
-
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
|
|
1198
|
-
urlObj.searchParams.set("attempt", step.toString(10));
|
|
1199
|
-
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
|
|
1200
|
-
if (!resp.ok) {
|
|
1201
|
-
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
1202
|
-
}
|
|
1203
|
-
const payload = await resp.json();
|
|
1204
|
-
if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
|
|
1205
|
-
if (outputType === "url") {
|
|
1206
|
-
return payload.result.sample;
|
|
1207
|
-
}
|
|
1208
|
-
const image = await fetch(payload.result.sample);
|
|
1209
|
-
return await image.blob();
|
|
1210
|
-
}
|
|
1211
|
-
}
|
|
1212
|
-
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
|
|
1714
|
+
return providerHelper.getResponse(res);
|
|
1213
1715
|
}
|
|
1214
1716
|
|
|
1215
1717
|
// src/tasks/cv/imageToImage.ts
|
|
1216
1718
|
async function imageToImage(args, options) {
|
|
1719
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
|
|
1217
1720
|
let reqArgs;
|
|
1218
1721
|
if (!args.parameters) {
|
|
1219
1722
|
reqArgs = {
|
|
@@ -1229,15 +1732,61 @@ async function imageToImage(args, options) {
|
|
|
1229
1732
|
)
|
|
1230
1733
|
};
|
|
1231
1734
|
}
|
|
1232
|
-
const { data: res } = await innerRequest(reqArgs, {
|
|
1735
|
+
const { data: res } = await innerRequest(reqArgs, providerHelper, {
|
|
1233
1736
|
...options,
|
|
1234
1737
|
task: "image-to-image"
|
|
1235
1738
|
});
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1739
|
+
return providerHelper.getResponse(res);
|
|
1740
|
+
}
|
|
1741
|
+
|
|
1742
|
+
// src/tasks/cv/imageToText.ts
|
|
1743
|
+
async function imageToText(args, options) {
|
|
1744
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
|
|
1745
|
+
const payload = preparePayload2(args);
|
|
1746
|
+
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1747
|
+
...options,
|
|
1748
|
+
task: "image-to-text"
|
|
1749
|
+
});
|
|
1750
|
+
return providerHelper.getResponse(res[0]);
|
|
1751
|
+
}
|
|
1752
|
+
|
|
1753
|
+
// src/tasks/cv/objectDetection.ts
|
|
1754
|
+
async function objectDetection(args, options) {
|
|
1755
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
|
|
1756
|
+
const payload = preparePayload2(args);
|
|
1757
|
+
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1758
|
+
...options,
|
|
1759
|
+
task: "object-detection"
|
|
1760
|
+
});
|
|
1761
|
+
return providerHelper.getResponse(res);
|
|
1762
|
+
}
|
|
1763
|
+
|
|
1764
|
+
// src/tasks/cv/textToImage.ts
|
|
1765
|
+
async function textToImage(args, options) {
|
|
1766
|
+
const provider = args.provider ?? "hf-inference";
|
|
1767
|
+
const providerHelper = getProviderHelper(provider, "text-to-image");
|
|
1768
|
+
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1769
|
+
...options,
|
|
1770
|
+
task: "text-to-image"
|
|
1771
|
+
});
|
|
1772
|
+
const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "text-to-image" });
|
|
1773
|
+
return providerHelper.getResponse(res, url, info.headers, options?.outputType);
|
|
1774
|
+
}
|
|
1775
|
+
|
|
1776
|
+
// src/tasks/cv/textToVideo.ts
|
|
1777
|
+
async function textToVideo(args, options) {
|
|
1778
|
+
const provider = args.provider ?? "hf-inference";
|
|
1779
|
+
const providerHelper = getProviderHelper(provider, "text-to-video");
|
|
1780
|
+
const { data: response } = await innerRequest(
|
|
1781
|
+
args,
|
|
1782
|
+
providerHelper,
|
|
1783
|
+
{
|
|
1784
|
+
...options,
|
|
1785
|
+
task: "text-to-video"
|
|
1786
|
+
}
|
|
1787
|
+
);
|
|
1788
|
+
const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "text-to-video" });
|
|
1789
|
+
return providerHelper.getResponse(response, url, info.headers);
|
|
1241
1790
|
}
|
|
1242
1791
|
|
|
1243
1792
|
// src/tasks/cv/zeroShotImageClassification.ts
|
|
@@ -1263,231 +1812,126 @@ async function preparePayload3(args) {
|
|
|
1263
1812
|
}
|
|
1264
1813
|
}
|
|
1265
1814
|
async function zeroShotImageClassification(args, options) {
|
|
1815
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
|
|
1266
1816
|
const payload = await preparePayload3(args);
|
|
1267
|
-
const { data: res } = await innerRequest(payload, {
|
|
1817
|
+
const { data: res } = await innerRequest(payload, providerHelper, {
|
|
1268
1818
|
...options,
|
|
1269
1819
|
task: "zero-shot-image-classification"
|
|
1270
1820
|
});
|
|
1271
|
-
|
|
1272
|
-
if (!isValidOutput) {
|
|
1273
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
1274
|
-
}
|
|
1275
|
-
return res;
|
|
1821
|
+
return providerHelper.getResponse(res);
|
|
1276
1822
|
}
|
|
1277
1823
|
|
|
1278
|
-
// src/tasks/
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
throw new Error(
|
|
1283
|
-
`textToVideo inference is only supported for the following providers: ${SUPPORTED_PROVIDERS.join(", ")}`
|
|
1284
|
-
);
|
|
1285
|
-
}
|
|
1286
|
-
const payload = args.provider === "fal-ai" || args.provider === "replicate" || args.provider === "novita" ? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs } : args;
|
|
1287
|
-
const { data, requestContext } = await innerRequest(payload, {
|
|
1824
|
+
// src/tasks/nlp/chatCompletion.ts
|
|
1825
|
+
async function chatCompletion(args, options) {
|
|
1826
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
|
|
1827
|
+
const { data: response } = await innerRequest(args, providerHelper, {
|
|
1288
1828
|
...options,
|
|
1289
|
-
task: "
|
|
1829
|
+
task: "conversational"
|
|
1830
|
+
});
|
|
1831
|
+
return providerHelper.getResponse(response);
|
|
1832
|
+
}
|
|
1833
|
+
|
|
1834
|
+
// src/tasks/nlp/chatCompletionStream.ts
|
|
1835
|
+
async function* chatCompletionStream(args, options) {
|
|
1836
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
|
|
1837
|
+
yield* innerStreamingRequest(args, providerHelper, {
|
|
1838
|
+
...options,
|
|
1839
|
+
task: "conversational"
|
|
1290
1840
|
});
|
|
1291
|
-
if (args.provider === "fal-ai") {
|
|
1292
|
-
return await pollFalResponse(
|
|
1293
|
-
data,
|
|
1294
|
-
requestContext.url,
|
|
1295
|
-
requestContext.info.headers
|
|
1296
|
-
);
|
|
1297
|
-
} else if (args.provider === "novita") {
|
|
1298
|
-
const isValidOutput = typeof data === "object" && !!data && "video" in data && typeof data.video === "object" && !!data.video && "video_url" in data.video && typeof data.video.video_url === "string" && isUrl(data.video.video_url);
|
|
1299
|
-
if (!isValidOutput) {
|
|
1300
|
-
throw new InferenceOutputError("Expected { video: { video_url: string } }");
|
|
1301
|
-
}
|
|
1302
|
-
const urlResponse = await fetch(data.video.video_url);
|
|
1303
|
-
return await urlResponse.blob();
|
|
1304
|
-
} else {
|
|
1305
|
-
const isValidOutput = typeof data === "object" && !!data && "output" in data && typeof data.output === "string" && isUrl(data.output);
|
|
1306
|
-
if (!isValidOutput) {
|
|
1307
|
-
throw new InferenceOutputError("Expected { output: string }");
|
|
1308
|
-
}
|
|
1309
|
-
const urlResponse = await fetch(data.output);
|
|
1310
|
-
return await urlResponse.blob();
|
|
1311
|
-
}
|
|
1312
1841
|
}
|
|
1313
1842
|
|
|
1314
1843
|
// src/tasks/nlp/featureExtraction.ts
|
|
1315
1844
|
async function featureExtraction(args, options) {
|
|
1316
|
-
const
|
|
1845
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction");
|
|
1846
|
+
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1317
1847
|
...options,
|
|
1318
1848
|
task: "feature-extraction"
|
|
1319
1849
|
});
|
|
1320
|
-
|
|
1321
|
-
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
|
|
1322
|
-
if (curDepth > maxDepth)
|
|
1323
|
-
return false;
|
|
1324
|
-
if (arr.every((x) => Array.isArray(x))) {
|
|
1325
|
-
return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
|
|
1326
|
-
} else {
|
|
1327
|
-
return arr.every((x) => typeof x === "number");
|
|
1328
|
-
}
|
|
1329
|
-
};
|
|
1330
|
-
isValidOutput = Array.isArray(res) && isNumArrayRec(res, 3, 0);
|
|
1331
|
-
if (!isValidOutput) {
|
|
1332
|
-
throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
|
|
1333
|
-
}
|
|
1334
|
-
return res;
|
|
1850
|
+
return providerHelper.getResponse(res);
|
|
1335
1851
|
}
|
|
1336
1852
|
|
|
1337
1853
|
// src/tasks/nlp/fillMask.ts
|
|
1338
1854
|
async function fillMask(args, options) {
|
|
1339
|
-
const
|
|
1855
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask");
|
|
1856
|
+
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1340
1857
|
...options,
|
|
1341
1858
|
task: "fill-mask"
|
|
1342
1859
|
});
|
|
1343
|
-
|
|
1344
|
-
(x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
|
|
1345
|
-
);
|
|
1346
|
-
if (!isValidOutput) {
|
|
1347
|
-
throw new InferenceOutputError(
|
|
1348
|
-
"Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
|
|
1349
|
-
);
|
|
1350
|
-
}
|
|
1351
|
-
return res;
|
|
1860
|
+
return providerHelper.getResponse(res);
|
|
1352
1861
|
}
|
|
1353
1862
|
|
|
1354
1863
|
// src/tasks/nlp/questionAnswering.ts
|
|
1355
1864
|
async function questionAnswering(args, options) {
|
|
1356
|
-
const
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
|
|
1361
|
-
|
|
1362
|
-
|
|
1363
|
-
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
return Array.isArray(res) ? res[0] : res;
|
|
1865
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering");
|
|
1866
|
+
const { data: res } = await innerRequest(
|
|
1867
|
+
args,
|
|
1868
|
+
providerHelper,
|
|
1869
|
+
{
|
|
1870
|
+
...options,
|
|
1871
|
+
task: "question-answering"
|
|
1872
|
+
}
|
|
1873
|
+
);
|
|
1874
|
+
return providerHelper.getResponse(res);
|
|
1367
1875
|
}
|
|
1368
1876
|
|
|
1369
1877
|
// src/tasks/nlp/sentenceSimilarity.ts
|
|
1370
1878
|
async function sentenceSimilarity(args, options) {
|
|
1371
|
-
const
|
|
1879
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity");
|
|
1880
|
+
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1372
1881
|
...options,
|
|
1373
1882
|
task: "sentence-similarity"
|
|
1374
1883
|
});
|
|
1375
|
-
|
|
1376
|
-
if (!isValidOutput) {
|
|
1377
|
-
throw new InferenceOutputError("Expected number[]");
|
|
1378
|
-
}
|
|
1379
|
-
return res;
|
|
1884
|
+
return providerHelper.getResponse(res);
|
|
1380
1885
|
}
|
|
1381
1886
|
|
|
1382
1887
|
// src/tasks/nlp/summarization.ts
|
|
1383
1888
|
async function summarization(args, options) {
|
|
1384
|
-
const
|
|
1889
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization");
|
|
1890
|
+
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1385
1891
|
...options,
|
|
1386
1892
|
task: "summarization"
|
|
1387
1893
|
});
|
|
1388
|
-
|
|
1389
|
-
if (!isValidOutput) {
|
|
1390
|
-
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
|
1391
|
-
}
|
|
1392
|
-
return res?.[0];
|
|
1894
|
+
return providerHelper.getResponse(res);
|
|
1393
1895
|
}
|
|
1394
1896
|
|
|
1395
1897
|
// src/tasks/nlp/tableQuestionAnswering.ts
|
|
1396
1898
|
async function tableQuestionAnswering(args, options) {
|
|
1397
|
-
const
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
);
|
|
1406
|
-
}
|
|
1407
|
-
return Array.isArray(res) ? res[0] : res;
|
|
1408
|
-
}
|
|
1409
|
-
function validate(elem) {
|
|
1410
|
-
return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
|
|
1411
|
-
(coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
|
|
1899
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
|
|
1900
|
+
const { data: res } = await innerRequest(
|
|
1901
|
+
args,
|
|
1902
|
+
providerHelper,
|
|
1903
|
+
{
|
|
1904
|
+
...options,
|
|
1905
|
+
task: "table-question-answering"
|
|
1906
|
+
}
|
|
1412
1907
|
);
|
|
1908
|
+
return providerHelper.getResponse(res);
|
|
1413
1909
|
}
|
|
1414
1910
|
|
|
1415
1911
|
// src/tasks/nlp/textClassification.ts
|
|
1416
1912
|
async function textClassification(args, options) {
|
|
1417
|
-
const
|
|
1913
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
|
|
1914
|
+
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1418
1915
|
...options,
|
|
1419
1916
|
task: "text-classification"
|
|
1420
1917
|
});
|
|
1421
|
-
|
|
1422
|
-
const isValidOutput = Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number");
|
|
1423
|
-
if (!isValidOutput) {
|
|
1424
|
-
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
|
1425
|
-
}
|
|
1426
|
-
return output;
|
|
1427
|
-
}
|
|
1428
|
-
|
|
1429
|
-
// src/utils/toArray.ts
|
|
1430
|
-
function toArray(obj) {
|
|
1431
|
-
if (Array.isArray(obj)) {
|
|
1432
|
-
return obj;
|
|
1433
|
-
}
|
|
1434
|
-
return [obj];
|
|
1918
|
+
return providerHelper.getResponse(res);
|
|
1435
1919
|
}
|
|
1436
1920
|
|
|
1437
1921
|
// src/tasks/nlp/textGeneration.ts
|
|
1438
1922
|
async function textGeneration(args, options) {
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
|
|
1446
|
-
if (!isValidOutput) {
|
|
1447
|
-
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
1448
|
-
}
|
|
1449
|
-
const completion = raw.choices[0];
|
|
1450
|
-
return {
|
|
1451
|
-
generated_text: completion.text
|
|
1452
|
-
};
|
|
1453
|
-
} else if (args.provider === "hyperbolic") {
|
|
1454
|
-
const payload = {
|
|
1455
|
-
messages: [{ content: args.inputs, role: "user" }],
|
|
1456
|
-
...args.parameters ? {
|
|
1457
|
-
max_tokens: args.parameters.max_new_tokens,
|
|
1458
|
-
...omit(args.parameters, "max_new_tokens")
|
|
1459
|
-
} : void 0,
|
|
1460
|
-
...omit(args, ["inputs", "parameters"])
|
|
1461
|
-
};
|
|
1462
|
-
const raw = (await innerRequest(payload, {
|
|
1463
|
-
...options,
|
|
1464
|
-
task: "text-generation"
|
|
1465
|
-
})).data;
|
|
1466
|
-
const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
|
|
1467
|
-
if (!isValidOutput) {
|
|
1468
|
-
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
1469
|
-
}
|
|
1470
|
-
const completion = raw.choices[0];
|
|
1471
|
-
return {
|
|
1472
|
-
generated_text: completion.message.content
|
|
1473
|
-
};
|
|
1474
|
-
} else {
|
|
1475
|
-
const { data: res } = await innerRequest(args, {
|
|
1476
|
-
...options,
|
|
1477
|
-
task: "text-generation"
|
|
1478
|
-
});
|
|
1479
|
-
const output = toArray(res);
|
|
1480
|
-
const isValidOutput = Array.isArray(output) && output.every((x) => "generated_text" in x && typeof x?.generated_text === "string");
|
|
1481
|
-
if (!isValidOutput) {
|
|
1482
|
-
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
|
1483
|
-
}
|
|
1484
|
-
return output?.[0];
|
|
1485
|
-
}
|
|
1923
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
|
|
1924
|
+
const { data: response } = await innerRequest(args, providerHelper, {
|
|
1925
|
+
...options,
|
|
1926
|
+
task: "text-generation"
|
|
1927
|
+
});
|
|
1928
|
+
return providerHelper.getResponse(response);
|
|
1486
1929
|
}
|
|
1487
1930
|
|
|
1488
1931
|
// src/tasks/nlp/textGenerationStream.ts
|
|
1489
1932
|
async function* textGenerationStream(args, options) {
|
|
1490
|
-
|
|
1933
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
|
|
1934
|
+
yield* innerStreamingRequest(args, providerHelper, {
|
|
1491
1935
|
...options,
|
|
1492
1936
|
task: "text-generation"
|
|
1493
1937
|
});
|
|
@@ -1495,77 +1939,45 @@ async function* textGenerationStream(args, options) {
|
|
|
1495
1939
|
|
|
1496
1940
|
// src/tasks/nlp/tokenClassification.ts
|
|
1497
1941
|
async function tokenClassification(args, options) {
|
|
1498
|
-
const
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
|
|
1502
|
-
|
|
1503
|
-
|
|
1504
|
-
|
|
1942
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification");
|
|
1943
|
+
const { data: res } = await innerRequest(
|
|
1944
|
+
args,
|
|
1945
|
+
providerHelper,
|
|
1946
|
+
{
|
|
1947
|
+
...options,
|
|
1948
|
+
task: "token-classification"
|
|
1949
|
+
}
|
|
1505
1950
|
);
|
|
1506
|
-
|
|
1507
|
-
throw new InferenceOutputError(
|
|
1508
|
-
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
|
|
1509
|
-
);
|
|
1510
|
-
}
|
|
1511
|
-
return output;
|
|
1951
|
+
return providerHelper.getResponse(res);
|
|
1512
1952
|
}
|
|
1513
1953
|
|
|
1514
1954
|
// src/tasks/nlp/translation.ts
|
|
1515
1955
|
async function translation(args, options) {
|
|
1516
|
-
const
|
|
1956
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation");
|
|
1957
|
+
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1517
1958
|
...options,
|
|
1518
1959
|
task: "translation"
|
|
1519
1960
|
});
|
|
1520
|
-
|
|
1521
|
-
if (!isValidOutput) {
|
|
1522
|
-
throw new InferenceOutputError("Expected type Array<{translation_text: string}>");
|
|
1523
|
-
}
|
|
1524
|
-
return res?.length === 1 ? res?.[0] : res;
|
|
1961
|
+
return providerHelper.getResponse(res);
|
|
1525
1962
|
}
|
|
1526
1963
|
|
|
1527
1964
|
// src/tasks/nlp/zeroShotClassification.ts
|
|
1528
1965
|
async function zeroShotClassification(args, options) {
|
|
1529
|
-
const
|
|
1530
|
-
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1966
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification");
|
|
1967
|
+
const { data: res } = await innerRequest(
|
|
1968
|
+
args,
|
|
1969
|
+
providerHelper,
|
|
1970
|
+
{
|
|
1971
|
+
...options,
|
|
1972
|
+
task: "zero-shot-classification"
|
|
1973
|
+
}
|
|
1536
1974
|
);
|
|
1537
|
-
|
|
1538
|
-
throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
|
|
1539
|
-
}
|
|
1540
|
-
return output;
|
|
1541
|
-
}
|
|
1542
|
-
|
|
1543
|
-
// src/tasks/nlp/chatCompletion.ts
|
|
1544
|
-
async function chatCompletion(args, options) {
|
|
1545
|
-
const { data: res } = await innerRequest(args, {
|
|
1546
|
-
...options,
|
|
1547
|
-
task: "text-generation",
|
|
1548
|
-
chatCompletion: true
|
|
1549
|
-
});
|
|
1550
|
-
const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
|
|
1551
|
-
(res.system_fingerprint === void 0 || res.system_fingerprint === null || typeof res.system_fingerprint === "string") && typeof res?.usage === "object";
|
|
1552
|
-
if (!isValidOutput) {
|
|
1553
|
-
throw new InferenceOutputError("Expected ChatCompletionOutput");
|
|
1554
|
-
}
|
|
1555
|
-
return res;
|
|
1556
|
-
}
|
|
1557
|
-
|
|
1558
|
-
// src/tasks/nlp/chatCompletionStream.ts
|
|
1559
|
-
async function* chatCompletionStream(args, options) {
|
|
1560
|
-
yield* innerStreamingRequest(args, {
|
|
1561
|
-
...options,
|
|
1562
|
-
task: "text-generation",
|
|
1563
|
-
chatCompletion: true
|
|
1564
|
-
});
|
|
1975
|
+
return providerHelper.getResponse(res);
|
|
1565
1976
|
}
|
|
1566
1977
|
|
|
1567
1978
|
// src/tasks/multimodal/documentQuestionAnswering.ts
|
|
1568
1979
|
async function documentQuestionAnswering(args, options) {
|
|
1980
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering");
|
|
1569
1981
|
const reqArgs = {
|
|
1570
1982
|
...args,
|
|
1571
1983
|
inputs: {
|
|
@@ -1576,23 +1988,18 @@ async function documentQuestionAnswering(args, options) {
|
|
|
1576
1988
|
};
|
|
1577
1989
|
const { data: res } = await innerRequest(
|
|
1578
1990
|
reqArgs,
|
|
1991
|
+
providerHelper,
|
|
1579
1992
|
{
|
|
1580
1993
|
...options,
|
|
1581
1994
|
task: "document-question-answering"
|
|
1582
1995
|
}
|
|
1583
1996
|
);
|
|
1584
|
-
|
|
1585
|
-
const isValidOutput = Array.isArray(output) && output.every(
|
|
1586
|
-
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
|
|
1587
|
-
);
|
|
1588
|
-
if (!isValidOutput) {
|
|
1589
|
-
throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>");
|
|
1590
|
-
}
|
|
1591
|
-
return output[0];
|
|
1997
|
+
return providerHelper.getResponse(res);
|
|
1592
1998
|
}
|
|
1593
1999
|
|
|
1594
2000
|
// src/tasks/multimodal/visualQuestionAnswering.ts
|
|
1595
2001
|
async function visualQuestionAnswering(args, options) {
|
|
2002
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "visual-question-answering");
|
|
1596
2003
|
const reqArgs = {
|
|
1597
2004
|
...args,
|
|
1598
2005
|
inputs: {
|
|
@@ -1601,43 +2008,31 @@ async function visualQuestionAnswering(args, options) {
|
|
|
1601
2008
|
image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer()))
|
|
1602
2009
|
}
|
|
1603
2010
|
};
|
|
1604
|
-
const { data: res } = await innerRequest(reqArgs, {
|
|
2011
|
+
const { data: res } = await innerRequest(reqArgs, providerHelper, {
|
|
1605
2012
|
...options,
|
|
1606
2013
|
task: "visual-question-answering"
|
|
1607
2014
|
});
|
|
1608
|
-
|
|
1609
|
-
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
|
|
1610
|
-
);
|
|
1611
|
-
if (!isValidOutput) {
|
|
1612
|
-
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
|
1613
|
-
}
|
|
1614
|
-
return res[0];
|
|
2015
|
+
return providerHelper.getResponse(res);
|
|
1615
2016
|
}
|
|
1616
2017
|
|
|
1617
|
-
// src/tasks/tabular/
|
|
1618
|
-
async function
|
|
1619
|
-
const
|
|
2018
|
+
// src/tasks/tabular/tabularClassification.ts
|
|
2019
|
+
async function tabularClassification(args, options) {
|
|
2020
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification");
|
|
2021
|
+
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1620
2022
|
...options,
|
|
1621
|
-
task: "tabular-
|
|
2023
|
+
task: "tabular-classification"
|
|
1622
2024
|
});
|
|
1623
|
-
|
|
1624
|
-
if (!isValidOutput) {
|
|
1625
|
-
throw new InferenceOutputError("Expected number[]");
|
|
1626
|
-
}
|
|
1627
|
-
return res;
|
|
2025
|
+
return providerHelper.getResponse(res);
|
|
1628
2026
|
}
|
|
1629
2027
|
|
|
1630
|
-
// src/tasks/tabular/
|
|
1631
|
-
async function
|
|
1632
|
-
const
|
|
2028
|
+
// src/tasks/tabular/tabularRegression.ts
|
|
2029
|
+
async function tabularRegression(args, options) {
|
|
2030
|
+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression");
|
|
2031
|
+
const { data: res } = await innerRequest(args, providerHelper, {
|
|
1633
2032
|
...options,
|
|
1634
|
-
task: "tabular-
|
|
2033
|
+
task: "tabular-regression"
|
|
1635
2034
|
});
|
|
1636
|
-
|
|
1637
|
-
if (!isValidOutput) {
|
|
1638
|
-
throw new InferenceOutputError("Expected number[]");
|
|
1639
|
-
}
|
|
1640
|
-
return res;
|
|
2035
|
+
return providerHelper.getResponse(res);
|
|
1641
2036
|
}
|
|
1642
2037
|
|
|
1643
2038
|
// src/InferenceClient.ts
|
|
@@ -1706,29 +2101,29 @@ __export(snippets_exports, {
|
|
|
1706
2101
|
});
|
|
1707
2102
|
|
|
1708
2103
|
// src/snippets/getInferenceSnippets.ts
|
|
2104
|
+
import { Template } from "@huggingface/jinja";
|
|
1709
2105
|
import {
|
|
1710
|
-
|
|
1711
|
-
|
|
2106
|
+
getModelInputSnippet,
|
|
2107
|
+
inferenceSnippetLanguages
|
|
1712
2108
|
} from "@huggingface/tasks";
|
|
1713
|
-
import { Template } from "@huggingface/jinja";
|
|
1714
2109
|
|
|
1715
2110
|
// src/snippets/templates.exported.ts
|
|
1716
2111
|
var templates = {
|
|
1717
2112
|
"js": {
|
|
1718
2113
|
"fetch": {
|
|
1719
|
-
"basic": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
|
|
1720
|
-
"basicAudio": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "audio/flac"\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
|
|
1721
|
-
"basicImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "image/jpeg"\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
|
|
1722
|
-
"textToAudio": '{% if model.library_name == "transformers" %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ',
|
|
1723
|
-
"textToImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\n\nquery({ {{ providerInputs.asTsString }} }).then((response) => {\n // Use image\n});',
|
|
1724
|
-
"zeroShotClassification": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n
|
|
2114
|
+
"basic": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
|
|
2115
|
+
"basicAudio": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "audio/flac",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
|
|
2116
|
+
"basicImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "image/jpeg",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});',
|
|
2117
|
+
"textToAudio": '{% if model.library_name == "transformers" %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n // Returns a byte object of the Audio wavform. Use it directly!\n});\n{% else %}\nasync function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({ inputs: {{ providerInputs.asObj.inputs }} }).then((response) => {\n console.log(JSON.stringify(response));\n});\n{% endif %} ',
|
|
2118
|
+
"textToImage": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.blob();\n return result;\n}\n\n\nquery({ {{ providerInputs.asTsString }} }).then((response) => {\n // Use image\n});',
|
|
2119
|
+
"zeroShotClassification": 'async function query(data) {\n const response = await fetch(\n "{{ fullUrl }}",\n {\n headers: {\n Authorization: "{{ authorizationHeader }}",\n "Content-Type": "application/json",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}",\n{% endif %} },\n method: "POST",\n body: JSON.stringify(data),\n }\n );\n const result = await response.json();\n return result;\n}\n\nquery({\n inputs: {{ providerInputs.asObj.inputs }},\n parameters: { candidate_labels: ["refund", "legal", "faq"] }\n}).then((response) => {\n console.log(JSON.stringify(response));\n});'
|
|
1725
2120
|
},
|
|
1726
2121
|
"huggingface.js": {
|
|
1727
|
-
"basic": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst output = await client.{{ methodName }}({\n model: "{{ model.id }}",\n inputs: {{ inputs.asObj.inputs }},\n provider: "{{ provider }}",\n});\n\nconsole.log(output);',
|
|
1728
|
-
"basicAudio": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst data = fs.readFileSync({{inputs.asObj.inputs}});\n\nconst output = await client.{{ methodName }}({\n data,\n model: "{{ model.id }}",\n provider: "{{ provider }}",\n});\n\nconsole.log(output);',
|
|
1729
|
-
"basicImage": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst data = fs.readFileSync({{inputs.asObj.inputs}});\n\nconst output = await client.{{ methodName }}({\n data,\n model: "{{ model.id }}",\n provider: "{{ provider }}",\n});\n\nconsole.log(output);',
|
|
1730
|
-
"conversational": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst chatCompletion = await client.chatCompletion({\n provider: "{{ provider }}",\n model: "{{ model.id }}",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);',
|
|
1731
|
-
"conversationalStream": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nlet out = "";\n\nconst stream = await client.chatCompletionStream({\n provider: "{{ provider }}",\n model: "{{ model.id }}",\n{{ inputs.asTsString }}\n});\n\nfor await (const chunk of stream) {\n if (chunk.choices && chunk.choices.length > 0) {\n const newContent = chunk.choices[0].delta.content;\n out += newContent;\n console.log(newContent);\n } \n}',
|
|
2122
|
+
"basic": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst output = await client.{{ methodName }}({\n model: "{{ model.id }}",\n inputs: {{ inputs.asObj.inputs }},\n provider: "{{ provider }}",\n}{% if billTo %}, {\n billTo: "{{ billTo }}",\n}{% endif %});\n\nconsole.log(output);',
|
|
2123
|
+
"basicAudio": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst data = fs.readFileSync({{inputs.asObj.inputs}});\n\nconst output = await client.{{ methodName }}({\n data,\n model: "{{ model.id }}",\n provider: "{{ provider }}",\n}{% if billTo %}, {\n billTo: "{{ billTo }}",\n}{% endif %});\n\nconsole.log(output);',
|
|
2124
|
+
"basicImage": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst data = fs.readFileSync({{inputs.asObj.inputs}});\n\nconst output = await client.{{ methodName }}({\n data,\n model: "{{ model.id }}",\n provider: "{{ provider }}",\n}{% if billTo %}, {\n billTo: "{{ billTo }}",\n}{% endif %});\n\nconsole.log(output);',
|
|
2125
|
+
"conversational": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nconst chatCompletion = await client.chatCompletion({\n provider: "{{ provider }}",\n model: "{{ model.id }}",\n{{ inputs.asTsString }}\n}{% if billTo %}, {\n billTo: "{{ billTo }}",\n}{% endif %});\n\nconsole.log(chatCompletion.choices[0].message);',
|
|
2126
|
+
"conversationalStream": 'import { InferenceClient } from "@huggingface/inference";\n\nconst client = new InferenceClient("{{ accessToken }}");\n\nlet out = "";\n\nconst stream = await client.chatCompletionStream({\n provider: "{{ provider }}",\n model: "{{ model.id }}",\n{{ inputs.asTsString }}\n}{% if billTo %}, {\n billTo: "{{ billTo }}",\n}{% endif %});\n\nfor await (const chunk of stream) {\n if (chunk.choices && chunk.choices.length > 0) {\n const newContent = chunk.choices[0].delta.content;\n out += newContent;\n console.log(newContent);\n } \n}',
|
|
1732
2127
|
"textToImage": `import { InferenceClient } from "@huggingface/inference";
|
|
1733
2128
|
|
|
1734
2129
|
const client = new InferenceClient("{{ accessToken }}");
|
|
@@ -1738,7 +2133,9 @@ const image = await client.textToImage({
|
|
|
1738
2133
|
model: "{{ model.id }}",
|
|
1739
2134
|
inputs: {{ inputs.asObj.inputs }},
|
|
1740
2135
|
parameters: { num_inference_steps: 5 },
|
|
1741
|
-
}
|
|
2136
|
+
}{% if billTo %}, {
|
|
2137
|
+
billTo: "{{ billTo }}",
|
|
2138
|
+
}{% endif %});
|
|
1742
2139
|
/// Use the generated image (it's a Blob)`,
|
|
1743
2140
|
"textToVideo": `import { InferenceClient } from "@huggingface/inference";
|
|
1744
2141
|
|
|
@@ -1748,12 +2145,14 @@ const image = await client.textToVideo({
|
|
|
1748
2145
|
provider: "{{ provider }}",
|
|
1749
2146
|
model: "{{ model.id }}",
|
|
1750
2147
|
inputs: {{ inputs.asObj.inputs }},
|
|
1751
|
-
}
|
|
2148
|
+
}{% if billTo %}, {
|
|
2149
|
+
billTo: "{{ billTo }}",
|
|
2150
|
+
}{% endif %});
|
|
1752
2151
|
// Use the generated video (it's a Blob)`
|
|
1753
2152
|
},
|
|
1754
2153
|
"openai": {
|
|
1755
|
-
"conversational": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n});\n\nconst chatCompletion = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);',
|
|
1756
|
-
"conversationalStream": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n}
|
|
2154
|
+
"conversational": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n{% if billTo %}\n defaultHeaders: {\n "X-HF-Bill-To": "{{ billTo }}" \n }\n{% endif %}\n});\n\nconst chatCompletion = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n});\n\nconsole.log(chatCompletion.choices[0].message);',
|
|
2155
|
+
"conversationalStream": 'import { OpenAI } from "openai";\n\nconst client = new OpenAI({\n baseURL: "{{ baseUrl }}",\n apiKey: "{{ accessToken }}",\n{% if billTo %}\n defaultHeaders: {\n "X-HF-Bill-To": "{{ billTo }}" \n }\n{% endif %}\n});\n\nconst stream = await client.chat.completions.create({\n model: "{{ providerModelId }}",\n{{ inputs.asTsString }}\n stream: true,\n});\n\nfor await (const chunk of stream) {\n process.stdout.write(chunk.choices[0]?.delta?.content || "");\n}'
|
|
1757
2156
|
}
|
|
1758
2157
|
},
|
|
1759
2158
|
"python": {
|
|
@@ -1768,13 +2167,13 @@ const image = await client.textToVideo({
|
|
|
1768
2167
|
"conversationalStream": 'stream = client.chat.completions.create(\n model="{{ model.id }}",\n{{ inputs.asPythonString }}\n stream=True,\n)\n\nfor chunk in stream:\n print(chunk.choices[0].delta.content, end="") ',
|
|
1769
2168
|
"documentQuestionAnswering": 'output = client.document_question_answering(\n "{{ inputs.asObj.image }}",\n question="{{ inputs.asObj.question }}",\n model="{{ model.id }}",\n) ',
|
|
1770
2169
|
"imageToImage": '# output is a PIL.Image object\nimage = client.image_to_image(\n "{{ inputs.asObj.inputs }}",\n prompt="{{ inputs.asObj.parameters.prompt }}",\n model="{{ model.id }}",\n) ',
|
|
1771
|
-
"importInferenceClient": 'from huggingface_hub import InferenceClient\n\nclient = InferenceClient(\n provider="{{ provider }}",\n api_key="{{ accessToken }}",\n)',
|
|
2170
|
+
"importInferenceClient": 'from huggingface_hub import InferenceClient\n\nclient = InferenceClient(\n provider="{{ provider }}",\n api_key="{{ accessToken }}",\n{% if billTo %}\n bill_to="{{ billTo }}",\n{% endif %}\n)',
|
|
1772
2171
|
"textToImage": '# output is a PIL.Image object\nimage = client.text_to_image(\n {{ inputs.asObj.inputs }},\n model="{{ model.id }}",\n) ',
|
|
1773
2172
|
"textToVideo": 'video = client.text_to_video(\n {{ inputs.asObj.inputs }},\n model="{{ model.id }}",\n) '
|
|
1774
2173
|
},
|
|
1775
2174
|
"openai": {
|
|
1776
|
-
"conversational": 'from openai import OpenAI\n\nclient = OpenAI(\n base_url="{{ baseUrl }}",\n api_key="{{ accessToken }}"\n)\n\ncompletion = client.chat.completions.create(\n model="{{ providerModelId }}",\n{{ inputs.asPythonString }}\n)\n\nprint(completion.choices[0].message) ',
|
|
1777
|
-
"conversationalStream": 'from openai import OpenAI\n\nclient = OpenAI(\n base_url="{{ baseUrl }}",\n api_key="{{ accessToken }}"\n)\n\nstream = client.chat.completions.create(\n model="{{ providerModelId }}",\n{{ inputs.asPythonString }}\n stream=True,\n)\n\nfor chunk in stream:\n print(chunk.choices[0].delta.content, end="")'
|
|
2175
|
+
"conversational": 'from openai import OpenAI\n\nclient = OpenAI(\n base_url="{{ baseUrl }}",\n api_key="{{ accessToken }}",\n{% if billTo %}\n default_headers={\n "X-HF-Bill-To": "{{ billTo }}"\n }\n{% endif %}\n)\n\ncompletion = client.chat.completions.create(\n model="{{ providerModelId }}",\n{{ inputs.asPythonString }}\n)\n\nprint(completion.choices[0].message) ',
|
|
2176
|
+
"conversationalStream": 'from openai import OpenAI\n\nclient = OpenAI(\n base_url="{{ baseUrl }}",\n api_key="{{ accessToken }}",\n{% if billTo %}\n default_headers={\n "X-HF-Bill-To": "{{ billTo }}"\n }\n{% endif %}\n)\n\nstream = client.chat.completions.create(\n model="{{ providerModelId }}",\n{{ inputs.asPythonString }}\n stream=True,\n)\n\nfor chunk in stream:\n print(chunk.choices[0].delta.content, end="")'
|
|
1778
2177
|
},
|
|
1779
2178
|
"requests": {
|
|
1780
2179
|
"basic": 'def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\noutput = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n}) ',
|
|
@@ -1784,7 +2183,7 @@ const image = await client.textToVideo({
|
|
|
1784
2183
|
"conversationalStream": 'def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload, stream=True)\n for line in response.iter_lines():\n if not line.startswith(b"data:"):\n continue\n if line.strip() == b"data: [DONE]":\n return\n yield json.loads(line.decode("utf-8").lstrip("data:").rstrip("/n"))\n\nchunks = query({\n{{ providerInputs.asJsonString }},\n "stream": True,\n})\n\nfor chunk in chunks:\n print(chunk["choices"][0]["delta"]["content"], end="")',
|
|
1785
2184
|
"documentQuestionAnswering": 'def query(payload):\n with open(payload["image"], "rb") as f:\n img = f.read()\n payload["image"] = base64.b64encode(img).decode("utf-8")\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\noutput = query({\n "inputs": {\n "image": "{{ inputs.asObj.image }}",\n "question": "{{ inputs.asObj.question }}",\n },\n}) ',
|
|
1786
2185
|
"imageToImage": 'def query(payload):\n with open(payload["inputs"], "rb") as f:\n img = f.read()\n payload["inputs"] = base64.b64encode(img).decode("utf-8")\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nimage_bytes = query({\n{{ providerInputs.asJsonString }}\n})\n\n# You can access the image with PIL.Image for example\nimport io\nfrom PIL import Image\nimage = Image.open(io.BytesIO(image_bytes)) ',
|
|
1787
|
-
"importRequests": '{% if importBase64 %}\nimport base64\n{% endif %}\n{% if importJson %}\nimport json\n{% endif %}\nimport requests\n\nAPI_URL = "{{ fullUrl }}"\nheaders = {"Authorization": "{{ authorizationHeader }}"}',
|
|
2186
|
+
"importRequests": '{% if importBase64 %}\nimport base64\n{% endif %}\n{% if importJson %}\nimport json\n{% endif %}\nimport requests\n\nAPI_URL = "{{ fullUrl }}"\nheaders = {\n "Authorization": "{{ authorizationHeader }}",\n{% if billTo %}\n "X-HF-Bill-To": "{{ billTo }}"\n{% endif %}\n}',
|
|
1788
2187
|
"tabular": 'def query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nresponse = query({\n "inputs": {\n "data": {{ providerInputs.asObj.inputs }}\n },\n}) ',
|
|
1789
2188
|
"textToAudio": '{% if model.library_name == "transformers" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\naudio_bytes = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio_bytes)\n{% else %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.json()\n\naudio, sampling_rate = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n})\n# You can access the audio with IPython.display for example\nfrom IPython.display import Audio\nAudio(audio, rate=sampling_rate)\n{% endif %} ',
|
|
1790
2189
|
"textToImage": '{% if provider == "hf-inference" %}\ndef query(payload):\n response = requests.post(API_URL, headers=headers, json=payload)\n return response.content\n\nimage_bytes = query({\n "inputs": {{ providerInputs.asObj.inputs }},\n})\n\n# You can access the image with PIL.Image for example\nimport io\nfrom PIL import Image\nimage = Image.open(io.BytesIO(image_bytes))\n{% endif %}',
|
|
@@ -1794,12 +2193,15 @@ const image = await client.textToVideo({
|
|
|
1794
2193
|
},
|
|
1795
2194
|
"sh": {
|
|
1796
2195
|
"curl": {
|
|
1797
|
-
"basic": "curl {{ fullUrl }} \\\n -X POST \\\n -H 'Authorization: {{ authorizationHeader }}' \\\n -H 'Content-Type: application/json' \\\n -d '{\n{{ providerInputs.asCurlString }}\n }'",
|
|
1798
|
-
"basicAudio": "curl {{ fullUrl }} \\\n -X POST \\\n -H 'Authorization: {{ authorizationHeader }}' \\\n -H 'Content-Type: audio/flac' \\\n --data-binary @{{ providerInputs.asObj.inputs }}",
|
|
1799
|
-
"basicImage": "curl {{ fullUrl }} \\\n -X POST \\\n -H 'Authorization: {{ authorizationHeader }}' \\\n -H 'Content-Type: image/jpeg' \\\n --data-binary @{{ providerInputs.asObj.inputs }}",
|
|
2196
|
+
"basic": "curl {{ fullUrl }} \\\n -X POST \\\n -H 'Authorization: {{ authorizationHeader }}' \\\n -H 'Content-Type: application/json' \\\n{% if billTo %}\n -H 'X-HF-Bill-To: {{ billTo }}' \\\n{% endif %}\n -d '{\n{{ providerInputs.asCurlString }}\n }'",
|
|
2197
|
+
"basicAudio": "curl {{ fullUrl }} \\\n -X POST \\\n -H 'Authorization: {{ authorizationHeader }}' \\\n -H 'Content-Type: audio/flac' \\\n{% if billTo %}\n -H 'X-HF-Bill-To: {{ billTo }}' \\\n{% endif %}\n --data-binary @{{ providerInputs.asObj.inputs }}",
|
|
2198
|
+
"basicImage": "curl {{ fullUrl }} \\\n -X POST \\\n -H 'Authorization: {{ authorizationHeader }}' \\\n -H 'Content-Type: image/jpeg' \\\n{% if billTo %}\n -H 'X-HF-Bill-To: {{ billTo }}' \\\n{% endif %}\n --data-binary @{{ providerInputs.asObj.inputs }}",
|
|
1800
2199
|
"conversational": `curl {{ fullUrl }} \\
|
|
1801
2200
|
-H 'Authorization: {{ authorizationHeader }}' \\
|
|
1802
2201
|
-H 'Content-Type: application/json' \\
|
|
2202
|
+
{% if billTo %}
|
|
2203
|
+
-H 'X-HF-Bill-To: {{ billTo }}' \\
|
|
2204
|
+
{% endif %}
|
|
1803
2205
|
-d '{
|
|
1804
2206
|
{{ providerInputs.asCurlString }},
|
|
1805
2207
|
"stream": false
|
|
@@ -1807,6 +2209,9 @@ const image = await client.textToVideo({
|
|
|
1807
2209
|
"conversationalStream": `curl {{ fullUrl }} \\
|
|
1808
2210
|
-H 'Authorization: {{ authorizationHeader }}' \\
|
|
1809
2211
|
-H 'Content-Type: application/json' \\
|
|
2212
|
+
{% if billTo %}
|
|
2213
|
+
-H 'X-HF-Bill-To: {{ billTo }}' \\
|
|
2214
|
+
{% endif %}
|
|
1810
2215
|
-d '{
|
|
1811
2216
|
{{ providerInputs.asCurlString }},
|
|
1812
2217
|
"stream": true
|
|
@@ -1815,7 +2220,10 @@ const image = await client.textToVideo({
|
|
|
1815
2220
|
-X POST \\
|
|
1816
2221
|
-d '{"inputs": {{ providerInputs.asObj.inputs }}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\
|
|
1817
2222
|
-H 'Content-Type: application/json' \\
|
|
1818
|
-
-H 'Authorization: {{ authorizationHeader }}'
|
|
2223
|
+
-H 'Authorization: {{ authorizationHeader }}'
|
|
2224
|
+
{% if billTo %} \\
|
|
2225
|
+
-H 'X-HF-Bill-To: {{ billTo }}'
|
|
2226
|
+
{% endif %}`
|
|
1819
2227
|
}
|
|
1820
2228
|
}
|
|
1821
2229
|
};
|
|
@@ -1884,16 +2292,35 @@ var HF_JS_METHODS = {
|
|
|
1884
2292
|
translation: "translation"
|
|
1885
2293
|
};
|
|
1886
2294
|
var snippetGenerator = (templateName, inputPreparationFn) => {
|
|
1887
|
-
return (model, accessToken, provider,
|
|
2295
|
+
return (model, accessToken, provider, inferenceProviderMapping, billTo, opts) => {
|
|
2296
|
+
const providerModelId = inferenceProviderMapping?.providerId ?? model.id;
|
|
2297
|
+
let task = model.pipeline_tag;
|
|
1888
2298
|
if (model.pipeline_tag && ["text-generation", "image-text-to-text"].includes(model.pipeline_tag) && model.tags.includes("conversational")) {
|
|
1889
2299
|
templateName = opts?.streaming ? "conversationalStream" : "conversational";
|
|
1890
2300
|
inputPreparationFn = prepareConversationalInput;
|
|
2301
|
+
task = "conversational";
|
|
2302
|
+
}
|
|
2303
|
+
let providerHelper;
|
|
2304
|
+
try {
|
|
2305
|
+
providerHelper = getProviderHelper(provider, task);
|
|
2306
|
+
} catch (e) {
|
|
2307
|
+
console.error(`Failed to get provider helper for ${provider} (${task})`, e);
|
|
2308
|
+
return [];
|
|
1891
2309
|
}
|
|
1892
2310
|
const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
|
|
1893
2311
|
const request2 = makeRequestOptionsFromResolvedModel(
|
|
1894
|
-
providerModelId
|
|
1895
|
-
|
|
1896
|
-
{
|
|
2312
|
+
providerModelId,
|
|
2313
|
+
providerHelper,
|
|
2314
|
+
{
|
|
2315
|
+
accessToken,
|
|
2316
|
+
provider,
|
|
2317
|
+
...inputs
|
|
2318
|
+
},
|
|
2319
|
+
inferenceProviderMapping,
|
|
2320
|
+
{
|
|
2321
|
+
task,
|
|
2322
|
+
billTo
|
|
2323
|
+
}
|
|
1897
2324
|
);
|
|
1898
2325
|
let providerInputs = inputs;
|
|
1899
2326
|
const bodyAsObj = request2.info.body;
|
|
@@ -1925,7 +2352,8 @@ var snippetGenerator = (templateName, inputPreparationFn) => {
|
|
|
1925
2352
|
},
|
|
1926
2353
|
model,
|
|
1927
2354
|
provider,
|
|
1928
|
-
providerModelId: providerModelId ?? model.id
|
|
2355
|
+
providerModelId: providerModelId ?? model.id,
|
|
2356
|
+
billTo
|
|
1929
2357
|
};
|
|
1930
2358
|
return inferenceSnippetLanguages.map((language) => {
|
|
1931
2359
|
return CLIENTS[language].map((client) => {
|
|
@@ -1980,7 +2408,7 @@ var prepareConversationalInput = (model, opts) => {
|
|
|
1980
2408
|
return {
|
|
1981
2409
|
messages: opts?.messages ?? getModelInputSnippet(model),
|
|
1982
2410
|
...opts?.temperature ? { temperature: opts?.temperature } : void 0,
|
|
1983
|
-
max_tokens: opts?.max_tokens ??
|
|
2411
|
+
max_tokens: opts?.max_tokens ?? 512,
|
|
1984
2412
|
...opts?.top_p ? { top_p: opts?.top_p } : void 0
|
|
1985
2413
|
};
|
|
1986
2414
|
};
|
|
@@ -2015,8 +2443,8 @@ var snippets = {
|
|
|
2015
2443
|
"zero-shot-classification": snippetGenerator("zeroShotClassification"),
|
|
2016
2444
|
"zero-shot-image-classification": snippetGenerator("zeroShotImageClassification")
|
|
2017
2445
|
};
|
|
2018
|
-
function getInferenceSnippets(model, accessToken, provider,
|
|
2019
|
-
return model.pipeline_tag && model.pipeline_tag in snippets ? snippets[model.pipeline_tag]?.(model, accessToken, provider,
|
|
2446
|
+
function getInferenceSnippets(model, accessToken, provider, inferenceProviderMapping, billTo, opts) {
|
|
2447
|
+
return model.pipeline_tag && model.pipeline_tag in snippets ? snippets[model.pipeline_tag]?.(model, accessToken, provider, inferenceProviderMapping, billTo, opts) ?? [] : [];
|
|
2020
2448
|
}
|
|
2021
2449
|
function formatBody(obj, format) {
|
|
2022
2450
|
switch (format) {
|