@huggingface/inference 1.8.0 → 2.0.0-rc2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. package/README.md +57 -8
  2. package/dist/index.js +440 -354
  3. package/dist/index.mjs +423 -353
  4. package/package.json +7 -9
  5. package/src/HfInference.ts +43 -1112
  6. package/src/index.ts +3 -1
  7. package/src/lib/InferenceOutputError.ts +8 -0
  8. package/src/lib/makeRequestOptions.ts +55 -0
  9. package/src/tasks/audio/audioClassification.ts +41 -0
  10. package/src/tasks/audio/automaticSpeechRecognition.ts +33 -0
  11. package/src/tasks/custom/request.ts +39 -0
  12. package/src/tasks/custom/streamingRequest.ts +76 -0
  13. package/src/tasks/cv/imageClassification.ts +40 -0
  14. package/src/tasks/cv/imageSegmentation.ts +45 -0
  15. package/src/tasks/cv/imageToText.ts +30 -0
  16. package/src/tasks/cv/objectDetection.ts +58 -0
  17. package/src/tasks/cv/textToImage.ts +48 -0
  18. package/src/tasks/index.ts +29 -0
  19. package/src/tasks/nlp/conversational.ts +81 -0
  20. package/src/tasks/nlp/featureExtraction.ts +51 -0
  21. package/src/tasks/nlp/fillMask.ts +48 -0
  22. package/src/tasks/nlp/questionAnswering.ts +48 -0
  23. package/src/tasks/nlp/sentenceSimilarity.ts +36 -0
  24. package/src/tasks/nlp/summarization.ts +59 -0
  25. package/src/tasks/nlp/tableQuestionAnswering.ts +58 -0
  26. package/src/tasks/nlp/textClassification.ts +37 -0
  27. package/src/tasks/nlp/textGeneration.ts +67 -0
  28. package/src/tasks/nlp/textGenerationStream.ts +92 -0
  29. package/src/tasks/nlp/tokenClassification.ts +78 -0
  30. package/src/tasks/nlp/translation.ts +29 -0
  31. package/src/tasks/nlp/zeroShotClassification.ts +55 -0
  32. package/src/types.ts +42 -0
  33. package/src/utils/distributive-omit.d.ts +15 -0
  34. package/dist/index.d.ts +0 -677
