@huggingface/inference 1.6.2 → 1.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +7 -0
- package/dist/index.d.ts +94 -1
- package/dist/index.js +299 -22
- package/dist/index.mjs +297 -21
- package/package.json +1 -1
- package/src/HfInference.ts +355 -23
- package/src/vendor/fetch-event-source/parse.spec.ts +389 -0
- package/src/vendor/fetch-event-source/parse.ts +216 -0
package/README.md
CHANGED
|
@@ -76,6 +76,13 @@ await hf.textGeneration({
|
|
|
76
76
|
inputs: 'The answer to the universe is'
|
|
77
77
|
})
|
|
78
78
|
|
|
79
|
+
for await const (output of hf.textGenerationStream({
|
|
80
|
+
model: "google/flan-t5-xxl",
|
|
81
|
+
inputs: 'repeat "one two three four"'
|
|
82
|
+
})) {
|
|
83
|
+
console.log(output.token.text, output.generated_text);
|
|
84
|
+
}
|
|
85
|
+
|
|
79
86
|
await hf.tokenClassification({
|
|
80
87
|
model: 'dbmdz/bert-large-cased-finetuned-conll03-english',
|
|
81
88
|
inputs: 'My name is Sarah Jessica Parker but you can call me Jessica'
|
package/dist/index.d.ts
CHANGED
|
@@ -206,6 +206,80 @@ interface TextGenerationReturn {
|
|
|
206
206
|
*/
|
|
207
207
|
generated_text: string;
|
|
208
208
|
}
|
|
209
|
+
interface TextGenerationStreamToken {
|
|
210
|
+
/** Token ID from the model tokenizer */
|
|
211
|
+
id: number;
|
|
212
|
+
/** Token text */
|
|
213
|
+
text: string;
|
|
214
|
+
/** Logprob */
|
|
215
|
+
logprob: number;
|
|
216
|
+
/**
|
|
217
|
+
* Is the token a special token
|
|
218
|
+
* Can be used to ignore tokens when concatenating
|
|
219
|
+
*/
|
|
220
|
+
special: boolean;
|
|
221
|
+
}
|
|
222
|
+
interface TextGenerationStreamPrefillToken {
|
|
223
|
+
/** Token ID from the model tokenizer */
|
|
224
|
+
id: number;
|
|
225
|
+
/** Token text */
|
|
226
|
+
text: string;
|
|
227
|
+
/**
|
|
228
|
+
* Logprob
|
|
229
|
+
* Optional since the logprob of the first token cannot be computed
|
|
230
|
+
*/
|
|
231
|
+
logprob?: number;
|
|
232
|
+
}
|
|
233
|
+
interface TextGenerationStreamBestOfSequence {
|
|
234
|
+
/** Generated text */
|
|
235
|
+
generated_text: string;
|
|
236
|
+
/** Generation finish reason */
|
|
237
|
+
finish_reason: TextGenerationStreamFinishReason;
|
|
238
|
+
/** Number of generated tokens */
|
|
239
|
+
generated_tokens: number;
|
|
240
|
+
/** Sampling seed if sampling was activated */
|
|
241
|
+
seed?: number;
|
|
242
|
+
/** Prompt tokens */
|
|
243
|
+
prefill: TextGenerationStreamPrefillToken[];
|
|
244
|
+
/** Generated tokens */
|
|
245
|
+
tokens: TextGenerationStreamToken[];
|
|
246
|
+
}
|
|
247
|
+
declare enum TextGenerationStreamFinishReason {
|
|
248
|
+
/** number of generated tokens == `max_new_tokens` */
|
|
249
|
+
Length = "length",
|
|
250
|
+
/** the model generated its end of sequence token */
|
|
251
|
+
EndOfSequenceToken = "eos_token",
|
|
252
|
+
/** the model generated a text included in `stop_sequences` */
|
|
253
|
+
StopSequence = "stop_sequence"
|
|
254
|
+
}
|
|
255
|
+
interface TextGenerationStreamDetails {
|
|
256
|
+
/** Generation finish reason */
|
|
257
|
+
finish_reason: TextGenerationStreamFinishReason;
|
|
258
|
+
/** Number of generated tokens */
|
|
259
|
+
generated_tokens: number;
|
|
260
|
+
/** Sampling seed if sampling was activated */
|
|
261
|
+
seed?: number;
|
|
262
|
+
/** Prompt tokens */
|
|
263
|
+
prefill: TextGenerationStreamPrefillToken[];
|
|
264
|
+
/** */
|
|
265
|
+
tokens: TextGenerationStreamToken[];
|
|
266
|
+
/** Additional sequences when using the `best_of` parameter */
|
|
267
|
+
best_of_sequences?: TextGenerationStreamBestOfSequence[];
|
|
268
|
+
}
|
|
269
|
+
interface TextGenerationStreamReturn {
|
|
270
|
+
/** Generated token, one at a time */
|
|
271
|
+
token: TextGenerationStreamToken;
|
|
272
|
+
/**
|
|
273
|
+
* Complete generated text
|
|
274
|
+
* Only available when the generation is finished
|
|
275
|
+
*/
|
|
276
|
+
generated_text?: string;
|
|
277
|
+
/**
|
|
278
|
+
* Generation details
|
|
279
|
+
* Only available when the generation is finished
|
|
280
|
+
*/
|
|
281
|
+
details?: TextGenerationStreamDetails;
|
|
282
|
+
}
|
|
209
283
|
type TokenClassificationArgs = Args & {
|
|
210
284
|
/**
|
|
211
285
|
* A string to be classified
|
|
@@ -486,6 +560,10 @@ declare class HfInference {
|
|
|
486
560
|
* Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
|
|
487
561
|
*/
|
|
488
562
|
textGeneration(args: TextGenerationArgs, options?: Options): Promise<TextGenerationReturn>;
|
|
563
|
+
/**
|
|
564
|
+
* Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time
|
|
565
|
+
*/
|
|
566
|
+
textGenerationStream(args: TextGenerationArgs, options?: Options): AsyncGenerator<TextGenerationStreamReturn>;
|
|
489
567
|
/**
|
|
490
568
|
* Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. Recommended model: dbmdz/bert-large-cased-finetuned-conll03-english
|
|
491
569
|
*/
|
|
@@ -537,6 +615,10 @@ declare class HfInference {
|
|
|
537
615
|
* Recommended model: stabilityai/stable-diffusion-2
|
|
538
616
|
*/
|
|
539
617
|
textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageReturn>;
|
|
618
|
+
/**
|
|
619
|
+
* Helper that prepares request arguments
|
|
620
|
+
*/
|
|
621
|
+
private makeRequestOptions;
|
|
540
622
|
request<T>(args: Args & {
|
|
541
623
|
data?: Blob | ArrayBuffer;
|
|
542
624
|
}, options?: Options & {
|
|
@@ -545,6 +627,17 @@ declare class HfInference {
|
|
|
545
627
|
/** For internal HF use, which is why it's not exposed in {@link Options} */
|
|
546
628
|
includeCredentials?: boolean;
|
|
547
629
|
}): Promise<T>;
|
|
630
|
+
/**
|
|
631
|
+
* Make request that uses server-sent events and returns response as a generator
|
|
632
|
+
*/
|
|
633
|
+
streamingRequest<T>(args: Args & {
|
|
634
|
+
data?: Blob | ArrayBuffer;
|
|
635
|
+
}, options?: Options & {
|
|
636
|
+
binary?: boolean;
|
|
637
|
+
blob?: boolean;
|
|
638
|
+
/** For internal HF use, which is why it's not exposed in {@link Options} */
|
|
639
|
+
includeCredentials?: boolean;
|
|
640
|
+
}): AsyncGenerator<T>;
|
|
548
641
|
}
|
|
549
642
|
|
|
550
|
-
export { Args, AudioClassificationArgs, AudioClassificationReturn, AudioClassificationReturnValue, AutomaticSpeechRecognitionArgs, AutomaticSpeechRecognitionReturn, ConversationalArgs, ConversationalReturn, FeatureExtractionArgs, FeatureExtractionReturn, FillMaskArgs, FillMaskReturn, HfInference, ImageClassificationArgs, ImageClassificationReturn, ImageClassificationReturnValue, ImageSegmentationArgs, ImageSegmentationReturn, ImageSegmentationReturnValue, ObjectDetectionArgs, ObjectDetectionReturn, ObjectDetectionReturnValue, Options, QuestionAnswerArgs, QuestionAnswerReturn, SummarizationArgs, SummarizationReturn, TableQuestionAnswerArgs, TableQuestionAnswerReturn, TextClassificationArgs, TextClassificationReturn, TextGenerationArgs, TextGenerationReturn, TextToImageArgs, TextToImageReturn, TokenClassificationArgs, TokenClassificationReturn, TokenClassificationReturnValue, TranslationArgs, TranslationReturn, ZeroShotClassificationArgs, ZeroShotClassificationReturn, ZeroShotClassificationReturnValue };
|
|
643
|
+
export { Args, AudioClassificationArgs, AudioClassificationReturn, AudioClassificationReturnValue, AutomaticSpeechRecognitionArgs, AutomaticSpeechRecognitionReturn, ConversationalArgs, ConversationalReturn, FeatureExtractionArgs, FeatureExtractionReturn, FillMaskArgs, FillMaskReturn, HfInference, ImageClassificationArgs, ImageClassificationReturn, ImageClassificationReturnValue, ImageSegmentationArgs, ImageSegmentationReturn, ImageSegmentationReturnValue, ObjectDetectionArgs, ObjectDetectionReturn, ObjectDetectionReturnValue, Options, QuestionAnswerArgs, QuestionAnswerReturn, SummarizationArgs, SummarizationReturn, TableQuestionAnswerArgs, TableQuestionAnswerReturn, TextClassificationArgs, TextClassificationReturn, TextGenerationArgs, TextGenerationReturn, TextGenerationStreamBestOfSequence, TextGenerationStreamDetails, TextGenerationStreamFinishReason, TextGenerationStreamPrefillToken, TextGenerationStreamReturn, TextGenerationStreamToken, TextToImageArgs, TextToImageReturn, TokenClassificationArgs, TokenClassificationReturn, TokenClassificationReturnValue, TranslationArgs, TranslationReturn, ZeroShotClassificationArgs, ZeroShotClassificationReturn, ZeroShotClassificationReturnValue };
|
package/dist/index.js
CHANGED
|
@@ -19,7 +19,8 @@ var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: tru
|
|
|
19
19
|
// src/index.ts
|
|
20
20
|
var src_exports = {};
|
|
21
21
|
__export(src_exports, {
|
|
22
|
-
HfInference: () => HfInference
|
|
22
|
+
HfInference: () => HfInference,
|
|
23
|
+
TextGenerationStreamFinishReason: () => TextGenerationStreamFinishReason
|
|
23
24
|
});
|
|
24
25
|
module.exports = __toCommonJS(src_exports);
|
|
25
26
|
|
|
@@ -31,7 +32,113 @@ function toArray(obj) {
|
|
|
31
32
|
return [obj];
|
|
32
33
|
}
|
|
33
34
|
|
|
35
|
+
// src/vendor/fetch-event-source/parse.ts
|
|
36
|
+
function getLines(onLine) {
|
|
37
|
+
let buffer;
|
|
38
|
+
let position;
|
|
39
|
+
let fieldLength;
|
|
40
|
+
let discardTrailingNewline = false;
|
|
41
|
+
return function onChunk(arr) {
|
|
42
|
+
if (buffer === void 0) {
|
|
43
|
+
buffer = arr;
|
|
44
|
+
position = 0;
|
|
45
|
+
fieldLength = -1;
|
|
46
|
+
} else {
|
|
47
|
+
buffer = concat(buffer, arr);
|
|
48
|
+
}
|
|
49
|
+
const bufLength = buffer.length;
|
|
50
|
+
let lineStart = 0;
|
|
51
|
+
while (position < bufLength) {
|
|
52
|
+
if (discardTrailingNewline) {
|
|
53
|
+
if (buffer[position] === 10 /* NewLine */) {
|
|
54
|
+
lineStart = ++position;
|
|
55
|
+
}
|
|
56
|
+
discardTrailingNewline = false;
|
|
57
|
+
}
|
|
58
|
+
let lineEnd = -1;
|
|
59
|
+
for (; position < bufLength && lineEnd === -1; ++position) {
|
|
60
|
+
switch (buffer[position]) {
|
|
61
|
+
case 58 /* Colon */:
|
|
62
|
+
if (fieldLength === -1) {
|
|
63
|
+
fieldLength = position - lineStart;
|
|
64
|
+
}
|
|
65
|
+
break;
|
|
66
|
+
case 13 /* CarriageReturn */:
|
|
67
|
+
discardTrailingNewline = true;
|
|
68
|
+
case 10 /* NewLine */:
|
|
69
|
+
lineEnd = position;
|
|
70
|
+
break;
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
if (lineEnd === -1) {
|
|
74
|
+
break;
|
|
75
|
+
}
|
|
76
|
+
onLine(buffer.subarray(lineStart, lineEnd), fieldLength);
|
|
77
|
+
lineStart = position;
|
|
78
|
+
fieldLength = -1;
|
|
79
|
+
}
|
|
80
|
+
if (lineStart === bufLength) {
|
|
81
|
+
buffer = void 0;
|
|
82
|
+
} else if (lineStart !== 0) {
|
|
83
|
+
buffer = buffer.subarray(lineStart);
|
|
84
|
+
position -= lineStart;
|
|
85
|
+
}
|
|
86
|
+
};
|
|
87
|
+
}
|
|
88
|
+
function getMessages(onId, onRetry, onMessage) {
|
|
89
|
+
let message = newMessage();
|
|
90
|
+
const decoder = new TextDecoder();
|
|
91
|
+
return function onLine(line, fieldLength) {
|
|
92
|
+
if (line.length === 0) {
|
|
93
|
+
onMessage?.(message);
|
|
94
|
+
message = newMessage();
|
|
95
|
+
} else if (fieldLength > 0) {
|
|
96
|
+
const field = decoder.decode(line.subarray(0, fieldLength));
|
|
97
|
+
const valueOffset = fieldLength + (line[fieldLength + 1] === 32 /* Space */ ? 2 : 1);
|
|
98
|
+
const value = decoder.decode(line.subarray(valueOffset));
|
|
99
|
+
switch (field) {
|
|
100
|
+
case "data":
|
|
101
|
+
message.data = message.data ? message.data + "\n" + value : value;
|
|
102
|
+
break;
|
|
103
|
+
case "event":
|
|
104
|
+
message.event = value;
|
|
105
|
+
break;
|
|
106
|
+
case "id":
|
|
107
|
+
onId(message.id = value);
|
|
108
|
+
break;
|
|
109
|
+
case "retry":
|
|
110
|
+
const retry = parseInt(value, 10);
|
|
111
|
+
if (!isNaN(retry)) {
|
|
112
|
+
onRetry(message.retry = retry);
|
|
113
|
+
}
|
|
114
|
+
break;
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
};
|
|
118
|
+
}
|
|
119
|
+
function concat(a, b) {
|
|
120
|
+
const res = new Uint8Array(a.length + b.length);
|
|
121
|
+
res.set(a);
|
|
122
|
+
res.set(b, a.length);
|
|
123
|
+
return res;
|
|
124
|
+
}
|
|
125
|
+
function newMessage() {
|
|
126
|
+
return {
|
|
127
|
+
data: "",
|
|
128
|
+
event: "",
|
|
129
|
+
id: "",
|
|
130
|
+
retry: void 0
|
|
131
|
+
};
|
|
132
|
+
}
|
|
133
|
+
|
|
34
134
|
// src/HfInference.ts
|
|
135
|
+
var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co/models/";
|
|
136
|
+
var TextGenerationStreamFinishReason = /* @__PURE__ */ ((TextGenerationStreamFinishReason2) => {
|
|
137
|
+
TextGenerationStreamFinishReason2["Length"] = "length";
|
|
138
|
+
TextGenerationStreamFinishReason2["EndOfSequenceToken"] = "eos_token";
|
|
139
|
+
TextGenerationStreamFinishReason2["StopSequence"] = "stop_sequence";
|
|
140
|
+
return TextGenerationStreamFinishReason2;
|
|
141
|
+
})(TextGenerationStreamFinishReason || {});
|
|
35
142
|
var HfInference = class {
|
|
36
143
|
apiKey;
|
|
37
144
|
defaultOptions;
|
|
@@ -43,132 +150,246 @@ var HfInference = class {
|
|
|
43
150
|
* Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
|
|
44
151
|
*/
|
|
45
152
|
async fillMask(args, options) {
|
|
46
|
-
|
|
153
|
+
const res = await this.request(args, options);
|
|
154
|
+
const isValidOutput = Array.isArray(res) && res.every(
|
|
155
|
+
(x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
|
|
156
|
+
);
|
|
157
|
+
if (!isValidOutput) {
|
|
158
|
+
throw new TypeError(
|
|
159
|
+
"Invalid inference output: output must be of type Array<score: number, sequence:string, token:number, token_str:string>"
|
|
160
|
+
);
|
|
161
|
+
}
|
|
162
|
+
return res;
|
|
47
163
|
}
|
|
48
164
|
/**
|
|
49
165
|
* This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.
|
|
50
166
|
*/
|
|
51
167
|
async summarization(args, options) {
|
|
52
|
-
|
|
168
|
+
const res = await this.request(args, options);
|
|
169
|
+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.summary_text === "string");
|
|
170
|
+
if (!isValidOutput) {
|
|
171
|
+
throw new TypeError("Invalid inference output: output must be of type Array<summary_text: string>");
|
|
172
|
+
}
|
|
173
|
+
return res?.[0];
|
|
53
174
|
}
|
|
54
175
|
/**
|
|
55
176
|
* Want to have a nice know-it-all bot that can answer any question?. Recommended model: deepset/roberta-base-squad2
|
|
56
177
|
*/
|
|
57
178
|
async questionAnswer(args, options) {
|
|
58
|
-
|
|
179
|
+
const res = await this.request(args, options);
|
|
180
|
+
const isValidOutput = typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";
|
|
181
|
+
if (!isValidOutput) {
|
|
182
|
+
throw new TypeError(
|
|
183
|
+
"Invalid inference output: output must be of type <answer: string, end: number, score: number, start: number>"
|
|
184
|
+
);
|
|
185
|
+
}
|
|
186
|
+
return res;
|
|
59
187
|
}
|
|
60
188
|
/**
|
|
61
189
|
* Don’t know SQL? Don’t want to dive into a large spreadsheet? Ask questions in plain english! Recommended model: google/tapas-base-finetuned-wtq.
|
|
62
190
|
*/
|
|
63
191
|
async tableQuestionAnswer(args, options) {
|
|
64
|
-
|
|
192
|
+
const res = await this.request(args, options);
|
|
193
|
+
const isValidOutput = typeof res.aggregator === "string" && typeof res.answer === "string" && Array.isArray(res.cells) && res.cells.every((x) => typeof x === "string") && Array.isArray(res.coordinates) && res.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number"));
|
|
194
|
+
if (!isValidOutput) {
|
|
195
|
+
throw new TypeError(
|
|
196
|
+
"Invalid inference output: output must be of type <aggregator: string, answer: string, cells: string[], coordinates: number[][]>"
|
|
197
|
+
);
|
|
198
|
+
}
|
|
199
|
+
return res;
|
|
65
200
|
}
|
|
66
201
|
/**
|
|
67
202
|
* Usually used for sentiment-analysis this will output the likelihood of classes of an input. Recommended model: distilbert-base-uncased-finetuned-sst-2-english
|
|
68
203
|
*/
|
|
69
204
|
async textClassification(args, options) {
|
|
70
|
-
|
|
205
|
+
const res = (await this.request(args, options))?.[0];
|
|
206
|
+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
207
|
+
if (!isValidOutput) {
|
|
208
|
+
throw new TypeError("Invalid inference output: output must be of type Array<label: string, score: number>");
|
|
209
|
+
}
|
|
210
|
+
return res;
|
|
71
211
|
}
|
|
72
212
|
/**
|
|
73
213
|
* Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
|
|
74
214
|
*/
|
|
75
215
|
async textGeneration(args, options) {
|
|
76
|
-
|
|
216
|
+
const res = await this.request(args, options);
|
|
217
|
+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.generated_text === "string");
|
|
218
|
+
if (!isValidOutput) {
|
|
219
|
+
throw new TypeError("Invalid inference output: output must be of type Array<generated_text: string>");
|
|
220
|
+
}
|
|
221
|
+
return res?.[0];
|
|
222
|
+
}
|
|
223
|
+
/**
|
|
224
|
+
* Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time
|
|
225
|
+
*/
|
|
226
|
+
async *textGenerationStream(args, options) {
|
|
227
|
+
yield* this.streamingRequest(args, options);
|
|
77
228
|
}
|
|
78
229
|
/**
|
|
79
230
|
* Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. Recommended model: dbmdz/bert-large-cased-finetuned-conll03-english
|
|
80
231
|
*/
|
|
81
232
|
async tokenClassification(args, options) {
|
|
82
|
-
|
|
233
|
+
const res = toArray(await this.request(args, options));
|
|
234
|
+
const isValidOutput = Array.isArray(res) && res.every(
|
|
235
|
+
(x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
|
|
236
|
+
);
|
|
237
|
+
if (!isValidOutput) {
|
|
238
|
+
throw new TypeError(
|
|
239
|
+
"Invalid inference output: output must be of type Array<end: number, entity_group: string, score: number, start: number, word: string>"
|
|
240
|
+
);
|
|
241
|
+
}
|
|
242
|
+
return res;
|
|
83
243
|
}
|
|
84
244
|
/**
|
|
85
245
|
* This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en.
|
|
86
246
|
*/
|
|
87
247
|
async translation(args, options) {
|
|
88
|
-
|
|
248
|
+
const res = await this.request(args, options);
|
|
249
|
+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.translation_text === "string");
|
|
250
|
+
if (!isValidOutput) {
|
|
251
|
+
throw new TypeError("Invalid inference output: output must be of type Array<translation_text: string>");
|
|
252
|
+
}
|
|
253
|
+
return res?.[0];
|
|
89
254
|
}
|
|
90
255
|
/**
|
|
91
256
|
* This task is super useful to try out classification with zero code, you simply pass a sentence/paragraph and the possible labels for that sentence, and you get a result. Recommended model: facebook/bart-large-mnli.
|
|
92
257
|
*/
|
|
93
258
|
async zeroShotClassification(args, options) {
|
|
94
|
-
|
|
259
|
+
const res = toArray(
|
|
95
260
|
await this.request(args, options)
|
|
96
261
|
);
|
|
262
|
+
const isValidOutput = Array.isArray(res) && res.every(
|
|
263
|
+
(x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
|
|
264
|
+
);
|
|
265
|
+
if (!isValidOutput) {
|
|
266
|
+
throw new TypeError(
|
|
267
|
+
"Invalid inference output: output must be of type Array<labels: string[], scores: number[], sequence: string>"
|
|
268
|
+
);
|
|
269
|
+
}
|
|
270
|
+
return res;
|
|
97
271
|
}
|
|
98
272
|
/**
|
|
99
273
|
* This task corresponds to any chatbot like structure. Models tend to have shorter max_length, so please check with caution when using a given model if you need long range dependency or not. Recommended model: microsoft/DialoGPT-large.
|
|
100
274
|
*
|
|
101
275
|
*/
|
|
102
276
|
async conversational(args, options) {
|
|
103
|
-
|
|
277
|
+
const res = await this.request(args, options);
|
|
278
|
+
const isValidOutput = Array.isArray(res.conversation.generated_responses) && res.conversation.generated_responses.every((x) => typeof x === "string") && Array.isArray(res.conversation.past_user_inputs) && res.conversation.past_user_inputs.every((x) => typeof x === "string") && typeof res.generated_text === "string" && Array.isArray(res.warnings) && res.warnings.every((x) => typeof x === "string");
|
|
279
|
+
if (!isValidOutput) {
|
|
280
|
+
throw new TypeError(
|
|
281
|
+
"Invalid inference output: output must be of type <conversation: {generated_responses: string[], past_user_inputs: string[]}, generated_text: string, warnings: string[]>"
|
|
282
|
+
);
|
|
283
|
+
}
|
|
284
|
+
return res;
|
|
104
285
|
}
|
|
105
286
|
/**
|
|
106
287
|
* This task reads some text and outputs raw float values, that are usually consumed as part of a semantic database/semantic search.
|
|
107
288
|
*/
|
|
108
289
|
async featureExtraction(args, options) {
|
|
109
|
-
|
|
290
|
+
const res = await this.request(args, options);
|
|
291
|
+
return res;
|
|
110
292
|
}
|
|
111
293
|
/**
|
|
112
294
|
* This task reads some audio input and outputs the said words within the audio files.
|
|
113
295
|
* Recommended model (english language): facebook/wav2vec2-large-960h-lv60-self
|
|
114
296
|
*/
|
|
115
297
|
async automaticSpeechRecognition(args, options) {
|
|
116
|
-
|
|
298
|
+
const res = await this.request(args, {
|
|
117
299
|
...options,
|
|
118
300
|
binary: true
|
|
119
301
|
});
|
|
302
|
+
const isValidOutput = typeof res.text === "string";
|
|
303
|
+
if (!isValidOutput) {
|
|
304
|
+
throw new TypeError("Invalid inference output: output must be of type <text: string>");
|
|
305
|
+
}
|
|
306
|
+
return res;
|
|
120
307
|
}
|
|
121
308
|
/**
|
|
122
309
|
* This task reads some audio input and outputs the likelihood of classes.
|
|
123
310
|
* Recommended model: superb/hubert-large-superb-er
|
|
124
311
|
*/
|
|
125
312
|
async audioClassification(args, options) {
|
|
126
|
-
|
|
313
|
+
const res = await this.request(args, {
|
|
127
314
|
...options,
|
|
128
315
|
binary: true
|
|
129
316
|
});
|
|
317
|
+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
318
|
+
if (!isValidOutput) {
|
|
319
|
+
throw new TypeError("Invalid inference output: output must be of type Array<label: string, score: number>");
|
|
320
|
+
}
|
|
321
|
+
return res;
|
|
130
322
|
}
|
|
131
323
|
/**
|
|
132
324
|
* This task reads some image input and outputs the likelihood of classes.
|
|
133
325
|
* Recommended model: google/vit-base-patch16-224
|
|
134
326
|
*/
|
|
135
327
|
async imageClassification(args, options) {
|
|
136
|
-
|
|
328
|
+
const res = await this.request(args, {
|
|
137
329
|
...options,
|
|
138
330
|
binary: true
|
|
139
331
|
});
|
|
332
|
+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
|
333
|
+
if (!isValidOutput) {
|
|
334
|
+
throw new TypeError("Invalid inference output: output must be of type Array<label: string, score: number>");
|
|
335
|
+
}
|
|
336
|
+
return res;
|
|
140
337
|
}
|
|
141
338
|
/**
|
|
142
339
|
* This task reads some image input and outputs the likelihood of classes & bounding boxes of detected objects.
|
|
143
340
|
* Recommended model: facebook/detr-resnet-50
|
|
144
341
|
*/
|
|
145
342
|
async objectDetection(args, options) {
|
|
146
|
-
|
|
343
|
+
const res = await this.request(args, {
|
|
147
344
|
...options,
|
|
148
345
|
binary: true
|
|
149
346
|
});
|
|
347
|
+
const isValidOutput = Array.isArray(res) && res.every(
|
|
348
|
+
(x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
|
|
349
|
+
);
|
|
350
|
+
if (!isValidOutput) {
|
|
351
|
+
throw new TypeError(
|
|
352
|
+
"Invalid inference output: output must be of type Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
|
|
353
|
+
);
|
|
354
|
+
}
|
|
355
|
+
return res;
|
|
150
356
|
}
|
|
151
357
|
/**
|
|
152
358
|
* This task reads some image input and outputs the likelihood of classes & bounding boxes of detected objects.
|
|
153
359
|
* Recommended model: facebook/detr-resnet-50-panoptic
|
|
154
360
|
*/
|
|
155
361
|
async imageSegmentation(args, options) {
|
|
156
|
-
|
|
362
|
+
const res = await this.request(args, {
|
|
157
363
|
...options,
|
|
158
364
|
binary: true
|
|
159
365
|
});
|
|
366
|
+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
|
|
367
|
+
if (!isValidOutput) {
|
|
368
|
+
throw new TypeError(
|
|
369
|
+
"Invalid inference output: output must be of type Array<label: string, mask: string, score: number>"
|
|
370
|
+
);
|
|
371
|
+
}
|
|
372
|
+
return res;
|
|
160
373
|
}
|
|
161
374
|
/**
|
|
162
375
|
* This task reads some text input and outputs an image.
|
|
163
376
|
* Recommended model: stabilityai/stable-diffusion-2
|
|
164
377
|
*/
|
|
165
378
|
async textToImage(args, options) {
|
|
166
|
-
|
|
379
|
+
const res = await this.request(args, {
|
|
167
380
|
...options,
|
|
168
381
|
blob: true
|
|
169
382
|
});
|
|
383
|
+
const isValidOutput = res && res instanceof Blob;
|
|
384
|
+
if (!isValidOutput) {
|
|
385
|
+
throw new TypeError("Invalid inference output: output must be of type object & of instance Blob");
|
|
386
|
+
}
|
|
387
|
+
return res;
|
|
170
388
|
}
|
|
171
|
-
|
|
389
|
+
/**
|
|
390
|
+
* Helper that prepares request arguments
|
|
391
|
+
*/
|
|
392
|
+
makeRequestOptions(args, options) {
|
|
172
393
|
const mergedOptions = { ...this.defaultOptions, ...options };
|
|
173
394
|
const { model, ...otherArgs } = args;
|
|
174
395
|
const headers = {};
|
|
@@ -189,7 +410,8 @@ var HfInference = class {
|
|
|
189
410
|
headers["X-Load-Model"] = "0";
|
|
190
411
|
}
|
|
191
412
|
}
|
|
192
|
-
const
|
|
413
|
+
const url = `${HF_INFERENCE_API_BASE_URL}${model}`;
|
|
414
|
+
const info = {
|
|
193
415
|
headers,
|
|
194
416
|
method: "POST",
|
|
195
417
|
body: options?.binary ? args.data : JSON.stringify({
|
|
@@ -197,7 +419,12 @@ var HfInference = class {
|
|
|
197
419
|
options: mergedOptions
|
|
198
420
|
}),
|
|
199
421
|
credentials: options?.includeCredentials ? "include" : "same-origin"
|
|
200
|
-
}
|
|
422
|
+
};
|
|
423
|
+
return { url, info, mergedOptions };
|
|
424
|
+
}
|
|
425
|
+
async request(args, options) {
|
|
426
|
+
const { url, info, mergedOptions } = this.makeRequestOptions(args, options);
|
|
427
|
+
const response = await fetch(url, info);
|
|
201
428
|
if (mergedOptions.retry_on_error !== false && response.status === 503 && !mergedOptions.wait_for_model) {
|
|
202
429
|
return this.request(args, {
|
|
203
430
|
...mergedOptions,
|
|
@@ -216,8 +443,58 @@ var HfInference = class {
|
|
|
216
443
|
}
|
|
217
444
|
return output;
|
|
218
445
|
}
|
|
446
|
+
/**
|
|
447
|
+
* Make request that uses server-sent events and returns response as a generator
|
|
448
|
+
*/
|
|
449
|
+
async *streamingRequest(args, options) {
|
|
450
|
+
const { url, info, mergedOptions } = this.makeRequestOptions({ ...args, stream: true }, options);
|
|
451
|
+
const response = await fetch(url, info);
|
|
452
|
+
if (mergedOptions.retry_on_error !== false && response.status === 503 && !mergedOptions.wait_for_model) {
|
|
453
|
+
return this.streamingRequest(args, {
|
|
454
|
+
...mergedOptions,
|
|
455
|
+
wait_for_model: true
|
|
456
|
+
});
|
|
457
|
+
}
|
|
458
|
+
if (!response.ok) {
|
|
459
|
+
throw new Error(`Server response contains error: ${response.status}`);
|
|
460
|
+
}
|
|
461
|
+
if (response.headers.get("content-type") !== "text/event-stream") {
|
|
462
|
+
throw new Error(`Server does not support event stream content type`);
|
|
463
|
+
}
|
|
464
|
+
const reader = response.body.getReader();
|
|
465
|
+
const events = [];
|
|
466
|
+
const onEvent = (event) => {
|
|
467
|
+
events.push(event);
|
|
468
|
+
};
|
|
469
|
+
const onChunk = getLines(
|
|
470
|
+
getMessages(
|
|
471
|
+
() => {
|
|
472
|
+
},
|
|
473
|
+
() => {
|
|
474
|
+
},
|
|
475
|
+
onEvent
|
|
476
|
+
)
|
|
477
|
+
);
|
|
478
|
+
try {
|
|
479
|
+
while (true) {
|
|
480
|
+
const { done, value } = await reader.read();
|
|
481
|
+
if (done)
|
|
482
|
+
return;
|
|
483
|
+
onChunk(value);
|
|
484
|
+
while (events.length > 0) {
|
|
485
|
+
const event = events.shift();
|
|
486
|
+
if (event.data.length > 0) {
|
|
487
|
+
yield JSON.parse(event.data);
|
|
488
|
+
}
|
|
489
|
+
}
|
|
490
|
+
}
|
|
491
|
+
} finally {
|
|
492
|
+
reader.releaseLock();
|
|
493
|
+
}
|
|
494
|
+
}
|
|
219
495
|
};
|
|
220
496
|
// Annotate the CommonJS export names for ESM import in node:
|
|
221
497
|
0 && (module.exports = {
|
|
222
|
-
HfInference
|
|
498
|
+
HfInference,
|
|
499
|
+
TextGenerationStreamFinishReason
|
|
223
500
|
});
|