@huggingface/transformers 3.0.0-alpha.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. package/LICENSE +202 -0
  2. package/README.md +376 -0
  3. package/dist/ort-wasm-simd-threaded.jsep.wasm +0 -0
  4. package/dist/transformers.cjs +30741 -0
  5. package/dist/transformers.cjs.map +1 -0
  6. package/dist/transformers.js +33858 -0
  7. package/dist/transformers.js.map +1 -0
  8. package/dist/transformers.min.cjs +173 -0
  9. package/dist/transformers.min.cjs.map +1 -0
  10. package/dist/transformers.min.js +231 -0
  11. package/dist/transformers.min.js.map +1 -0
  12. package/package.json +92 -0
  13. package/src/backends/onnx.js +151 -0
  14. package/src/configs.js +360 -0
  15. package/src/env.js +152 -0
  16. package/src/generation/configuration_utils.js +381 -0
  17. package/src/generation/logits_process.js +716 -0
  18. package/src/generation/logits_sampler.js +204 -0
  19. package/src/generation/parameters.js +35 -0
  20. package/src/generation/stopping_criteria.js +156 -0
  21. package/src/generation/streamers.js +212 -0
  22. package/src/models/whisper/common_whisper.js +151 -0
  23. package/src/models/whisper/generation_whisper.js +89 -0
  24. package/src/models.js +7028 -0
  25. package/src/ops/registry.js +92 -0
  26. package/src/pipelines.js +3341 -0
  27. package/src/processors.js +2614 -0
  28. package/src/tokenizers.js +4395 -0
  29. package/src/transformers.js +28 -0
  30. package/src/utils/audio.js +704 -0
  31. package/src/utils/constants.js +2 -0
  32. package/src/utils/core.js +149 -0
  33. package/src/utils/data-structures.js +445 -0
  34. package/src/utils/devices.js +11 -0
  35. package/src/utils/dtypes.js +62 -0
  36. package/src/utils/generic.js +35 -0
  37. package/src/utils/hub.js +671 -0
  38. package/src/utils/image.js +745 -0
  39. package/src/utils/maths.js +1050 -0
  40. package/src/utils/tensor.js +1378 -0
  41. package/types/backends/onnx.d.ts +26 -0
  42. package/types/backends/onnx.d.ts.map +1 -0
  43. package/types/configs.d.ts +59 -0
  44. package/types/configs.d.ts.map +1 -0
  45. package/types/env.d.ts +106 -0
  46. package/types/env.d.ts.map +1 -0
  47. package/types/generation/configuration_utils.d.ts +320 -0
  48. package/types/generation/configuration_utils.d.ts.map +1 -0
  49. package/types/generation/logits_process.d.ts +354 -0
  50. package/types/generation/logits_process.d.ts.map +1 -0
  51. package/types/generation/logits_sampler.d.ts +51 -0
  52. package/types/generation/logits_sampler.d.ts.map +1 -0
  53. package/types/generation/parameters.d.ts +47 -0
  54. package/types/generation/parameters.d.ts.map +1 -0
  55. package/types/generation/stopping_criteria.d.ts +81 -0
  56. package/types/generation/stopping_criteria.d.ts.map +1 -0
  57. package/types/generation/streamers.d.ts +81 -0
  58. package/types/generation/streamers.d.ts.map +1 -0
  59. package/types/models/whisper/common_whisper.d.ts +8 -0
  60. package/types/models/whisper/common_whisper.d.ts.map +1 -0
  61. package/types/models/whisper/generation_whisper.d.ts +76 -0
  62. package/types/models/whisper/generation_whisper.d.ts.map +1 -0
  63. package/types/models.d.ts +3845 -0
  64. package/types/models.d.ts.map +1 -0
  65. package/types/ops/registry.d.ts +11 -0
  66. package/types/ops/registry.d.ts.map +1 -0
  67. package/types/pipelines.d.ts +2403 -0
  68. package/types/pipelines.d.ts.map +1 -0
  69. package/types/processors.d.ts +917 -0
  70. package/types/processors.d.ts.map +1 -0
  71. package/types/tokenizers.d.ts +999 -0
  72. package/types/tokenizers.d.ts.map +1 -0
  73. package/types/transformers.d.ts +13 -0
  74. package/types/transformers.d.ts.map +1 -0
  75. package/types/utils/audio.d.ts +130 -0
  76. package/types/utils/audio.d.ts.map +1 -0
  77. package/types/utils/constants.d.ts +2 -0
  78. package/types/utils/constants.d.ts.map +1 -0
  79. package/types/utils/core.d.ts +91 -0
  80. package/types/utils/core.d.ts.map +1 -0
  81. package/types/utils/data-structures.d.ts +236 -0
  82. package/types/utils/data-structures.d.ts.map +1 -0
  83. package/types/utils/devices.d.ts +8 -0
  84. package/types/utils/devices.d.ts.map +1 -0
  85. package/types/utils/dtypes.d.ts +22 -0
  86. package/types/utils/dtypes.d.ts.map +1 -0
  87. package/types/utils/generic.d.ts +11 -0
  88. package/types/utils/generic.d.ts.map +1 -0
  89. package/types/utils/hub.d.ts +191 -0
  90. package/types/utils/hub.d.ts.map +1 -0
  91. package/types/utils/image.d.ts +119 -0
  92. package/types/utils/image.d.ts.map +1 -0
  93. package/types/utils/maths.d.ts +280 -0
  94. package/types/utils/maths.d.ts.map +1 -0
  95. package/types/utils/tensor.d.ts +392 -0
  96. package/types/utils/tensor.d.ts.map +1 -0