package/dist/index.mjs CHANGED
@@ -1,9 +1,94 @@
1
- // src/utils/toArray.ts
2
- function toArray(obj) {
3
- if (Array.isArray(obj)) {
4
- return obj;
1
+ var __defProp = Object.defineProperty;
2
+ var __export = (target, all) => {
3
+ for (var name in all)
4
+ __defProp(target, name, { get: all[name], enumerable: true });
5
+ };
6
+
7
+ // src/tasks/index.ts
8
+ var tasks_exports = {};
9
+ __export(tasks_exports, {
10
+ audioClassification: () => audioClassification,
11
+ automaticSpeechRecognition: () => automaticSpeechRecognition,
12
+ conversational: () => conversational,
13
+ featureExtraction: () => featureExtraction,
14
+ fillMask: () => fillMask,
15
+ imageClassification: () => imageClassification,
16
+ imageSegmentation: () => imageSegmentation,
17
+ imageToText: () => imageToText,
18
+ objectDetection: () => objectDetection,
19
+ questionAnswering: () => questionAnswering,
20
+ request: () => request,
21
+ sentenceSimilarity: () => sentenceSimilarity,
22
+ streamingRequest: () => streamingRequest,
23
+ summarization: () => summarization,
24
+ tableQuestionAnswering: () => tableQuestionAnswering,
25
+ textClassification: () => textClassification,
26
+ textGeneration: () => textGeneration,
27
+ textGenerationStream: () => textGenerationStream,
28
+ textToImage: () => textToImage,
29
+ tokenClassification: () => tokenClassification,
30
+ translation: () => translation,
31
+ zeroShotClassification: () => zeroShotClassification
32
+ });
33
+
34
+ // src/lib/makeRequestOptions.ts
35
+ var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co/models/";
36
+ function makeRequestOptions(args, options) {
37
+ const { model, accessToken, ...otherArgs } = args;
38
+ const headers = {};
39
+ if (accessToken) {
40
+ headers["Authorization"] = `Bearer ${accessToken}`;
5
41
  }
6
- return [obj];
42
+ const binary = "data" in args && !!args.data;
43
+ if (!binary) {
44
+ headers["Content-Type"] = "application/json";
45
+ } else {
46
+ if (options?.wait_for_model) {
47
+ headers["X-Wait-For-Model"] = "true";
48
+ }
49
+ if (options?.use_cache === false) {
50
+ headers["X-Use-Cache"] = "false";
51
+ }
52
+ if (options?.dont_load_model) {
53
+ headers["X-Load-Model"] = "0";
54
+ }
55
+ }
56
+ const url = /^http(s?):/.test(model) ? model : `${HF_INFERENCE_API_BASE_URL}${model}`;
57
+ const info = {
58
+ headers,
59
+ method: "POST",
60
+ body: binary ? args.data : JSON.stringify({
61
+ ...otherArgs,
62
+ options
63
+ }),
64
+ credentials: options?.includeCredentials ? "include" : "same-origin"
65
+ };
66
+ return { url, info };
67
+ }
68
+
69
+ // src/tasks/custom/request.ts
70
+ async function request(args, options) {
71
+ const { url, info } = makeRequestOptions(args, options);
72
+ const response = await fetch(url, info);
73
+ if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
74
+ return request(args, {
75
+ ...options,
76
+ wait_for_model: true
77
+ });
78
+ }
79
+ if (!response.ok) {
80
+ if (response.headers.get("Content-Type")?.startsWith("application/json")) {
81
+ const output = await response.json();
82
+ if (output.error) {
83
+ throw new Error(output.error);
84
+ }
85
+ }
86
+ throw new Error("An error occurred while fetching the blob");
87
+ }
88
+ if (response.headers.get("Content-Type")?.startsWith("application/json")) {
89
+ return await response.json();
90
+ }
91
+ return await response.blob();
7
92
  }
8
93
 
9
94
  // src/vendor/fetch-event-source/parse.ts
@@ -105,389 +190,374 @@ function newMessage() {
105
190
  };
106
191
  }
107
192
 
