@huggingface/inference 1.6.3 → 1.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,8 @@
1
1
  import { toArray } from "./utils/to-array";
2
+ import type { EventSourceMessage } from "./vendor/fetch-event-source/parse";
3
+ import { getLines, getMessages } from "./vendor/fetch-event-source/parse";
4
+
5
+ const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co/models/";
2
6
 
3
7
  export interface Options {
4
8
  /**
@@ -223,6 +227,86 @@ export interface TextGenerationReturn {
223
227
  generated_text: string;
224
228
  }
225
229
 
230
+ export interface TextGenerationStreamToken {
231
+ /** Token ID from the model tokenizer */
232
+ id: number;
233
+ /** Token text */
234
+ text: string;
235
+ /** Logprob */
236
+ logprob: number;
237
+ /**
238
+ * Is the token a special token
239
+ * Can be used to ignore tokens when concatenating
240
+ */
241
+ special: boolean;
242
+ }
243
+
244
+ export interface TextGenerationStreamPrefillToken {
245
+ /** Token ID from the model tokenizer */
246
+ id: number;
247
+ /** Token text */
248
+ text: string;
249
+ /**
250
+ * Logprob
251
+ * Optional since the logprob of the first token cannot be computed
252
+ */
253
+ logprob?: number;
254
+ }
255
+
256
+ export interface TextGenerationStreamBestOfSequence {
257
+ /** Generated text */
258
+ generated_text: string;
259
+ /** Generation finish reason */
260
+ finish_reason: TextGenerationStreamFinishReason;
261
+ /** Number of generated tokens */
262
+ generated_tokens: number;
263
+ /** Sampling seed if sampling was activated */
264
+ seed?: number;
265
+ /** Prompt tokens */
266
+ prefill: TextGenerationStreamPrefillToken[];
267
+ /** Generated tokens */
268
+ tokens: TextGenerationStreamToken[];
269
+ }
270
+
271
+ export enum TextGenerationStreamFinishReason {
272
+ /** number of generated tokens == `max_new_tokens` */
273
+ Length = "length",
274
+ /** the model generated its end of sequence token */
275
+ EndOfSequenceToken = "eos_token",
276
+ /** the model generated a text included in `stop_sequences` */
277
+ StopSequence = "stop_sequence",
278
+ }
279
+
280
+ export interface TextGenerationStreamDetails {
281
+ /** Generation finish reason */
282
+ finish_reason: TextGenerationStreamFinishReason;
283
+ /** Number of generated tokens */
284
+ generated_tokens: number;
285
+ /** Sampling seed if sampling was activated */
286
+ seed?: number;
287
+ /** Prompt tokens */
288
+ prefill: TextGenerationStreamPrefillToken[];
289
+ /** */
290
+ tokens: TextGenerationStreamToken[];
291
+ /** Additional sequences when using the `best_of` parameter */
292
+ best_of_sequences?: TextGenerationStreamBestOfSequence[];
293
+ }
294
+
295
+ export interface TextGenerationStreamReturn {
296
+ /** Generated token, one at a time */
297
+ token: TextGenerationStreamToken;
298
+ /**
299
+ * Complete generated text
300
+ * Only available when the generation is finished
301
+ */
302
+ generated_text?: string;
303
+ /**
304
+ * Generation details
305
+ * Only available when the generation is finished
306
+ */
307
+ details?: TextGenerationStreamDetails;
308
+ }
309
+
226
310
  export type TokenClassificationArgs = Args & {
227
311
  /**
228
312
  * A string to be classified
@@ -615,6 +699,16 @@ export class HfInference {
615
699
  return res?.[0];
616
700
  }
617
701
 
702
+ /**
703
+ * Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time
704
+ */
705
+ public async *textGenerationStream(
706
+ args: TextGenerationArgs,
707
+ options?: Options
708
+ ): AsyncGenerator<TextGenerationStreamReturn> {
709
+ yield* this.streamingRequest<TextGenerationStreamReturn>(args, options);
710
+ }
711
+
618
712
  /**
619
713
  * Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. Recommended model: dbmdz/bert-large-cased-finetuned-conll03-english
620
714
  */
@@ -834,15 +928,21 @@ export class HfInference {
834
928
  return res;
835
929
  }
836
930
 
837
- public async request<T>(
838
- args: Args & { data?: Blob | ArrayBuffer },
931
+ /**
932
+ * Helper that prepares request arguments
933
+ */
934
+ private makeRequestOptions(
935
+ args: Args & {
936
+ data?: Blob | ArrayBuffer;
937
+ stream?: boolean;
938
+ },
839
939
  options?: Options & {
840
940
  binary?: boolean;
841
941
  blob?: boolean;
842
942
  /** For internal HF use, which is why it's not exposed in {@link Options} */
843
943
  includeCredentials?: boolean;
844
944
  }
845
- ): Promise<T> {
945
+ ) {
846
946
  const mergedOptions = { ...this.defaultOptions, ...options };
847
947
  const { model, ...otherArgs } = args;
848
948
 
@@ -867,7 +967,8 @@ export class HfInference {
867
967
  }
868
968
  }
869
969
 
870
- const response = await fetch(`https://api-inference.huggingface.co/models/${model}`, {
970
+ const url = `${HF_INFERENCE_API_BASE_URL}${model}`;
971
+ const info: RequestInit = {
871
972
  headers,
872
973
  method: "POST",
873
974
  body: options?.binary
@@ -877,7 +978,22 @@ export class HfInference {
877
978
  options: mergedOptions,
878
979
  }),
879
980
  credentials: options?.includeCredentials ? "include" : "same-origin",
880
- });
981
+ };
982
+
983
+ return { url, info, mergedOptions };
984
+ }
985
+
986
+ public async request<T>(
987
+ args: Args & { data?: Blob | ArrayBuffer },
988
+ options?: Options & {
989
+ binary?: boolean;
990
+ blob?: boolean;
991
+ /** For internal HF use, which is why it's not exposed in {@link Options} */
992
+ includeCredentials?: boolean;
993
+ }
994
+ ): Promise<T> {
995
+ const { url, info, mergedOptions } = this.makeRequestOptions(args, options);
996
+ const response = await fetch(url, info);
881
997
 
882
998
  if (mergedOptions.retry_on_error !== false && response.status === 503 && !mergedOptions.wait_for_model) {
883
999
  return this.request(args, {
@@ -899,4 +1015,65 @@ export class HfInference {
899
1015
  }
900
1016
  return output;
901
1017
  }
1018
+
1019
+ /**
1020
+ * Make request that uses server-sent events and returns response as a generator
1021
+ */
1022
+ public async *streamingRequest<T>(
1023
+ args: Args & { data?: Blob | ArrayBuffer },
1024
+ options?: Options & {
1025
+ binary?: boolean;
1026
+ blob?: boolean;
1027
+ /** For internal HF use, which is why it's not exposed in {@link Options} */
1028
+ includeCredentials?: boolean;
1029
+ }
1030
+ ): AsyncGenerator<T> {
1031
+ const { url, info, mergedOptions } = this.makeRequestOptions({ ...args, stream: true }, options);
1032
+ const response = await fetch(url, info);
1033
+
1034
+ if (mergedOptions.retry_on_error !== false && response.status === 503 && !mergedOptions.wait_for_model) {
1035
+ return this.streamingRequest(args, {
1036
+ ...mergedOptions,
1037
+ wait_for_model: true,
1038
+ });
1039
+ }
1040
+ if (!response.ok) {
1041
+ throw new Error(`Server response contains error: ${response.status}`);
1042
+ }
1043
+ if (response.headers.get("content-type") !== "text/event-stream") {
1044
+ throw new Error(`Server does not support event stream content type`);
1045
+ }
1046
+
1047
+ const reader = response.body.getReader();
1048
+ const events: EventSourceMessage[] = [];
1049
+
1050
+ const onEvent = (event: EventSourceMessage) => {
1051
+ // accumulate events in array
1052
+ events.push(event);
1053
+ };
1054
+
1055
+ const onChunk = getLines(
1056
+ getMessages(
1057
+ () => {},
1058
+ () => {},
1059
+ onEvent
1060
+ )
1061
+ );
1062
+
1063
+ try {
1064
+ while (true) {
1065
+ const { done, value } = await reader.read();
1066
+ if (done) return;
1067
+ onChunk(value);
1068
+ while (events.length > 0) {
1069
+ const event = events.shift();
1070
+ if (event.data.length > 0) {
1071
+ yield JSON.parse(event.data) as T;
1072
+ }
1073
+ }
1074
+ }
1075
+ } finally {
1076
+ reader.releaseLock();
1077
+ }
1078
+ }
902
1079
  }
@@ -0,0 +1,389 @@
1
+ import { expect, it, describe } from "vitest";
2
+ const fail = (msg: string) => { throw new Error(msg) };
3
+
4
+ /**
5
+ This file is a part of fetch-event-source package (as of v2.0.1)
6
+ https://github.com/Azure/fetch-event-source/blob/v2.0.1/src/parse.spec.ts
7
+
8
+ Full package can be used after it is made compatible with nodejs:
9
+ https://github.com/Azure/fetch-event-source/issues/20
10
+
11
+ Below is the fetch-event-source package license:
12
+
13
+ MIT License
14
+
15
+ Copyright (c) Microsoft Corporation.
16
+
17
+ Permission is hereby granted, free of charge, to any person obtaining a copy
18
+ of this software and associated documentation files (the "Software"), to deal
19
+ in the Software without restriction, including without limitation the rights
20
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
21
+ copies of the Software, and to permit persons to whom the Software is
22
+ furnished to do so, subject to the following conditions:
23
+
24
+ The above copyright notice and this permission notice shall be included in all
25
+ copies or substantial portions of the Software.
26
+
27
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
28
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
29
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
30
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
31
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
32
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
33
+ SOFTWARE
34
+
35
+ */
36
+
37
+ import * as parse from './parse';
38
+
39
+ describe('parse', () => {
40
+ const encoder = new TextEncoder();
41
+ const decoder = new TextDecoder();
42
+
43
+ describe('getLines', () => {
44
+ it('single line', () => {
45
+ // arrange:
46
+ let lineNum = 0;
47
+ const next = parse.getLines((line, fieldLength) => {
48
+ ++lineNum;
49
+ expect(decoder.decode(line)).toEqual('id: abc');
50
+ expect(fieldLength).toEqual(2);
51
+ });
52
+
53
+ // act:
54
+ next(encoder.encode('id: abc\n'));
55
+
56
+ // assert:
57
+ expect(lineNum).toBe(1);
58
+ });
59
+
60
+ it('multiple lines', () => {
61
+ // arrange:
62
+ let lineNum = 0;
63
+ const next = parse.getLines((line, fieldLength) => {
64
+ ++lineNum;
65
+ expect(decoder.decode(line)).toEqual(lineNum === 1 ? 'id: abc' : 'data: def');
66
+ expect(fieldLength).toEqual(lineNum === 1 ? 2 : 4);
67
+ });
68
+
69
+ // act:
70
+ next(encoder.encode('id: abc\n'));
71
+ next(encoder.encode('data: def\n'));
72
+
73
+ // assert:
74
+ expect(lineNum).toBe(2);
75
+ });
76
+
77
+ it('single line split across multiple arrays', () => {
78
+ // arrange:
79
+ let lineNum = 0;
80
+ const next = parse.getLines((line, fieldLength) => {
81
+ ++lineNum;
82
+ expect(decoder.decode(line)).toEqual('id: abc');
83
+ expect(fieldLength).toEqual(2);
84
+ });
85
+
86
+ // act:
87
+ next(encoder.encode('id: a'));
88
+ next(encoder.encode('bc\n'));
89
+
90
+ // assert:
91
+ expect(lineNum).toBe(1);
92
+ });
93
+
94
+ it('multiple lines split across multiple arrays', () => {
95
+ // arrange:
96
+ let lineNum = 0;
97
+ const next = parse.getLines((line, fieldLength) => {
98
+ ++lineNum;
99
+ expect(decoder.decode(line)).toEqual(lineNum === 1 ? 'id: abc' : 'data: def');
100
+ expect(fieldLength).toEqual(lineNum === 1 ? 2 : 4);
101
+ });
102
+
103
+ // act:
104
+ next(encoder.encode('id: ab'));
105
+ next(encoder.encode('c\nda'));
106
+ next(encoder.encode('ta: def\n'));
107
+
108
+ // assert:
109
+ expect(lineNum).toBe(2);
110
+ });
111
+
112
+ it('new line', () => {
113
+ // arrange:
114
+ let lineNum = 0;
115
+ const next = parse.getLines((line, fieldLength) => {
116
+ ++lineNum;
117
+ expect(decoder.decode(line)).toEqual('');
118
+ expect(fieldLength).toEqual(-1);
119
+ });
120
+
121
+ // act:
122
+ next(encoder.encode('\n'));
123
+
124
+ // assert:
125
+ expect(lineNum).toBe(1);
126
+ });
127
+
128
+ it('comment line', () => {
129
+ // arrange:
130
+ let lineNum = 0;
131
+ const next = parse.getLines((line, fieldLength) => {
132
+ ++lineNum;
133
+ expect(decoder.decode(line)).toEqual(': this is a comment');
134
+ expect(fieldLength).toEqual(0);
135
+ });
136
+
137
+ // act:
138
+ next(encoder.encode(': this is a comment\n'));
139
+
140
+ // assert:
141
+ expect(lineNum).toBe(1);
142
+ });
143
+
144
+ it('line with no field', () => {
145
+ // arrange:
146
+ let lineNum = 0;
147
+ const next = parse.getLines((line, fieldLength) => {
148
+ ++lineNum;
149
+ expect(decoder.decode(line)).toEqual('this is an invalid line');
150
+ expect(fieldLength).toEqual(-1);
151
+ });
152
+
153
+ // act:
154
+ next(encoder.encode('this is an invalid line\n'));
155
+
156
+ // assert:
157
+ expect(lineNum).toBe(1);
158
+ });
159
+
160
+ it('line with multiple colons', () => {
161
+ // arrange:
162
+ let lineNum = 0;
163
+ const next = parse.getLines((line, fieldLength) => {
164
+ ++lineNum;
165
+ expect(decoder.decode(line)).toEqual('id: abc: def');
166
+ expect(fieldLength).toEqual(2);
167
+ });
168
+
169
+ // act:
170
+ next(encoder.encode('id: abc: def\n'));
171
+
172
+ // assert:
173
+ expect(lineNum).toBe(1);
174
+ });
175
+
176
+ it('single byte array with multiple lines separated by \\n', () => {
177
+ // arrange:
178
+ let lineNum = 0;
179
+ const next = parse.getLines((line, fieldLength) => {
180
+ ++lineNum;
181
+ expect(decoder.decode(line)).toEqual(lineNum === 1 ? 'id: abc' : 'data: def');
182
+ expect(fieldLength).toEqual(lineNum === 1 ? 2 : 4);
183
+ });
184
+
185
+ // act:
186
+ next(encoder.encode('id: abc\ndata: def\n'));
187
+
188
+ // assert:
189
+ expect(lineNum).toBe(2);
190
+ });
191
+
192
+ it('single byte array with multiple lines separated by \\r', () => {
193
+ // arrange:
194
+ let lineNum = 0;
195
+ const next = parse.getLines((line, fieldLength) => {
196
+ ++lineNum;
197
+ expect(decoder.decode(line)).toEqual(lineNum === 1 ? 'id: abc' : 'data: def');
198
+ expect(fieldLength).toEqual(lineNum === 1 ? 2 : 4);
199
+ });
200
+
201
+ // act:
202
+ next(encoder.encode('id: abc\rdata: def\r'));
203
+
204
+ // assert:
205
+ expect(lineNum).toBe(2);
206
+ });
207
+
208
+ it('single byte array with multiple lines separated by \\r\\n', () => {
209
+ // arrange:
210
+ let lineNum = 0;
211
+ const next = parse.getLines((line, fieldLength) => {
212
+ ++lineNum;
213
+ expect(decoder.decode(line)).toEqual(lineNum === 1 ? 'id: abc' : 'data: def');
214
+ expect(fieldLength).toEqual(lineNum === 1 ? 2 : 4);
215
+ });
216
+
217
+ // act:
218
+ next(encoder.encode('id: abc\r\ndata: def\r\n'));
219
+
220
+ // assert:
221
+ expect(lineNum).toBe(2);
222
+ });
223
+ });
224
+
225
+ describe('getMessages', () => {
226
+ it('happy path', () => {
227
+ // arrange:
228
+ let msgNum = 0;
229
+ const next = parse.getMessages(id => {
230
+ expect(id).toEqual('abc');
231
+ }, retry => {
232
+ expect(retry).toEqual(42);
233
+ }, msg => {
234
+ ++msgNum;
235
+ expect(msg).toEqual({
236
+ retry: 42,
237
+ id: 'abc',
238
+ event: 'def',
239
+ data: 'ghi'
240
+ });
241
+ });
242
+
243
+ // act:
244
+ next(encoder.encode('retry: 42'), 5);
245
+ next(encoder.encode('id: abc'), 2);
246
+ next(encoder.encode('event:def'), 5);
247
+ next(encoder.encode('data:ghi'), 4);
248
+ next(encoder.encode(''), -1);
249
+
250
+ // assert:
251
+ expect(msgNum).toBe(1);
252
+ });
253
+
254
+ it('skip unknown fields', () => {
255
+ let msgNum = 0;
256
+ const next = parse.getMessages(id => {
257
+ expect(id).toEqual('abc');
258
+ }, _retry => {
259
+ fail('retry should not be called');
260
+ }, msg => {
261
+ ++msgNum;
262
+ expect(msg).toEqual({
263
+ id: 'abc',
264
+ data: '',
265
+ event: '',
266
+ retry: undefined,
267
+ });
268
+ });
269
+
270
+ // act:
271
+ next(encoder.encode('id: abc'), 2);
272
+ next(encoder.encode('foo: null'), 3);
273
+ next(encoder.encode(''), -1);
274
+
275
+ // assert:
276
+ expect(msgNum).toBe(1);
277
+ });
278
+
279
+ it('ignore non-integer retry', () => {
280
+ let msgNum = 0;
281
+ const next = parse.getMessages(_id => {
282
+ fail('id should not be called');
283
+ }, _retry => {
284
+ fail('retry should not be called');
285
+ }, msg => {
286
+ ++msgNum;
287
+ expect(msg).toEqual({
288
+ id: '',
289
+ data: '',
290
+ event: '',
291
+ retry: undefined,
292
+ });
293
+ });
294
+
295
+ // act:
296
+ next(encoder.encode('retry: def'), 5);
297
+ next(encoder.encode(''), -1);
298
+
299
+ // assert:
300
+ expect(msgNum).toBe(1);
301
+ });
302
+
303
+ it('skip comment-only messages', () => {
304
+ // arrange:
305
+ let msgNum = 0;
306
+ const next = parse.getMessages(id => {
307
+ expect(id).toEqual('123');
308
+ }, _retry => {
309
+ fail('retry should not be called');
310
+ }, msg => {
311
+ ++msgNum;
312
+ expect(msg).toEqual({
313
+ retry: undefined,
314
+ id: '123',
315
+ event: 'foo ',
316
+ data: '',
317
+ });
318
+ });
319
+
320
+ // act:
321
+ next(encoder.encode('id:123'), 2);
322
+ next(encoder.encode(':'), 0);
323
+ next(encoder.encode(': '), 0);
324
+ next(encoder.encode('event: foo '), 5);
325
+ next(encoder.encode(''), -1);
326
+
327
+ // assert:
328
+ expect(msgNum).toBe(1);
329
+ });
330
+
331
+ it('should append data split across multiple lines', () => {
332
+ // arrange:
333
+ let msgNum = 0;
334
+ const next = parse.getMessages(_id => {
335
+ fail('id should not be called');
336
+ }, _retry => {
337
+ fail('retry should not be called');
338
+ }, msg => {
339
+ ++msgNum;
340
+ expect(msg).toEqual({
341
+ data: 'YHOO\n+2\n\n10',
342
+ id: '',
343
+ event: '',
344
+ retry: undefined,
345
+ });
346
+ });
347
+
348
+ // act:
349
+ next(encoder.encode('data:YHOO'), 4);
350
+ next(encoder.encode('data: +2'), 4);
351
+ next(encoder.encode('data'), 4);
352
+ next(encoder.encode('data: 10'), 4);
353
+ next(encoder.encode(''), -1);
354
+
355
+ // assert:
356
+ expect(msgNum).toBe(1);
357
+ });
358
+
359
+ it('should reset id if sent multiple times', () => {
360
+ // arrange:
361
+ const expectedIds = ['foo', ''];
362
+ let idsIdx = 0;
363
+ let msgNum = 0;
364
+ const next = parse.getMessages(id => {
365
+ expect(id).toEqual(expectedIds[idsIdx]);
366
+ ++idsIdx;
367
+ }, _retry => {
368
+ fail('retry should not be called');
369
+ }, msg => {
370
+ ++msgNum;
371
+ expect(msg).toEqual({
372
+ data: '',
373
+ id: '',
374
+ event: '',
375
+ retry: undefined,
376
+ });
377
+ });
378
+
379
+ // act:
380
+ next(encoder.encode('id: foo'), 2);
381
+ next(encoder.encode('id'), 2);
382
+ next(encoder.encode(''), -1);
383
+
384
+ // assert:
385
+ expect(idsIdx).toBe(2);
386
+ expect(msgNum).toBe(1);
387
+ });
388
+ });
389
+ });