@@ -0,0 +1,716 @@
1
+
2
+ /**
3
+ * @module generation/logits_process
4
+ */
5
+
6
+ import { Callable } from "../utils/generic.js";
7
+ import { Tensor } from "../utils/tensor.js";
8
+
9
+ import { max, log_softmax } from "../utils/maths.js";
10
+
11
+ /**
12
+ * Abstract base class for all logit processors that can be applied during generation.
13
+ */
14
+ export class LogitsProcessor extends Callable {
15
+ /**
16
+ * Apply the processor to the input logits.
17
+ *
18
+ * @abstract
19
+ * @param {bigint[][]} input_ids The input ids.
20
+ * @param {Tensor} logits The logits to process.
21
+ * @throws {Error} Throws an error if `_call` is not implemented in the subclass.
22
+ */
23
+ _call(input_ids, logits) {
24
+ throw Error("`_call` should be implemented in a subclass")
25
+ }
26
+ }
27
+
28
+
29
+ /**
30
+ * Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.
31
+ */
32
+ export class LogitsWarper extends Callable {
33
+ /**
34
+ * Apply the processor to the input logits.
35
+ *
36
+ * @abstract
37
+ * @param {bigint[][]} input_ids The input ids.
38
+ * @param {Tensor} logits The logits to process.
39
+ * @throws {Error} Throws an error if `_call` is not implemented in the subclass.
40
+ */
41
+ _call(input_ids, logits) {
42
+ throw Error("`_call` should be implemented in a subclass")
43
+ }
44
+ }
45
+
46
+
47
+ /**
48
+ * A class representing a list of logits processors. A logits processor is a function that modifies the logits
49
+ * output of a language model. This class provides methods for adding new processors and applying all processors to a
50
+ * batch of logits.
51
+ */
52
+ export class LogitsProcessorList extends Callable {
53
+ /**
54
+ * Constructs a new instance of `LogitsProcessorList`.
55
+ */
56
+ constructor() {
57
+ super();
58
+ this.processors = [];
59
+ }
60
+
61
+ /**
62
+ * Adds a new logits processor to the list.
63
+ *
64
+ * @param {LogitsProcessor} item The logits processor function to add.
65
+ */
66
+ push(item) {
67
+ this.processors.push(item);
68
+ }
69
+
70
+ /**
71
+ * Adds multiple logits processors to the list.
72
+ *
73
+ * @param {LogitsProcessor[]} items The logits processor functions to add.
74
+ */
75
+ extend(items) {
76
+ this.processors.push(...items);
77
+ }
78
+
79
+ /**
80
+ * Applies all logits processors in the list to a batch of logits, modifying them in-place.
81
+ *
82
+ * @param {bigint[][]} input_ids The input IDs for the language model.
83
+ * @param {Tensor} logits
84
+ */
85
+ _call(input_ids, logits) {
86
+ let toReturn = logits;
87
+ // NOTE: Most processors modify logits inplace
88
+ for (const processor of this.processors) {
89
+ toReturn = processor(input_ids, toReturn);
90
+ }
91
+ return toReturn;
92
+ }
93
+
94
+ [Symbol.iterator]() {
95
+ return this.processors.values();
96
+ }
97
+ }
98
+
99
+ // DEPRECATED: https://github.com/huggingface/transformers/pull/29485
100
+ // /**
101
+ // * A logits processor that forces a specific token to be generated by the decoder.
102
+ // */
103
+ // export class ForceTokensLogitsProcessor extends LogitsProcessor {
104
+ // /**
105
+ // * Constructs a new instance of `ForceTokensLogitsProcessor`.
106
+ // *
107
+ // * @param {[number, number][]} forced_decoder_ids The ids of tokens that should be forced.
108
+ // */
109
+ // constructor(forced_decoder_ids) {
110
+ // super();
111
+ // // TODO: convert to `new Map(forced_decoder_ids)`
112
+ // this.force_token_map = Object.fromEntries(forced_decoder_ids ?? []);
113
+ // }
114
+
115
+ // /**
116
+ // * Apply the processor to the input logits.
117
+ // *
118
+ // * @param {bigint[][]} input_ids The input ids.
119
+ // * @param {Tensor} logits The logits to process.
120
+ // * @returns {Tensor} The processed logits.
121
+ // */
122
+ // _call(input_ids, logits) {
123
+ // console.log('this.force_token_map', this.force_token_map)
124
+ // console.log('call ForceTokensLogitsProcessor', input_ids, logits)
125
+ // console.log('input_ids.length', input_ids.length)
126
+ // let map = this.force_token_map[input_ids.length];
127
+ // if (map) { // There exists a mapping
128
+ // logits.data.fill(-Infinity)
129
+ // logits.data[map] = 0;
130
+ // }
131
+ // console.log('map', map)
132
+ // // throw Error("Not implemented")
133
+ // return logits;
134
+ // }
135
+ // }
136
+
137
+ /**
138
+ * A LogitsProcessor that forces a BOS token at the beginning of the generated sequence.
139
+ */
140
+ export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
141
+ /**
142
+ * Create a ForcedBOSTokenLogitsProcessor.
143
+ * @param {number} bos_token_id The ID of the beginning-of-sequence token to be forced.
144
+ */
145
+ constructor(bos_token_id) {
146
+ super();
147
+ this.bos_token_id = bos_token_id;
148
+ }
149
+
150
+ /**
151
+ * Apply the BOS token forcing to the logits.
152
+ * @param {bigint[][]} input_ids The input IDs.
153
+ * @param {Tensor} logits The logits.
154
+ * @returns {Object} The logits with BOS token forcing.
155
+ */
156
+ _call(input_ids, logits) {
157
+ for (let i = 0; i < input_ids.length; ++i) {
158
+ if (input_ids[i].length === 1) {
159
+ const batch_logits = logits[i];
160
+ batch_logits.data.fill(-Infinity);
161
+ batch_logits.data[this.bos_token_id] = 0;
162
+ }
163
+ }
164
+ return logits;
165
+ }
166
+ }
167
+
168
+ /**
169
+ * A logits processor that enforces the specified token as the last generated token when `max_length` is reached.
170
+ */
171
+ export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
172
+ /**
173
+ * Create a ForcedEOSTokenLogitsProcessor.
174
+ * @param {number} max_length The maximum length of the sequence to be generated.
175
+ * @param {number|number[]} eos_token_id The id(s) of the *end-of-sequence* token.
176
+ */
177
+ constructor(max_length, eos_token_id) {
178
+ super();
179
+ this.max_length = max_length;
180
+ this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
181
+ }
182
+
183
+ /**
184
+ * Apply the processor to input_ids and logits.
185
+ *
186
+ * @param {bigint[][]} input_ids The input ids.
187
+ * @param {Tensor} logits The logits tensor.
188
+ */
189
+ _call(input_ids, logits) {
190
+ for (let i = 0; i < input_ids.length; ++i) {
191
+ if (input_ids[i].length === this.max_length - 1) {
192
+ const batch_logits = logits[i];
193
+ batch_logits.data.fill(-Infinity);
194
+
195
+ for (const eos_token of this.eos_token_id) {
196
+ batch_logits.data[eos_token] = 0;
197
+ }
198
+ }
199
+ }
200
+ return logits;
201
+ }
202
+ }
203
+
204
+ /**
205
+ * A LogitsProcessor that suppresses a list of tokens as soon as the `generate` function starts
206
+ * generating using `begin_index` tokens. This should ensure that the tokens defined by
207
+ * `begin_suppress_tokens` at not sampled at the begining of the generation.
208
+ */
209
+ export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor {
210
+ /**
211
+ * Create a SuppressTokensAtBeginLogitsProcessor.
212
+ * @param {number[]} begin_suppress_tokens The IDs of the tokens to suppress.
213
+ * @param {number} begin_index The number of tokens to generate before suppressing tokens.
214
+ */
215
+ constructor(begin_suppress_tokens, begin_index) {
216
+ super();
217
+ this.begin_suppress_tokens = begin_suppress_tokens;
218
+ this.begin_index = begin_index;
219
+ }
220
+
221
+ /**
222
+ * Apply the BOS token forcing to the logits.
223
+ * @param {bigint[][]} input_ids The input IDs.
224
+ * @param {Tensor} logits The logits.
225
+ * @returns {Object} The logits with BOS token forcing.
226
+ */
227
+ _call(input_ids, logits) {
228
+ for (let i = 0; i < input_ids.length; ++i) {
229
+ if (input_ids[i].length === this.begin_index) {
230
+ const batch_logits = logits[i];
231
+ for (const token_id of this.begin_suppress_tokens) {
232
+ batch_logits.data[token_id] = -Infinity;
233
+ }
234
+ }
235
+ }
236
+ return logits;
237
+ }
238
+ }
239
+
240
+ /**
241
+ * A LogitsProcessor that handles adding timestamps to generated text.
242
+ */
243
+ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
244
+ /**
245
+ * Constructs a new WhisperTimeStampLogitsProcessor.
246
+ * @param {import('../models/whisper/generation_whisper.js').WhisperGenerationConfig} generate_config The config object passed to the `generate()` method of a transformer model.
247
+ * @param {number[]} init_tokens The initial tokens of the input sequence.
248
+ */
249
+ constructor(generate_config, init_tokens) {
250
+ super();
251
+ this.eos_token_id =
252
+ Array.isArray(generate_config.eos_token_id)
253
+ ? generate_config.eos_token_id[0]
254
+ : generate_config.eos_token_id;
255
+
256
+ this.no_timestamps_token_id = generate_config.no_timestamps_token_id;
257
+ this.timestamp_begin = this.no_timestamps_token_id + 1;
258
+
259
+ this.begin_index = init_tokens.length;
260
+ if (init_tokens.at(-1) === this.no_timestamps_token_id) {
261
+ this.begin_index -= 1;
262
+ }
263
+ this.max_initial_timestamp_index = generate_config.max_initial_timestamp_index;
264
+ }
265
+
266
+ /**
267
+ * Modify the logits to handle timestamp tokens.
268
+ * @param {bigint[][]} input_ids The input sequence of tokens.
269
+ * @param {Tensor} logits The logits output by the model.
270
+ * @returns {Tensor} The modified logits.
271
+ */
272
+ _call(input_ids, logits) {
273
+ for (let i = 0; i < input_ids.length; ++i) {
274
+ const batch_logits = logits[i];
275
+ const logitsData = /** @type {Float32Array} */(batch_logits.data);
276
+
277
+ // suppress <|notimestamps|> which is handled by without_timestamps
278
+ logitsData[this.no_timestamps_token_id] = -Infinity;
279
+
280
+ if (input_ids[i].length === this.begin_index - 1) {
281
+ logitsData.fill(-Infinity);
282
+ logitsData[this.timestamp_begin] = 0;
283
+ continue;
284
+ }
285
+
286
+ // timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
287
+ const seq = input_ids[i].slice(this.begin_index);
288
+ const last_was_timestamp = seq.length >= 1 && seq[seq.length - 1] >= this.timestamp_begin;
289
+ const penultimate_was_timestamp = seq.length < 2 || seq[seq.length - 2] >= this.timestamp_begin;
290
+
291
+ if (last_was_timestamp) {
292
+ if (penultimate_was_timestamp) { // has to be non-timestamp
293
+ logitsData.subarray(this.timestamp_begin).fill(-Infinity);
294
+ } else { // cannot be normal text tokens
295
+ logitsData.subarray(0, this.eos_token_id).fill(-Infinity);
296
+ }
297
+ }
298
+
299
+ // apply the `max_initial_timestamp` option
300
+ if (input_ids[i].length === this.begin_index && this.max_initial_timestamp_index !== null) {
301
+ const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index;
302
+ logitsData.subarray(last_allowed + 1).fill(-Infinity);
303
+ }
304
+
305
+ // if sum of probability over timestamps is above any other token, sample timestamp
306
+ const logprobs = log_softmax(logitsData);
307
+ const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b));
308
+ const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0];
309
+
310
+ if (timestamp_logprob > max_text_token_logprob) {
311
+ logitsData.subarray(0, this.timestamp_begin).fill(-Infinity);
312
+ }
313
+ }
314
+
315
+ return logits;
316
+ }
317
+ }
318
+
319
+ /**
320
+ * A logits processor that disallows ngrams of a certain size to be repeated.
321
+ */
322
+ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
323
+ /**
324
+ * Create a NoRepeatNGramLogitsProcessor.
325
+ * @param {number} no_repeat_ngram_size The no-repeat-ngram size. All ngrams of this size can only occur once.
326
+ */
327
+ constructor(no_repeat_ngram_size) {
328
+ super();
329
+ this.no_repeat_ngram_size = no_repeat_ngram_size;
330
+ }
331
+
332
+ /**
333
+ * Generate n-grams from a sequence of token ids.
334
+ * @param {bigint[]} prevInputIds List of previous input ids
335
+ * @returns {Map<string, number[]>} Map of generated n-grams
336
+ */
337
+ getNgrams(prevInputIds) {
338
+ const curLen = prevInputIds.length;
339
+
340
+ /**@type {number[][]} */
341
+ const ngrams = [];
342
+ for (let j = 0; j < curLen + 1 - this.no_repeat_ngram_size; ++j) {
343
+ const ngram = [];
344
+ for (let k = 0; k < this.no_repeat_ngram_size; ++k) {
345
+ ngram.push(prevInputIds[j + k]);
346
+ }
347
+ ngrams.push(ngram.map(Number));
348
+ }
349
+
350
+ /** @type {Map<string, number[]>} */
351
+ const generatedNgram = new Map();
352
+ for (const ngram of ngrams) {
353
+ const prevNgram = ngram.slice(0, ngram.length - 1);
354
+ const prevNgramKey = JSON.stringify(prevNgram);
355
+ const prevNgramValue = generatedNgram.get(prevNgramKey) ?? [];
356
+ prevNgramValue.push(ngram[ngram.length - 1]);
357
+ generatedNgram.set(prevNgramKey, prevNgramValue);
358
+ }
359
+ return generatedNgram;
360
+ }
361
+
362
+ /**
363
+ * Generate n-grams from a sequence of token ids.
364
+ * @param {Map<string, number[]>} bannedNgrams Map of banned n-grams
365
+ * @param {bigint[]} prevInputIds List of previous input ids
366
+ * @returns {number[]} Map of generated n-grams
367
+ */
368
+ getGeneratedNgrams(bannedNgrams, prevInputIds) {
369
+ const ngramIdx = prevInputIds.slice(prevInputIds.length + 1 - this.no_repeat_ngram_size, prevInputIds.length);
370
+ const banned = bannedNgrams.get(JSON.stringify(ngramIdx.map(Number))) ?? [];
371
+ return banned;
372
+ }
373
+
374
+ /**
375
+ * Calculate banned n-gram tokens
376
+ * @param {bigint[]} prevInputIds List of previous input ids
377
+ * @returns {number[]} Map of generated n-grams
378
+ */
379
+ calcBannedNgramTokens(prevInputIds) {
380
+ const bannedTokens = [];
381
+ if (prevInputIds.length + 1 < this.no_repeat_ngram_size) {
382
+ // return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
383
+ return bannedTokens;
384
+
385
+ } else {
386
+ const generatedNgrams = this.getNgrams(prevInputIds);
387
+ const bannedTokens = this.getGeneratedNgrams(generatedNgrams, prevInputIds);
388
+ return bannedTokens;
389
+ }
390
+ }
391
+
392
+ /**
393
+ * Apply the no-repeat-ngram processor to the logits.
394
+ * @param {bigint[][]} input_ids The input IDs.
395
+ * @param {Tensor} logits The logits.
396
+ * @returns {Object} The logits with no-repeat-ngram processing.
397
+ */
398
+ _call(input_ids, logits) {
399
+ for (let i = 0; i < input_ids.length; ++i) {
400
+ const batch_logits = logits[i];
401
+ const bannedTokens = this.calcBannedNgramTokens(input_ids[i]);
402
+ for (const token of bannedTokens) {
403
+ batch_logits.data[token] = -Infinity;
404
+ }
405
+ }
406
+ return logits;
407
+ }
408
+ }
409
+
410
+ /**
411
+ * A logits processor that penalises repeated output tokens.
412
+ */
413
+ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
414
+ /**
415
+ * Create a RepetitionPenaltyLogitsProcessor.
416
+ * @param {number} penalty The penalty to apply for repeated tokens.
417
+ */
418
+ constructor(penalty) {
419
+ super();
420
+ this.penalty = penalty;
421
+ }
422
+
423
+ /**
424
+ * Apply the repetition penalty to the logits.
425
+ * @param {bigint[][]} input_ids The input IDs.
426
+ * @param {Tensor} logits The logits.
427
+ * @returns {Object} The logits with repetition penalty processing.
428
+ */
429
+ _call(input_ids, logits) {
430
+ // Modify the logits corresponding to each element in `input_ids`.
431
+ // As a consequence, the logits corresponding to tokens that appear
432
+ // many times in the output will be penalised more.
433
+
434
+ for (let i = 0; i < input_ids.length; ++i) {
435
+ const batch_logits = logits[i];
436
+
437
+ for (const input_id of input_ids[i]) {
438
+ if (batch_logits.data[input_id] < 0) {
439
+ batch_logits.data[input_id] *= this.penalty;
440
+ } else {
441
+ batch_logits.data[input_id] /= this.penalty;
442
+ }
443
+ }
444
+ }
445
+
446
+ return logits
447
+ }
448
+ }
449
+
450
+ /**
451
+ * A logits processor that enforces a minimum number of tokens.
452
+ */
453
+ export class MinLengthLogitsProcessor extends LogitsProcessor {
454
+ /**
455
+ * Create a MinLengthLogitsProcessor.
456
+ * @param {number} min_length The minimum length below which the score of `eos_token_id` is set to negative infinity.
457
+ * @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
458
+ */
459
+ constructor(min_length, eos_token_id) {
460
+ super();
461
+ this.min_length = min_length;
462
+ this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
463
+ }
464
+
465
+ /**
466
+ * Apply logit processor.
467
+ * @param {bigint[][]} input_ids The input IDs.
468
+ * @param {Tensor} logits The logits.
469
+ * @returns {Object} The processed logits.
470
+ */
471
+ _call(input_ids, logits) {
472
+ for (let i = 0; i < input_ids.length; ++i) {
473
+ if (input_ids[i].length < this.min_length) {
474
+ const batch_logits = logits[i];
475
+ for (const eos_token of this.eos_token_id) {
476
+ batch_logits.data[eos_token] = -Infinity;
477
+ }
478
+ }
479
+ }
480
+
481
+ return logits
482
+ }
483
+ }
484
+
485
+ /**
486
+ * A logits processor that enforces a minimum number of new tokens.
487
+ */
488
+ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
489
+ /**
490
+ * Create a MinNewTokensLengthLogitsProcessor.
491
+ * @param {number} prompt_length_to_skip The input tokens length.
492
+ * @param {number} min_new_tokens The minimum *new* tokens length below which the score of `eos_token_id` is set to negative infinity.
493
+ * @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
494
+ */
495
+ constructor(prompt_length_to_skip, min_new_tokens, eos_token_id) {
496
+ super();
497
+ this.prompt_length_to_skip = prompt_length_to_skip;
498
+ this.min_new_tokens = min_new_tokens;
499
+ this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
500
+ }
501
+
502
+ /**
503
+ * Apply logit processor.
504
+ * @param {bigint[][]} input_ids The input IDs.
505
+ * @param {Tensor} logits The logits.
506
+ * @returns {Object} The processed logits.
507
+ */
508
+ _call(input_ids, logits) {
509
+ for (let i = 0; i < input_ids.length; ++i) {
510
+ const new_tokens_length = input_ids[i].length - this.prompt_length_to_skip;
511
+ if (new_tokens_length < this.min_new_tokens) {
512
+ const batch_logits = logits[i];
513
+ for (const eos_token of this.eos_token_id) {
514
+ batch_logits[eos_token] = -Infinity;
515
+ }
516
+ }
517
+ }
518
+ return logits
519
+ }
520
+ }
521
+
522
+ export class NoBadWordsLogitsProcessor extends LogitsProcessor {
523
+ /**
524
+ * Create a `NoBadWordsLogitsProcessor`.
525
+ * @param {number[][]} bad_words_ids List of list of token ids that are not allowed to be generated.
526
+ * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
527
+ */
528
+ constructor(bad_words_ids, eos_token_id) {
529
+ super();
530
+ this.bad_words_ids = bad_words_ids;
531
+ this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
532
+ }
533
+
534
+ /**
535
+ * Apply logit processor.
536
+ * @param {bigint[][]} input_ids The input IDs.
537
+ * @param {Tensor} logits The logits.
538
+ * @returns {Object} The processed logits.
539
+ */
540
+ _call(input_ids, logits) {
541
+ for (let i = 0; i < input_ids.length; ++i) {
542
+ const batch_logits = logits[i];
543
+ for (const bad_word_ids of this.bad_words_ids) {
544
+ // Whether to modify the logits of the last token in the bad word id sequence
545
+ let mark = true;
546
+
547
+ // For each bad word in the list, if the current sequence of input ids ends with this sequence (excluding the last),
548
+ // then we set the logits of the last bad word id to -Infinity.
549
+ for (let i = 1; i <= bad_word_ids.length - 1 && bad_word_ids.length < input_ids[i].length; ++i) {
550
+
551
+ if (bad_word_ids.at(-i - 1) !== Number(input_ids[i].at(-i))) {
552
+ // We have found a mismatch
553
+ mark = false;
554
+ break;
555
+ }
556
+ }
557
+ if (mark) {
558
+ batch_logits[bad_word_ids.at(-1)] = -Infinity;
559
+ }
560
+ }
561
+ }
562
+ return logits
563
+ }
564
+ }
565
+
566
+ /**
567
+ * [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension,
568
+ * where the first half correspond to the conditional logits (predicted from the input prompt) and the second half
569
+ * correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a
570
+ * weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`.
571
+ *
572
+ * See [the paper](https://arxiv.org/abs/2306.05284) for more information.
573
+ */
574
+ export class ClassifierFreeGuidanceLogitsProcessor extends LogitsProcessor {
575
+
576
+ /**
577
+ * Create a `ClassifierFreeGuidanceLogitsProcessor`.
578
+ * @param {number} guidance_scale The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
579
+ * Higher guidance scale encourages the model to generate samples that are more closely linked to the input
580
+ * prompt, usually at the expense of poorer quality.
581
+ */
582
+ constructor(guidance_scale) {
583
+ super();
584
+ if (guidance_scale <= 1) {
585
+ throw new Error(
586
+ `Require guidance scale >1 to use the classifier free guidance processor, got guidance scale ${guidance_scale}.`
587
+ )
588
+ }
589
+ this.guidance_scale = guidance_scale;
590
+ }
591
+
592
+ /**
593
+ * Apply logit processor.
594
+ * @param {bigint[][]} input_ids The input IDs.
595
+ * @param {Tensor} logits The logits.
596
+ * @returns {Object} The processed logits.
597
+ */
598
+ _call(input_ids, logits) {
599
+ if (logits.dims[0] !== 2 * input_ids.length) {
600
+ throw new Error(
601
+ `Logits should have twice the batch size of the input ids, the first half of batches corresponding to ` +
602
+ `the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got ` +
603
+ `batch size ${logits.dims[0]} for the logits and ${input_ids.length} for the input ids.`
604
+ )
605
+ }
606
+
607
+ const unguided_bsz = input_ids.length;
608
+ const cond_logits = logits.slice([0, unguided_bsz], null);
609
+ const uncond_logits = logits.slice([unguided_bsz, logits.dims[0]], null);
610
+
611
+ // Merge into uncond_logits (to save memory). This is equivalent to the following:
612
+ // scores = uncond_logits + (cond_logits - uncond_logits) * guidance_scale
613
+ for (let i = 0; i < uncond_logits.data.length; ++i) {
614
+ uncond_logits.data[i] += (cond_logits.data[i] - uncond_logits.data[i]) * this.guidance_scale;
615
+ }
616
+
617
+ return uncond_logits;
618
+ }
619
+ }
620
+
621
+ /**
622
+ * [`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means
623
+ * that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and [`TopKLogitsWarper`].
624
+ */
625
+ export class TemperatureLogitsWarper extends LogitsWarper {
626
+ /**
627
+ * Create a `TemperatureLogitsWarper`.
628
+ * @param {number} temperature Strictly positive float value used to modulate the logits distribution.
629
+ * A value smaller than `1` decreases randomness (and vice versa), with `0` being equivalent to shifting
630
+ * all probability mass to the most likely token.
631
+ */
632
+ constructor(temperature) {
633
+ super();
634
+
635
+ if (typeof temperature !== 'number' || temperature <= 0) {
636
+ let errorMessage =
637
+ `\`temperature\` (=${temperature}) must be a strictly positive float, otherwise your next token scores will be invalid.`;
638
+
639
+ if (temperature === 0) {
640
+ errorMessage += " If you're looking for greedy decoding strategies, set `do_sample=false`."
641
+ }
642
+ }
643
+ this.temperature = temperature;
644
+ }
645
+
646
+ /**
647
+ * Apply logit warper.
648
+ * @param {bigint[][]} input_ids The input IDs.
649
+ * @param {Tensor} logits The logits.
650
+ * @returns {Object} The processed logits.
651
+ */
652
+ _call(input_ids, logits) {
653
+ const logitsData = /** @type {Float32Array} */(logits.data);
654
+ for (let i = 0; i < logitsData.length; ++i) {
655
+ logitsData[i] /= this.temperature;
656
+ }
657
+ return logits;
658
+ }
659
+ }
660
+
661
+ /**
662
+ * [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
663
+ * Often used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].
664
+ */
665
+ export class TopPLogitsWarper extends LogitsWarper {
666
+ /**
667
+ * Create a `TopPLogitsWarper`.
668
+ * @param {number} top_p If set to < 1, only the smallest set of most probable tokens with
669
+ * probabilities that add up to `top_p` or higher are kept for generation.
670
+ * @param {Object} options Additional options for the top-p sampling.
671
+ * @param {number} [options.filter_value=-Infinity] All filtered values will be set to this float value.
672
+ * @param {number} [options.min_tokens_to_keep=1] Minimum number of tokens that cannot be filtered.
673
+ */
674
+ constructor(top_p, {
675
+ filter_value = -Infinity,
676
+ min_tokens_to_keep = 1,
677
+ } = {}) {
678
+ super();
679
+ if (top_p < 0 || top_p > 1.0) {
680
+ throw new Error(`\`top_p\` must be a float > 0 and < 1, but is ${top_p}`)
681
+ }
682
+ if (!Number.isInteger(min_tokens_to_keep) || min_tokens_to_keep < 1) {
683
+ throw new Error(`\`min_tokens_to_keep\` must be a positive integer, but is ${min_tokens_to_keep}`)
684
+ }
685
+
686
+ this.top_p = top_p
687
+ this.filter_value = filter_value
688
+ this.min_tokens_to_keep = min_tokens_to_keep
689
+ }
690
+ }
691
+
692
+ /**
693
+ * [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
694
+ * Often used together with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].
695
+ */
696
+ export class TopKLogitsWarper extends LogitsWarper {
697
+ /**
698
+ * Create a `TopKLogitsWarper`.
699
+ * @param {number} top_k If set to > 0, only the top `top_k` tokens are kept for generation.
700
+ * @param {Object} options Additional options for the top-k sampling.
701
+ * @param {number} [options.filter_value=-Infinity] All filtered values will be set to this float value.
702
+ * @param {number} [options.min_tokens_to_keep=1] Minimum number of tokens that cannot be filtered.
703
+ */
704
+ constructor(top_k, {
705
+ filter_value = -Infinity,
706
+ min_tokens_to_keep = 1,
707
+ } = {}) {
708
+ super();
709
+ if (!Number.isInteger(top_k) || top_k < 0) {
710
+ throw new Error(`\`top_k\` must be a positive integer, but is ${top_k}`)
711
+ }
712
+
713
+ this.top_k = Math.max(top_k, min_tokens_to_keep)
714
+ this.filter_value = filter_value
715
+ }
716
+ }