108
- // src/HfInference.ts
109
- var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co/models/";
110
- var TextGenerationStreamFinishReason = /* @__PURE__ */ ((TextGenerationStreamFinishReason2) => {
111
- TextGenerationStreamFinishReason2["Length"] = "length";
112
- TextGenerationStreamFinishReason2["EndOfSequenceToken"] = "eos_token";
113
- TextGenerationStreamFinishReason2["StopSequence"] = "stop_sequence";
114
- return TextGenerationStreamFinishReason2;
115
- })(TextGenerationStreamFinishReason || {});
116
- var HfInference = class {
117
- apiKey;
118
- defaultOptions;
119
- constructor(apiKey = "", defaultOptions = {}) {
120
- this.apiKey = apiKey;
121
- this.defaultOptions = defaultOptions;
193
+ // src/tasks/custom/streamingRequest.ts
194
+ async function* streamingRequest(args, options) {
195
+ const { url, info } = makeRequestOptions({ ...args, stream: true }, options);
196
+ const response = await fetch(url, info);
197
+ if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
198
+ return streamingRequest(args, {
199
+ ...options,
200
+ wait_for_model: true
201
+ });
122
202
  }
123
- /**
124
- * Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
125
- */
126
- async fillMask(args, options) {
127
- const res = await this.request(args, options);
128
- const isValidOutput = Array.isArray(res) && res.every(
129
- (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
130
- );
131
- if (!isValidOutput) {
132
- throw new TypeError(
133
- "Invalid inference output: output must be of type Array<score: number, sequence:string, token:number, token_str:string>"
134
- );
203
+ if (!response.ok) {
204
+ if (response.headers.get("Content-Type")?.startsWith("application/json")) {
205
+ const output = await response.json();
206
+ if (output.error) {
207
+ throw new Error(output.error);
208
+ }
135
209
  }
136
- return res;
210
+ throw new Error(`Server response contains error: ${response.status}`);
137
211
  }
138
- /**
139
- * This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.
140
- */
141
- async summarization(args, options) {
142
- const res = await this.request(args, options);
143
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.summary_text === "string");
144
- if (!isValidOutput) {
145
- throw new TypeError("Invalid inference output: output must be of type Array<summary_text: string>");
146
- }
147
- return res?.[0];
212
+ if (response.headers.get("content-type") !== "text/event-stream") {
213
+ throw new Error(
214
+ `Server does not support event stream content type, it returned ` + response.headers.get("content-type")
215
+ );
148
216
  }
149
- /**
150
- * Want to have a nice know-it-all bot that can answer any question?. Recommended model: deepset/roberta-base-squad2
151
- */
152
- async questionAnswer(args, options) {
153
- const res = await this.request(args, options);
154
- const isValidOutput = typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";
155
- if (!isValidOutput) {
156
- throw new TypeError(
157
- "Invalid inference output: output must be of type <answer: string, end: number, score: number, start: number>"
158
- );
159
- }
160
- return res;
217
+ if (!response.body) {
218
+ return;
161
219
  }
162
- /**
163
- * Don’t know SQL? Don’t want to dive into a large spreadsheet? Ask questions in plain english! Recommended model: google/tapas-base-finetuned-wtq.
164
- */
165
- async tableQuestionAnswer(args, options) {
166
- const res = await this.request(args, options);
167
- const isValidOutput = typeof res.aggregator === "string" && typeof res.answer === "string" && Array.isArray(res.cells) && res.cells.every((x) => typeof x === "string") && Array.isArray(res.coordinates) && res.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number"));
168
- if (!isValidOutput) {
169
- throw new TypeError(
170
- "Invalid inference output: output must be of type <aggregator: string, answer: string, cells: string[], coordinates: number[][]>"
171
- );
220
+ const reader = response.body.getReader();
221
+ let events = [];
222
+ const onEvent = (event) => {
223
+ events.push(event);
224
+ };
225
+ const onChunk = getLines(
226
+ getMessages(
227
+ () => {
228
+ },
229
+ () => {
230
+ },
231
+ onEvent
232
+ )
233
+ );
234
+ try {
235
+ while (true) {
236
+ const { done, value } = await reader.read();
237
+ if (done)
238
+ return;
239
+ onChunk(value);
240
+ for (const event of events) {
241
+ if (event.data.length > 0) {
242
+ yield JSON.parse(event.data);
243
+ }
244
+ }
245
+ events = [];
172
246
  }
173
- return res;
247
+ } finally {
248
+ reader.releaseLock();
174
249
  }
175
- /**
176
- * Usually used for sentiment-analysis this will output the likelihood of classes of an input. Recommended model: distilbert-base-uncased-finetuned-sst-2-english
177
- */
178
- async textClassification(args, options) {
179
- const res = (await this.request(args, options))?.[0];
180
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
181
- if (!isValidOutput) {
182
- throw new TypeError("Invalid inference output: output must be of type Array<label: string, score: number>");
183
- }
184
- return res;
250
+ }
251
+
252
+ // src/lib/InferenceOutputError.ts
253
+ var InferenceOutputError = class extends TypeError {
254
+ constructor(message) {
255
+ super(
256
+ `Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.`
257
+ );
258
+ this.name = "InferenceOutputError";
185
259
  }
186
- /**
187
- * Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
188
- */
189
- async textGeneration(args, options) {
190
- const res = await this.request(args, options);
191
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.generated_text === "string");
192
- if (!isValidOutput) {
193
- throw new TypeError("Invalid inference output: output must be of type Array<generated_text: string>");
194
- }
195
- return res?.[0];
260
+ };
261
+
262
+ // src/tasks/audio/audioClassification.ts
263
+ async function audioClassification(args, options) {
264
+ const res = await request(args, options);
265
+ const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
266
+ if (!isValidOutput) {
267
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
196
268
  }
197
- /**
198
- * Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time
199
- */
200
- async *textGenerationStream(args, options) {
201
- yield* this.streamingRequest(args, options);
269
+ return res;
270
+ }
271
+
272
+ // src/tasks/audio/automaticSpeechRecognition.ts
273
+ async function automaticSpeechRecognition(args, options) {
274
+ const res = await request(args, options);
275
+ const isValidOutput = typeof res?.text === "string";
276
+ if (!isValidOutput) {
277
+ throw new InferenceOutputError("Expected {text: string}");
202
278
  }
203
- /**
204
- * Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. Recommended model: dbmdz/bert-large-cased-finetuned-conll03-english
205
- */
206
- async tokenClassification(args, options) {
207
- const res = toArray(await this.request(args, options));
208
- const isValidOutput = Array.isArray(res) && res.every(
209
- (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
210
- );
211
- if (!isValidOutput) {
212
- throw new TypeError(
213
- "Invalid inference output: output must be of type Array<end: number, entity_group: string, score: number, start: number, word: string>"
214
- );
215
- }
216
- return res;
279
+ return res;
280
+ }
281
+
282
+ // src/tasks/cv/imageClassification.ts
283
+ async function imageClassification(args, options) {
284
+ const res = await request(args, options);
285
+ const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
286
+ if (!isValidOutput) {
287
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
217
288
  }
218
- /**
219
- * This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en.
220
- */
221
- async translation(args, options) {
222
- const res = await this.request(args, options);
223
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.translation_text === "string");
224
- if (!isValidOutput) {
225
- throw new TypeError("Invalid inference output: output must be of type Array<translation_text: string>");
226
- }
227
- return res?.[0];
289
+ return res;
290
+ }
291
+
292
+ // src/tasks/cv/imageSegmentation.ts
293
+ async function imageSegmentation(args, options) {
294
+ const res = await request(args, options);
295
+ const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
296
+ if (!isValidOutput) {
297
+ throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
228
298
  }
229
- /**
230
- * This task is super useful to try out classification with zero code, you simply pass a sentence/paragraph and the possible labels for that sentence, and you get a result. Recommended model: facebook/bart-large-mnli.
231
- */
232
- async zeroShotClassification(args, options) {
233
- const res = toArray(
234
- await this.request(args, options)
299
+ return res;
300
+ }
301
+
302
+ // src/tasks/cv/imageToText.ts
303
+ async function imageToText(args, options) {
304
+ const res = (await request(args, options))?.[0];
305
+ if (typeof res?.generated_text !== "string") {
306
+ throw new InferenceOutputError("Expected {generated_text: string}");
307
+ }
308
+ return res;
309
+ }
310
+
311
+ // src/tasks/cv/objectDetection.ts
312
+ async function objectDetection(args, options) {
313
+ const res = await request(args, options);
314
+ const isValidOutput = Array.isArray(res) && res.every(
315
+ (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
316
+ );
317
+ if (!isValidOutput) {
318
+ throw new InferenceOutputError(
319
+ "Expected Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
235
320
  );
236
- const isValidOutput = Array.isArray(res) && res.every(
237
- (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
321
+ }
322
+ return res;
323
+ }
324
+
325
+ // src/tasks/cv/textToImage.ts
326
+ async function textToImage(args, options) {
327
+ const res = await request(args, options);
328
+ const isValidOutput = res && res instanceof Blob;
329
+ if (!isValidOutput) {
330
+ throw new InferenceOutputError("Expected Blob");
331
+ }
332
+ return res;
333
+ }
334
+
335
+ // src/tasks/nlp/conversational.ts
336
+ async function conversational(args, options) {
337
+ const res = await request(args, options);
338
+ const isValidOutput = Array.isArray(res.conversation.generated_responses) && res.conversation.generated_responses.every((x) => typeof x === "string") && Array.isArray(res.conversation.past_user_inputs) && res.conversation.past_user_inputs.every((x) => typeof x === "string") && typeof res.generated_text === "string" && Array.isArray(res.warnings) && res.warnings.every((x) => typeof x === "string");
339
+ if (!isValidOutput) {
340
+ throw new InferenceOutputError(
341
+ "Expected {conversation: {generated_responses: string[], past_user_inputs: string[]}, generated_text: string, warnings: string[]}"
238
342
  );
239
- if (!isValidOutput) {
240
- throw new TypeError(
241
- "Invalid inference output: output must be of type Array<labels: string[], scores: number[], sequence: string>"
242
- );
243
- }
244
- return res;
245
343
  }
246
- /**
247
- * This task corresponds to any chatbot like structure. Models tend to have shorter max_length, so please check with caution when using a given model if you need long range dependency or not. Recommended model: microsoft/DialoGPT-large.
248
- *
249
- */
250
- async conversational(args, options) {
251
- const res = await this.request(args, options);
252
- const isValidOutput = Array.isArray(res.conversation.generated_responses) && res.conversation.generated_responses.every((x) => typeof x === "string") && Array.isArray(res.conversation.past_user_inputs) && res.conversation.past_user_inputs.every((x) => typeof x === "string") && typeof res.generated_text === "string" && Array.isArray(res.warnings) && res.warnings.every((x) => typeof x === "string");
253
- if (!isValidOutput) {
254
- throw new TypeError(
255
- "Invalid inference output: output must be of type <conversation: {generated_responses: string[], past_user_inputs: string[]}, generated_text: string, warnings: string[]>"
256
- );
344
+ return res;
345
+ }
346
+
347
+ // src/tasks/nlp/featureExtraction.ts
348
+ async function featureExtraction(args, options) {
349
+ const res = await request(args, options);
350
+ let isValidOutput = true;
351
+ if (Array.isArray(res)) {
352
+ for (const e of res) {
353
+ if (Array.isArray(e)) {
354
+ isValidOutput = e.every((x) => typeof x === "number");
355
+ if (!isValidOutput) {
356
+ break;
357
+ }
358
+ } else if (typeof e !== "number") {
359
+ isValidOutput = false;
360
+ break;
361
+ }
257
362
  }
258
- return res;
363
+ } else {
364
+ isValidOutput = false;
259
365
  }
260
- /**
261
- * This task reads some text and outputs raw float values, that are usually consumed as part of a semantic database/semantic search.
262
- */
263
- async featureExtraction(args, options) {
264
- const res = await this.request(args, options);
265
- return res;
366
+ if (!isValidOutput) {
367
+ throw new InferenceOutputError("Expected Array<number[] | number>");
266
368
  }
267
- /**
268
- * This task reads some audio input and outputs the said words within the audio files.
269
- * Recommended model (english language): facebook/wav2vec2-large-960h-lv60-self
270
- */
271
- async automaticSpeechRecognition(args, options) {
272
- const res = await this.request(args, {
273
- ...options,
274
- binary: true
275
- });
276
- const isValidOutput = typeof res.text === "string";
277
- if (!isValidOutput) {
278
- throw new TypeError("Invalid inference output: output must be of type <text: string>");
279
- }
280
- return res;
369
+ return res;
370
+ }
371
+
372
+ // src/tasks/nlp/fillMask.ts
373
+ async function fillMask(args, options) {
374
+ const res = await request(args, options);
375
+ const isValidOutput = Array.isArray(res) && res.every(
376
+ (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
377
+ );
378
+ if (!isValidOutput) {
379
+ throw new InferenceOutputError(
380
+ "Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
381
+ );
281
382
  }
282
- /**
283
- * This task reads some audio input and outputs the likelihood of classes.
284
- * Recommended model: superb/hubert-large-superb-er
285
- */
286
- async audioClassification(args, options) {
287
- const res = await this.request(args, {
288
- ...options,
289
- binary: true
290
- });
291
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
292
- if (!isValidOutput) {
293
- throw new TypeError("Invalid inference output: output must be of type Array<label: string, score: number>");
294
- }
295
- return res;
383
+ return res;
384
+ }
385
+
386
+ // src/tasks/nlp/questionAnswering.ts
387
+ async function questionAnswering(args, options) {
388
+ const res = await request(args, options);
389
+ const isValidOutput = typeof res?.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";
390
+ if (!isValidOutput) {
391
+ throw new InferenceOutputError("Expected {answer: string, end: number, score: number, start: number}");
296
392
  }
297
- /**
298
- * This task reads some image input and outputs the likelihood of classes.
299
- * Recommended model: google/vit-base-patch16-224
300
- */
301
- async imageClassification(args, options) {
302
- const res = await this.request(args, {
303
- ...options,
304
- binary: true
305
- });
306
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
307
- if (!isValidOutput) {
308
- throw new TypeError("Invalid inference output: output must be of type Array<label: string, score: number>");
309
- }
310
- return res;
393
+ return res;
394
+ }
395
+
396
+ // src/tasks/nlp/sentenceSimilarity.ts
397
+ async function sentenceSimilarity(args, options) {
398
+ const res = await request(args, options);
399
+ const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
400
+ if (!isValidOutput) {
401
+ throw new InferenceOutputError("Expected number[]");
311
402
  }
312
- /**
313
- * This task reads some image input and outputs the likelihood of classes & bounding boxes of detected objects.
314
- * Recommended model: facebook/detr-resnet-50
315
- */
316
- async objectDetection(args, options) {
317
- const res = await this.request(args, {
318
- ...options,
319
- binary: true
320
- });
321
- const isValidOutput = Array.isArray(res) && res.every(
322
- (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
403
+ return res;
404
+ }
405
+
406
+ // src/tasks/nlp/summarization.ts
407
+ async function summarization(args, options) {
408
+ const res = await request(args, options);
409
+ const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string");
410
+ if (!isValidOutput) {
411
+ throw new InferenceOutputError("Expected Array<{summary_text: string}>");
412
+ }
413
+ return res?.[0];
414
+ }
415
+
416
+ // src/tasks/nlp/tableQuestionAnswering.ts
417
+ async function tableQuestionAnswering(args, options) {
418
+ const res = await request(args, options);
419
+ const isValidOutput = typeof res?.aggregator === "string" && typeof res.answer === "string" && Array.isArray(res.cells) && res.cells.every((x) => typeof x === "string") && Array.isArray(res.coordinates) && res.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number"));
420
+ if (!isValidOutput) {
421
+ throw new InferenceOutputError(
422
+ "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
323
423
  );
324
- if (!isValidOutput) {
325
- throw new TypeError(
326
- "Invalid inference output: output must be of type Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
327
- );
328
- }
329
- return res;
330
424
  }
331
- /**
332
- * This task reads some image input and outputs the likelihood of classes & bounding boxes of detected objects.
333
- * Recommended model: facebook/detr-resnet-50-panoptic
334
- */
335
- async imageSegmentation(args, options) {
336
- const res = await this.request(args, {
337
- ...options,
338
- binary: true
339
- });
340
- const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
341
- if (!isValidOutput) {
342
- throw new TypeError(
343
- "Invalid inference output: output must be of type Array<label: string, mask: string, score: number>"
344
- );
345
- }
346
- return res;
425
+ return res;
426
+ }
427
+
428
+ // src/tasks/nlp/textClassification.ts
429
+ async function textClassification(args, options) {
430
+ const res = (await request(args, options))?.[0];
431
+ const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number");
432
+ if (!isValidOutput) {
433
+ throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
347
434
  }
348
- /**
349
- * This task reads some text input and outputs an image.
350
- * Recommended model: stabilityai/stable-diffusion-2
351
- */
352
- async textToImage(args, options) {
353
- const res = await this.request(args, {
354
- ...options,
355
- blob: true
356
- });
357
- const isValidOutput = res && res instanceof Blob;
358
- if (!isValidOutput) {
359
- throw new TypeError("Invalid inference output: output must be of type object & of instance Blob");
360
- }
361
- return res;
435
+ return res;
436
+ }
437
+
438
+ // src/tasks/nlp/textGeneration.ts
439
+ async function textGeneration(args, options) {
440
+ const res = await request(args, options);
441
+ const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string");
442
+ if (!isValidOutput) {
443
+ throw new InferenceOutputError("Expected Array<{generated_text: string}>");
362
444
  }
363
- /**
364
- * This task reads some image input and outputs the text caption.
365
- */
366
- async imageToText(args, options) {
367
- return (await this.request(args, {
368
- ...options,
369
- binary: true
370
- }))?.[0];
445
+ return res?.[0];
446
+ }
447
+
448
+ // src/tasks/nlp/textGenerationStream.ts
449
+ async function* textGenerationStream(args, options) {
450
+ yield* streamingRequest(args, options);
451
+ }
452
+
453
+ // src/utils/toArray.ts
454
+ function toArray(obj) {
455
+ if (Array.isArray(obj)) {
456
+ return obj;
371
457
  }
372
- /**
373
- * Helper that prepares request arguments
374
- */
375
- makeRequestOptions(args, options) {
376
- const mergedOptions = { ...this.defaultOptions, ...options };
377
- const { model, ...otherArgs } = args;
378
- const headers = {};
379
- if (this.apiKey) {
380
- headers["Authorization"] = `Bearer ${this.apiKey}`;
381
- }
382
- if (!options?.binary) {
383
- headers["Content-Type"] = "application/json";
384
- }
385
- if (options?.binary) {
386
- if (mergedOptions.wait_for_model) {
387
- headers["X-Wait-For-Model"] = "true";
388
- }
389
- if (mergedOptions.use_cache === false) {
390
- headers["X-Use-Cache"] = "false";
391
- }
392
- if (mergedOptions.dont_load_model) {
393
- headers["X-Load-Model"] = "0";
394
- }
395
- }
396
- const url = `${HF_INFERENCE_API_BASE_URL}${model}`;
397
- const info = {
398
- headers,
399
- method: "POST",
400
- body: options?.binary ? args.data : JSON.stringify({
401
- ...otherArgs,
402
- options: mergedOptions
403
- }),
404
- credentials: options?.includeCredentials ? "include" : "same-origin"
405
- };
406
- return { url, info, mergedOptions };
407
- }
408
- async request(args, options) {
409
- const { url, info, mergedOptions } = this.makeRequestOptions(args, options);
410
- const response = await fetch(url, info);
411
- if (mergedOptions.retry_on_error !== false && response.status === 503 && !mergedOptions.wait_for_model) {
412
- return this.request(args, {
413
- ...mergedOptions,
414
- wait_for_model: true
458
+ return [obj];
459
+ }
460
+
461
+ // src/tasks/nlp/tokenClassification.ts
462
+ async function tokenClassification(args, options) {
463
+ const res = toArray(await request(args, options));
464
+ const isValidOutput = Array.isArray(res) && res.every(
465
+ (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
466
+ );
467
+ if (!isValidOutput) {
468
+ throw new InferenceOutputError(
469
+ "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
470
+ );
471
+ }
472
+ return res;
473
+ }
474
+
475
+ // src/tasks/nlp/translation.ts
476
+ async function translation(args, options) {
477
+ const res = await request(args, options);
478
+ const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");
479
+ if (!isValidOutput) {
480
+ throw new InferenceOutputError("Expected type Array<{translation_text: string}>");
481
+ }
482
+ return res?.[0];
483
+ }
484
+
485
+ // src/tasks/nlp/zeroShotClassification.ts
486
+ async function zeroShotClassification(args, options) {
487
+ const res = toArray(
488
+ await request(args, options)
489
+ );
490
+ const isValidOutput = Array.isArray(res) && res.every(
491
+ (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
492
+ );
493
+ if (!isValidOutput) {
494
+ throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
495
+ }
496
+ return res;
497
+ }
498
+
499
+ // src/HfInference.ts
500
+ var HfInference = class {
501
+ accessToken;
502
+ defaultOptions;
503
+ constructor(accessToken = "", defaultOptions = {}) {
504
+ this.accessToken = accessToken;
505
+ this.defaultOptions = defaultOptions;
506
+ for (const [name, fn] of Object.entries(tasks_exports)) {
507
+ Object.defineProperty(this, name, {
508
+ enumerable: false,
509
+ value: (params, options) => (
510
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
511
+ fn({ ...params, accessToken }, { ...defaultOptions, ...options })
512
+ )
415
513
  });
416
514
  }
417
- if (options?.blob) {
418
- if (!response.ok) {
419
- throw new Error("An error occurred while fetching the blob");
420
- }
421
- return await response.blob();
422
- }
423
- const output = await response.json();
424
- if (output.error) {
425
- throw new Error(output.error);
426
- }
427
- return output;
428
515
  }
429
516
  /**
430
- * Make request that uses server-sent events and returns response as a generator
517
+ * Returns copy of HfInference tied to a specified endpoint.
431
518
  */
432
- async *streamingRequest(args, options) {
433
- const { url, info, mergedOptions } = this.makeRequestOptions({ ...args, stream: true }, options);
434
- const response = await fetch(url, info);
435
- if (mergedOptions.retry_on_error !== false && response.status === 503 && !mergedOptions.wait_for_model) {
436
- return this.streamingRequest(args, {
437
- ...mergedOptions,
438
- wait_for_model: true
519
+ endpoint(endpointUrl) {
520
+ return new HfInferenceEndpoint(endpointUrl, this.accessToken, this.defaultOptions);
521
+ }
522
+ };
523
+ var HfInferenceEndpoint = class {
524
+ constructor(endpointUrl, accessToken = "", defaultOptions = {}) {
525
+ accessToken;
526
+ defaultOptions;
527
+ for (const [name, fn] of Object.entries(tasks_exports)) {
528
+ Object.defineProperty(this, name, {
529
+ enumerable: false,
530
+ value: (params, options) => (
531
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
532
+ fn({ ...params, accessToken, model: endpointUrl }, { ...defaultOptions, ...options })
533
+ )
439
534
  });
440
535
  }
441
- if (!response.ok) {
442
- if (response.headers.get("Content-Type")?.startsWith("application/json")) {
443
- const output = await response.json();
444
- if (output.error) {
445
- throw new Error(output.error);
446
- }
447
- }
448
- throw new Error(`Server response contains error: ${response.status}`);
449
- }
450
- if (response.headers.get("content-type") !== "text/event-stream") {
451
- throw new Error(
452
- `Server does not support event stream content type, it returned ` + response.headers.get("content-type")
453
- );
454
- }
455
- if (!response.body) {
456
- return;
457
- }
458
- const reader = response.body.getReader();
459
- let events = [];
460
- const onEvent = (event) => {
461
- events.push(event);
462
- };
463
- const onChunk = getLines(
464
- getMessages(
465
- () => {
466
- },
467
- () => {
468
- },
469
- onEvent
470
- )
471
- );
472
- try {
473
- while (true) {
474
- const { done, value } = await reader.read();
475
- if (done)
476
- return;
477
- onChunk(value);
478
- for (const event of events) {
479
- if (event.data.length > 0) {
480
- yield JSON.parse(event.data);
481
- }
482
- }
483
- events = [];
484
- }
485
- } finally {
486
- reader.releaseLock();
487
- }
488
536
  }
489
537
  };
490
538
  export {
491
539
  HfInference,
492
- TextGenerationStreamFinishReason
540
+ HfInferenceEndpoint,
541
+ audioClassification,
542
+ automaticSpeechRecognition,
543
+ conversational,
544
+ featureExtraction,
545
+ fillMask,
546
+ imageClassification,
547
+ imageSegmentation,
548
+ imageToText,
549
+ objectDetection,
550
+ questionAnswering,
551
+ request,
552
+ sentenceSimilarity,
553
+ streamingRequest,
554
+ summarization,
555
+ tableQuestionAnswering,
556
+ textClassification,
557
+ textGeneration,
558
+ textGenerationStream,
559
+ textToImage,
560
+ tokenClassification,
561
+ translation,
562
+ zeroShotClassification
493
563
  };