@huggingface/transformers 3.0.0-alpha.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +202 -0
- package/README.md +376 -0
- package/dist/ort-wasm-simd-threaded.jsep.wasm +0 -0
- package/dist/transformers.cjs +30741 -0
- package/dist/transformers.cjs.map +1 -0
- package/dist/transformers.js +33858 -0
- package/dist/transformers.js.map +1 -0
- package/dist/transformers.min.cjs +173 -0
- package/dist/transformers.min.cjs.map +1 -0
- package/dist/transformers.min.js +231 -0
- package/dist/transformers.min.js.map +1 -0
- package/package.json +92 -0
- package/src/backends/onnx.js +151 -0
- package/src/configs.js +360 -0
- package/src/env.js +152 -0
- package/src/generation/configuration_utils.js +381 -0
- package/src/generation/logits_process.js +716 -0
- package/src/generation/logits_sampler.js +204 -0
- package/src/generation/parameters.js +35 -0
- package/src/generation/stopping_criteria.js +156 -0
- package/src/generation/streamers.js +212 -0
- package/src/models/whisper/common_whisper.js +151 -0
- package/src/models/whisper/generation_whisper.js +89 -0
- package/src/models.js +7028 -0
- package/src/ops/registry.js +92 -0
- package/src/pipelines.js +3341 -0
- package/src/processors.js +2614 -0
- package/src/tokenizers.js +4395 -0
- package/src/transformers.js +28 -0
- package/src/utils/audio.js +704 -0
- package/src/utils/constants.js +2 -0
- package/src/utils/core.js +149 -0
- package/src/utils/data-structures.js +445 -0
- package/src/utils/devices.js +11 -0
- package/src/utils/dtypes.js +62 -0
- package/src/utils/generic.js +35 -0
- package/src/utils/hub.js +671 -0
- package/src/utils/image.js +745 -0
- package/src/utils/maths.js +1050 -0
- package/src/utils/tensor.js +1378 -0
- package/types/backends/onnx.d.ts +26 -0
- package/types/backends/onnx.d.ts.map +1 -0
- package/types/configs.d.ts +59 -0
- package/types/configs.d.ts.map +1 -0
- package/types/env.d.ts +106 -0
- package/types/env.d.ts.map +1 -0
- package/types/generation/configuration_utils.d.ts +320 -0
- package/types/generation/configuration_utils.d.ts.map +1 -0
- package/types/generation/logits_process.d.ts +354 -0
- package/types/generation/logits_process.d.ts.map +1 -0
- package/types/generation/logits_sampler.d.ts +51 -0
- package/types/generation/logits_sampler.d.ts.map +1 -0
- package/types/generation/parameters.d.ts +47 -0
- package/types/generation/parameters.d.ts.map +1 -0
- package/types/generation/stopping_criteria.d.ts +81 -0
- package/types/generation/stopping_criteria.d.ts.map +1 -0
- package/types/generation/streamers.d.ts +81 -0
- package/types/generation/streamers.d.ts.map +1 -0
- package/types/models/whisper/common_whisper.d.ts +8 -0
- package/types/models/whisper/common_whisper.d.ts.map +1 -0
- package/types/models/whisper/generation_whisper.d.ts +76 -0
- package/types/models/whisper/generation_whisper.d.ts.map +1 -0
- package/types/models.d.ts +3845 -0
- package/types/models.d.ts.map +1 -0
- package/types/ops/registry.d.ts +11 -0
- package/types/ops/registry.d.ts.map +1 -0
- package/types/pipelines.d.ts +2403 -0
- package/types/pipelines.d.ts.map +1 -0
- package/types/processors.d.ts +917 -0
- package/types/processors.d.ts.map +1 -0
- package/types/tokenizers.d.ts +999 -0
- package/types/tokenizers.d.ts.map +1 -0
- package/types/transformers.d.ts +13 -0
- package/types/transformers.d.ts.map +1 -0
- package/types/utils/audio.d.ts +130 -0
- package/types/utils/audio.d.ts.map +1 -0
- package/types/utils/constants.d.ts +2 -0
- package/types/utils/constants.d.ts.map +1 -0
- package/types/utils/core.d.ts +91 -0
- package/types/utils/core.d.ts.map +1 -0
- package/types/utils/data-structures.d.ts +236 -0
- package/types/utils/data-structures.d.ts.map +1 -0
- package/types/utils/devices.d.ts +8 -0
- package/types/utils/devices.d.ts.map +1 -0
- package/types/utils/dtypes.d.ts +22 -0
- package/types/utils/dtypes.d.ts.map +1 -0
- package/types/utils/generic.d.ts +11 -0
- package/types/utils/generic.d.ts.map +1 -0
- package/types/utils/hub.d.ts +191 -0
- package/types/utils/hub.d.ts.map +1 -0
- package/types/utils/image.d.ts +119 -0
- package/types/utils/image.d.ts.map +1 -0
- package/types/utils/maths.d.ts +280 -0
- package/types/utils/maths.d.ts.map +1 -0
- package/types/utils/tensor.d.ts +392 -0
- package/types/utils/tensor.d.ts.map +1 -0
|
@@ -0,0 +1,716 @@
|
|
|
1
|
+
|
|
2
|
+
/**
|
|
3
|
+
* @module generation/logits_process
|
|
4
|
+
*/
|
|
5
|
+
|
|
6
|
+
import { Callable } from "../utils/generic.js";
|
|
7
|
+
import { Tensor } from "../utils/tensor.js";
|
|
8
|
+
|
|
9
|
+
import { max, log_softmax } from "../utils/maths.js";
|
|
10
|
+
|
|
11
|
+
/**
|
|
12
|
+
* Abstract base class for all logit processors that can be applied during generation.
|
|
13
|
+
*/
|
|
14
|
+
export class LogitsProcessor extends Callable {
|
|
15
|
+
/**
|
|
16
|
+
* Apply the processor to the input logits.
|
|
17
|
+
*
|
|
18
|
+
* @abstract
|
|
19
|
+
* @param {bigint[][]} input_ids The input ids.
|
|
20
|
+
* @param {Tensor} logits The logits to process.
|
|
21
|
+
* @throws {Error} Throws an error if `_call` is not implemented in the subclass.
|
|
22
|
+
*/
|
|
23
|
+
_call(input_ids, logits) {
|
|
24
|
+
throw Error("`_call` should be implemented in a subclass")
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
/**
|
|
30
|
+
* Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.
|
|
31
|
+
*/
|
|
32
|
+
export class LogitsWarper extends Callable {
|
|
33
|
+
/**
|
|
34
|
+
* Apply the processor to the input logits.
|
|
35
|
+
*
|
|
36
|
+
* @abstract
|
|
37
|
+
* @param {bigint[][]} input_ids The input ids.
|
|
38
|
+
* @param {Tensor} logits The logits to process.
|
|
39
|
+
* @throws {Error} Throws an error if `_call` is not implemented in the subclass.
|
|
40
|
+
*/
|
|
41
|
+
_call(input_ids, logits) {
|
|
42
|
+
throw Error("`_call` should be implemented in a subclass")
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
/**
|
|
48
|
+
* A class representing a list of logits processors. A logits processor is a function that modifies the logits
|
|
49
|
+
* output of a language model. This class provides methods for adding new processors and applying all processors to a
|
|
50
|
+
* batch of logits.
|
|
51
|
+
*/
|
|
52
|
+
export class LogitsProcessorList extends Callable {
|
|
53
|
+
/**
|
|
54
|
+
* Constructs a new instance of `LogitsProcessorList`.
|
|
55
|
+
*/
|
|
56
|
+
constructor() {
|
|
57
|
+
super();
|
|
58
|
+
this.processors = [];
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
/**
|
|
62
|
+
* Adds a new logits processor to the list.
|
|
63
|
+
*
|
|
64
|
+
* @param {LogitsProcessor} item The logits processor function to add.
|
|
65
|
+
*/
|
|
66
|
+
push(item) {
|
|
67
|
+
this.processors.push(item);
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
/**
|
|
71
|
+
* Adds multiple logits processors to the list.
|
|
72
|
+
*
|
|
73
|
+
* @param {LogitsProcessor[]} items The logits processor functions to add.
|
|
74
|
+
*/
|
|
75
|
+
extend(items) {
|
|
76
|
+
this.processors.push(...items);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
/**
|
|
80
|
+
* Applies all logits processors in the list to a batch of logits, modifying them in-place.
|
|
81
|
+
*
|
|
82
|
+
* @param {bigint[][]} input_ids The input IDs for the language model.
|
|
83
|
+
* @param {Tensor} logits
|
|
84
|
+
*/
|
|
85
|
+
_call(input_ids, logits) {
|
|
86
|
+
let toReturn = logits;
|
|
87
|
+
// NOTE: Most processors modify logits inplace
|
|
88
|
+
for (const processor of this.processors) {
|
|
89
|
+
toReturn = processor(input_ids, toReturn);
|
|
90
|
+
}
|
|
91
|
+
return toReturn;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
[Symbol.iterator]() {
|
|
95
|
+
return this.processors.values();
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
// DEPRECATED: https://github.com/huggingface/transformers/pull/29485
|
|
100
|
+
// /**
|
|
101
|
+
// * A logits processor that forces a specific token to be generated by the decoder.
|
|
102
|
+
// */
|
|
103
|
+
// export class ForceTokensLogitsProcessor extends LogitsProcessor {
|
|
104
|
+
// /**
|
|
105
|
+
// * Constructs a new instance of `ForceTokensLogitsProcessor`.
|
|
106
|
+
// *
|
|
107
|
+
// * @param {[number, number][]} forced_decoder_ids The ids of tokens that should be forced.
|
|
108
|
+
// */
|
|
109
|
+
// constructor(forced_decoder_ids) {
|
|
110
|
+
// super();
|
|
111
|
+
// // TODO: convert to `new Map(forced_decoder_ids)`
|
|
112
|
+
// this.force_token_map = Object.fromEntries(forced_decoder_ids ?? []);
|
|
113
|
+
// }
|
|
114
|
+
|
|
115
|
+
// /**
|
|
116
|
+
// * Apply the processor to the input logits.
|
|
117
|
+
// *
|
|
118
|
+
// * @param {bigint[][]} input_ids The input ids.
|
|
119
|
+
// * @param {Tensor} logits The logits to process.
|
|
120
|
+
// * @returns {Tensor} The processed logits.
|
|
121
|
+
// */
|
|
122
|
+
// _call(input_ids, logits) {
|
|
123
|
+
// console.log('this.force_token_map', this.force_token_map)
|
|
124
|
+
// console.log('call ForceTokensLogitsProcessor', input_ids, logits)
|
|
125
|
+
// console.log('input_ids.length', input_ids.length)
|
|
126
|
+
// let map = this.force_token_map[input_ids.length];
|
|
127
|
+
// if (map) { // There exists a mapping
|
|
128
|
+
// logits.data.fill(-Infinity)
|
|
129
|
+
// logits.data[map] = 0;
|
|
130
|
+
// }
|
|
131
|
+
// console.log('map', map)
|
|
132
|
+
// // throw Error("Not implemented")
|
|
133
|
+
// return logits;
|
|
134
|
+
// }
|
|
135
|
+
// }
|
|
136
|
+
|
|
137
|
+
/**
|
|
138
|
+
* A LogitsProcessor that forces a BOS token at the beginning of the generated sequence.
|
|
139
|
+
*/
|
|
140
|
+
export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
|
|
141
|
+
/**
|
|
142
|
+
* Create a ForcedBOSTokenLogitsProcessor.
|
|
143
|
+
* @param {number} bos_token_id The ID of the beginning-of-sequence token to be forced.
|
|
144
|
+
*/
|
|
145
|
+
constructor(bos_token_id) {
|
|
146
|
+
super();
|
|
147
|
+
this.bos_token_id = bos_token_id;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
/**
|
|
151
|
+
* Apply the BOS token forcing to the logits.
|
|
152
|
+
* @param {bigint[][]} input_ids The input IDs.
|
|
153
|
+
* @param {Tensor} logits The logits.
|
|
154
|
+
* @returns {Object} The logits with BOS token forcing.
|
|
155
|
+
*/
|
|
156
|
+
_call(input_ids, logits) {
|
|
157
|
+
for (let i = 0; i < input_ids.length; ++i) {
|
|
158
|
+
if (input_ids[i].length === 1) {
|
|
159
|
+
const batch_logits = logits[i];
|
|
160
|
+
batch_logits.data.fill(-Infinity);
|
|
161
|
+
batch_logits.data[this.bos_token_id] = 0;
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
return logits;
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
/**
|
|
169
|
+
* A logits processor that enforces the specified token as the last generated token when `max_length` is reached.
|
|
170
|
+
*/
|
|
171
|
+
export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
|
|
172
|
+
/**
|
|
173
|
+
* Create a ForcedEOSTokenLogitsProcessor.
|
|
174
|
+
* @param {number} max_length The maximum length of the sequence to be generated.
|
|
175
|
+
* @param {number|number[]} eos_token_id The id(s) of the *end-of-sequence* token.
|
|
176
|
+
*/
|
|
177
|
+
constructor(max_length, eos_token_id) {
|
|
178
|
+
super();
|
|
179
|
+
this.max_length = max_length;
|
|
180
|
+
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
/**
|
|
184
|
+
* Apply the processor to input_ids and logits.
|
|
185
|
+
*
|
|
186
|
+
* @param {bigint[][]} input_ids The input ids.
|
|
187
|
+
* @param {Tensor} logits The logits tensor.
|
|
188
|
+
*/
|
|
189
|
+
_call(input_ids, logits) {
|
|
190
|
+
for (let i = 0; i < input_ids.length; ++i) {
|
|
191
|
+
if (input_ids[i].length === this.max_length - 1) {
|
|
192
|
+
const batch_logits = logits[i];
|
|
193
|
+
batch_logits.data.fill(-Infinity);
|
|
194
|
+
|
|
195
|
+
for (const eos_token of this.eos_token_id) {
|
|
196
|
+
batch_logits.data[eos_token] = 0;
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
return logits;
|
|
201
|
+
}
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
/**
|
|
205
|
+
* A LogitsProcessor that suppresses a list of tokens as soon as the `generate` function starts
|
|
206
|
+
* generating using `begin_index` tokens. This should ensure that the tokens defined by
|
|
207
|
+
* `begin_suppress_tokens` at not sampled at the begining of the generation.
|
|
208
|
+
*/
|
|
209
|
+
export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor {
|
|
210
|
+
/**
|
|
211
|
+
* Create a SuppressTokensAtBeginLogitsProcessor.
|
|
212
|
+
* @param {number[]} begin_suppress_tokens The IDs of the tokens to suppress.
|
|
213
|
+
* @param {number} begin_index The number of tokens to generate before suppressing tokens.
|
|
214
|
+
*/
|
|
215
|
+
constructor(begin_suppress_tokens, begin_index) {
|
|
216
|
+
super();
|
|
217
|
+
this.begin_suppress_tokens = begin_suppress_tokens;
|
|
218
|
+
this.begin_index = begin_index;
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
/**
|
|
222
|
+
* Apply the BOS token forcing to the logits.
|
|
223
|
+
* @param {bigint[][]} input_ids The input IDs.
|
|
224
|
+
* @param {Tensor} logits The logits.
|
|
225
|
+
* @returns {Object} The logits with BOS token forcing.
|
|
226
|
+
*/
|
|
227
|
+
_call(input_ids, logits) {
|
|
228
|
+
for (let i = 0; i < input_ids.length; ++i) {
|
|
229
|
+
if (input_ids[i].length === this.begin_index) {
|
|
230
|
+
const batch_logits = logits[i];
|
|
231
|
+
for (const token_id of this.begin_suppress_tokens) {
|
|
232
|
+
batch_logits.data[token_id] = -Infinity;
|
|
233
|
+
}
|
|
234
|
+
}
|
|
235
|
+
}
|
|
236
|
+
return logits;
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
/**
|
|
241
|
+
* A LogitsProcessor that handles adding timestamps to generated text.
|
|
242
|
+
*/
|
|
243
|
+
export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
|
|
244
|
+
/**
|
|
245
|
+
* Constructs a new WhisperTimeStampLogitsProcessor.
|
|
246
|
+
* @param {import('../models/whisper/generation_whisper.js').WhisperGenerationConfig} generate_config The config object passed to the `generate()` method of a transformer model.
|
|
247
|
+
* @param {number[]} init_tokens The initial tokens of the input sequence.
|
|
248
|
+
*/
|
|
249
|
+
constructor(generate_config, init_tokens) {
|
|
250
|
+
super();
|
|
251
|
+
this.eos_token_id =
|
|
252
|
+
Array.isArray(generate_config.eos_token_id)
|
|
253
|
+
? generate_config.eos_token_id[0]
|
|
254
|
+
: generate_config.eos_token_id;
|
|
255
|
+
|
|
256
|
+
this.no_timestamps_token_id = generate_config.no_timestamps_token_id;
|
|
257
|
+
this.timestamp_begin = this.no_timestamps_token_id + 1;
|
|
258
|
+
|
|
259
|
+
this.begin_index = init_tokens.length;
|
|
260
|
+
if (init_tokens.at(-1) === this.no_timestamps_token_id) {
|
|
261
|
+
this.begin_index -= 1;
|
|
262
|
+
}
|
|
263
|
+
this.max_initial_timestamp_index = generate_config.max_initial_timestamp_index;
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
/**
|
|
267
|
+
* Modify the logits to handle timestamp tokens.
|
|
268
|
+
* @param {bigint[][]} input_ids The input sequence of tokens.
|
|
269
|
+
* @param {Tensor} logits The logits output by the model.
|
|
270
|
+
* @returns {Tensor} The modified logits.
|
|
271
|
+
*/
|
|
272
|
+
_call(input_ids, logits) {
|
|
273
|
+
for (let i = 0; i < input_ids.length; ++i) {
|
|
274
|
+
const batch_logits = logits[i];
|
|
275
|
+
const logitsData = /** @type {Float32Array} */(batch_logits.data);
|
|
276
|
+
|
|
277
|
+
// suppress <|notimestamps|> which is handled by without_timestamps
|
|
278
|
+
logitsData[this.no_timestamps_token_id] = -Infinity;
|
|
279
|
+
|
|
280
|
+
if (input_ids[i].length === this.begin_index - 1) {
|
|
281
|
+
logitsData.fill(-Infinity);
|
|
282
|
+
logitsData[this.timestamp_begin] = 0;
|
|
283
|
+
continue;
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
// timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
|
|
287
|
+
const seq = input_ids[i].slice(this.begin_index);
|
|
288
|
+
const last_was_timestamp = seq.length >= 1 && seq[seq.length - 1] >= this.timestamp_begin;
|
|
289
|
+
const penultimate_was_timestamp = seq.length < 2 || seq[seq.length - 2] >= this.timestamp_begin;
|
|
290
|
+
|
|
291
|
+
if (last_was_timestamp) {
|
|
292
|
+
if (penultimate_was_timestamp) { // has to be non-timestamp
|
|
293
|
+
logitsData.subarray(this.timestamp_begin).fill(-Infinity);
|
|
294
|
+
} else { // cannot be normal text tokens
|
|
295
|
+
logitsData.subarray(0, this.eos_token_id).fill(-Infinity);
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
// apply the `max_initial_timestamp` option
|
|
300
|
+
if (input_ids[i].length === this.begin_index && this.max_initial_timestamp_index !== null) {
|
|
301
|
+
const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index;
|
|
302
|
+
logitsData.subarray(last_allowed + 1).fill(-Infinity);
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
// if sum of probability over timestamps is above any other token, sample timestamp
|
|
306
|
+
const logprobs = log_softmax(logitsData);
|
|
307
|
+
const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b));
|
|
308
|
+
const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0];
|
|
309
|
+
|
|
310
|
+
if (timestamp_logprob > max_text_token_logprob) {
|
|
311
|
+
logitsData.subarray(0, this.timestamp_begin).fill(-Infinity);
|
|
312
|
+
}
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
return logits;
|
|
316
|
+
}
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
/**
|
|
320
|
+
* A logits processor that disallows ngrams of a certain size to be repeated.
|
|
321
|
+
*/
|
|
322
|
+
export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
|
|
323
|
+
/**
|
|
324
|
+
* Create a NoRepeatNGramLogitsProcessor.
|
|
325
|
+
* @param {number} no_repeat_ngram_size The no-repeat-ngram size. All ngrams of this size can only occur once.
|
|
326
|
+
*/
|
|
327
|
+
constructor(no_repeat_ngram_size) {
|
|
328
|
+
super();
|
|
329
|
+
this.no_repeat_ngram_size = no_repeat_ngram_size;
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
/**
|
|
333
|
+
* Generate n-grams from a sequence of token ids.
|
|
334
|
+
* @param {bigint[]} prevInputIds List of previous input ids
|
|
335
|
+
* @returns {Map<string, number[]>} Map of generated n-grams
|
|
336
|
+
*/
|
|
337
|
+
getNgrams(prevInputIds) {
|
|
338
|
+
const curLen = prevInputIds.length;
|
|
339
|
+
|
|
340
|
+
/**@type {number[][]} */
|
|
341
|
+
const ngrams = [];
|
|
342
|
+
for (let j = 0; j < curLen + 1 - this.no_repeat_ngram_size; ++j) {
|
|
343
|
+
const ngram = [];
|
|
344
|
+
for (let k = 0; k < this.no_repeat_ngram_size; ++k) {
|
|
345
|
+
ngram.push(prevInputIds[j + k]);
|
|
346
|
+
}
|
|
347
|
+
ngrams.push(ngram.map(Number));
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
/** @type {Map<string, number[]>} */
|
|
351
|
+
const generatedNgram = new Map();
|
|
352
|
+
for (const ngram of ngrams) {
|
|
353
|
+
const prevNgram = ngram.slice(0, ngram.length - 1);
|
|
354
|
+
const prevNgramKey = JSON.stringify(prevNgram);
|
|
355
|
+
const prevNgramValue = generatedNgram.get(prevNgramKey) ?? [];
|
|
356
|
+
prevNgramValue.push(ngram[ngram.length - 1]);
|
|
357
|
+
generatedNgram.set(prevNgramKey, prevNgramValue);
|
|
358
|
+
}
|
|
359
|
+
return generatedNgram;
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
/**
|
|
363
|
+
* Generate n-grams from a sequence of token ids.
|
|
364
|
+
* @param {Map<string, number[]>} bannedNgrams Map of banned n-grams
|
|
365
|
+
* @param {bigint[]} prevInputIds List of previous input ids
|
|
366
|
+
* @returns {number[]} Map of generated n-grams
|
|
367
|
+
*/
|
|
368
|
+
getGeneratedNgrams(bannedNgrams, prevInputIds) {
|
|
369
|
+
const ngramIdx = prevInputIds.slice(prevInputIds.length + 1 - this.no_repeat_ngram_size, prevInputIds.length);
|
|
370
|
+
const banned = bannedNgrams.get(JSON.stringify(ngramIdx.map(Number))) ?? [];
|
|
371
|
+
return banned;
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
/**
|
|
375
|
+
* Calculate banned n-gram tokens
|
|
376
|
+
* @param {bigint[]} prevInputIds List of previous input ids
|
|
377
|
+
* @returns {number[]} Map of generated n-grams
|
|
378
|
+
*/
|
|
379
|
+
calcBannedNgramTokens(prevInputIds) {
|
|
380
|
+
const bannedTokens = [];
|
|
381
|
+
if (prevInputIds.length + 1 < this.no_repeat_ngram_size) {
|
|
382
|
+
// return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
|
383
|
+
return bannedTokens;
|
|
384
|
+
|
|
385
|
+
} else {
|
|
386
|
+
const generatedNgrams = this.getNgrams(prevInputIds);
|
|
387
|
+
const bannedTokens = this.getGeneratedNgrams(generatedNgrams, prevInputIds);
|
|
388
|
+
return bannedTokens;
|
|
389
|
+
}
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
/**
|
|
393
|
+
* Apply the no-repeat-ngram processor to the logits.
|
|
394
|
+
* @param {bigint[][]} input_ids The input IDs.
|
|
395
|
+
* @param {Tensor} logits The logits.
|
|
396
|
+
* @returns {Object} The logits with no-repeat-ngram processing.
|
|
397
|
+
*/
|
|
398
|
+
_call(input_ids, logits) {
|
|
399
|
+
for (let i = 0; i < input_ids.length; ++i) {
|
|
400
|
+
const batch_logits = logits[i];
|
|
401
|
+
const bannedTokens = this.calcBannedNgramTokens(input_ids[i]);
|
|
402
|
+
for (const token of bannedTokens) {
|
|
403
|
+
batch_logits.data[token] = -Infinity;
|
|
404
|
+
}
|
|
405
|
+
}
|
|
406
|
+
return logits;
|
|
407
|
+
}
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
/**
|
|
411
|
+
* A logits processor that penalises repeated output tokens.
|
|
412
|
+
*/
|
|
413
|
+
export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
|
|
414
|
+
/**
|
|
415
|
+
* Create a RepetitionPenaltyLogitsProcessor.
|
|
416
|
+
* @param {number} penalty The penalty to apply for repeated tokens.
|
|
417
|
+
*/
|
|
418
|
+
constructor(penalty) {
|
|
419
|
+
super();
|
|
420
|
+
this.penalty = penalty;
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
/**
|
|
424
|
+
* Apply the repetition penalty to the logits.
|
|
425
|
+
* @param {bigint[][]} input_ids The input IDs.
|
|
426
|
+
* @param {Tensor} logits The logits.
|
|
427
|
+
* @returns {Object} The logits with repetition penalty processing.
|
|
428
|
+
*/
|
|
429
|
+
_call(input_ids, logits) {
|
|
430
|
+
// Modify the logits corresponding to each element in `input_ids`.
|
|
431
|
+
// As a consequence, the logits corresponding to tokens that appear
|
|
432
|
+
// many times in the output will be penalised more.
|
|
433
|
+
|
|
434
|
+
for (let i = 0; i < input_ids.length; ++i) {
|
|
435
|
+
const batch_logits = logits[i];
|
|
436
|
+
|
|
437
|
+
for (const input_id of input_ids[i]) {
|
|
438
|
+
if (batch_logits.data[input_id] < 0) {
|
|
439
|
+
batch_logits.data[input_id] *= this.penalty;
|
|
440
|
+
} else {
|
|
441
|
+
batch_logits.data[input_id] /= this.penalty;
|
|
442
|
+
}
|
|
443
|
+
}
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
return logits
|
|
447
|
+
}
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
/**
|
|
451
|
+
* A logits processor that enforces a minimum number of tokens.
|
|
452
|
+
*/
|
|
453
|
+
export class MinLengthLogitsProcessor extends LogitsProcessor {
|
|
454
|
+
/**
|
|
455
|
+
* Create a MinLengthLogitsProcessor.
|
|
456
|
+
* @param {number} min_length The minimum length below which the score of `eos_token_id` is set to negative infinity.
|
|
457
|
+
* @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
|
|
458
|
+
*/
|
|
459
|
+
constructor(min_length, eos_token_id) {
|
|
460
|
+
super();
|
|
461
|
+
this.min_length = min_length;
|
|
462
|
+
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
/**
|
|
466
|
+
* Apply logit processor.
|
|
467
|
+
* @param {bigint[][]} input_ids The input IDs.
|
|
468
|
+
* @param {Tensor} logits The logits.
|
|
469
|
+
* @returns {Object} The processed logits.
|
|
470
|
+
*/
|
|
471
|
+
_call(input_ids, logits) {
|
|
472
|
+
for (let i = 0; i < input_ids.length; ++i) {
|
|
473
|
+
if (input_ids[i].length < this.min_length) {
|
|
474
|
+
const batch_logits = logits[i];
|
|
475
|
+
for (const eos_token of this.eos_token_id) {
|
|
476
|
+
batch_logits.data[eos_token] = -Infinity;
|
|
477
|
+
}
|
|
478
|
+
}
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
return logits
|
|
482
|
+
}
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
/**
|
|
486
|
+
* A logits processor that enforces a minimum number of new tokens.
|
|
487
|
+
*/
|
|
488
|
+
export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
|
|
489
|
+
/**
|
|
490
|
+
* Create a MinNewTokensLengthLogitsProcessor.
|
|
491
|
+
* @param {number} prompt_length_to_skip The input tokens length.
|
|
492
|
+
* @param {number} min_new_tokens The minimum *new* tokens length below which the score of `eos_token_id` is set to negative infinity.
|
|
493
|
+
* @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
|
|
494
|
+
*/
|
|
495
|
+
constructor(prompt_length_to_skip, min_new_tokens, eos_token_id) {
|
|
496
|
+
super();
|
|
497
|
+
this.prompt_length_to_skip = prompt_length_to_skip;
|
|
498
|
+
this.min_new_tokens = min_new_tokens;
|
|
499
|
+
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
|
|
500
|
+
}
|
|
501
|
+
|
|
502
|
+
/**
|
|
503
|
+
* Apply logit processor.
|
|
504
|
+
* @param {bigint[][]} input_ids The input IDs.
|
|
505
|
+
* @param {Tensor} logits The logits.
|
|
506
|
+
* @returns {Object} The processed logits.
|
|
507
|
+
*/
|
|
508
|
+
_call(input_ids, logits) {
|
|
509
|
+
for (let i = 0; i < input_ids.length; ++i) {
|
|
510
|
+
const new_tokens_length = input_ids[i].length - this.prompt_length_to_skip;
|
|
511
|
+
if (new_tokens_length < this.min_new_tokens) {
|
|
512
|
+
const batch_logits = logits[i];
|
|
513
|
+
for (const eos_token of this.eos_token_id) {
|
|
514
|
+
batch_logits[eos_token] = -Infinity;
|
|
515
|
+
}
|
|
516
|
+
}
|
|
517
|
+
}
|
|
518
|
+
return logits
|
|
519
|
+
}
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
export class NoBadWordsLogitsProcessor extends LogitsProcessor {
|
|
523
|
+
/**
|
|
524
|
+
* Create a `NoBadWordsLogitsProcessor`.
|
|
525
|
+
* @param {number[][]} bad_words_ids List of list of token ids that are not allowed to be generated.
|
|
526
|
+
* @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
|
527
|
+
*/
|
|
528
|
+
constructor(bad_words_ids, eos_token_id) {
|
|
529
|
+
super();
|
|
530
|
+
this.bad_words_ids = bad_words_ids;
|
|
531
|
+
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
/**
|
|
535
|
+
* Apply logit processor.
|
|
536
|
+
* @param {bigint[][]} input_ids The input IDs.
|
|
537
|
+
* @param {Tensor} logits The logits.
|
|
538
|
+
* @returns {Object} The processed logits.
|
|
539
|
+
*/
|
|
540
|
+
_call(input_ids, logits) {
|
|
541
|
+
for (let i = 0; i < input_ids.length; ++i) {
|
|
542
|
+
const batch_logits = logits[i];
|
|
543
|
+
for (const bad_word_ids of this.bad_words_ids) {
|
|
544
|
+
// Whether to modify the logits of the last token in the bad word id sequence
|
|
545
|
+
let mark = true;
|
|
546
|
+
|
|
547
|
+
// For each bad word in the list, if the current sequence of input ids ends with this sequence (excluding the last),
|
|
548
|
+
// then we set the logits of the last bad word id to -Infinity.
|
|
549
|
+
for (let i = 1; i <= bad_word_ids.length - 1 && bad_word_ids.length < input_ids[i].length; ++i) {
|
|
550
|
+
|
|
551
|
+
if (bad_word_ids.at(-i - 1) !== Number(input_ids[i].at(-i))) {
|
|
552
|
+
// We have found a mismatch
|
|
553
|
+
mark = false;
|
|
554
|
+
break;
|
|
555
|
+
}
|
|
556
|
+
}
|
|
557
|
+
if (mark) {
|
|
558
|
+
batch_logits[bad_word_ids.at(-1)] = -Infinity;
|
|
559
|
+
}
|
|
560
|
+
}
|
|
561
|
+
}
|
|
562
|
+
return logits
|
|
563
|
+
}
|
|
564
|
+
}
|
|
565
|
+
|
|
566
|
+
/**
|
|
567
|
+
* [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension,
|
|
568
|
+
* where the first half correspond to the conditional logits (predicted from the input prompt) and the second half
|
|
569
|
+
* correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a
|
|
570
|
+
* weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`.
|
|
571
|
+
*
|
|
572
|
+
* See [the paper](https://arxiv.org/abs/2306.05284) for more information.
|
|
573
|
+
*/
|
|
574
|
+
export class ClassifierFreeGuidanceLogitsProcessor extends LogitsProcessor {
|
|
575
|
+
|
|
576
|
+
/**
|
|
577
|
+
* Create a `ClassifierFreeGuidanceLogitsProcessor`.
|
|
578
|
+
* @param {number} guidance_scale The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
|
|
579
|
+
* Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
|
580
|
+
* prompt, usually at the expense of poorer quality.
|
|
581
|
+
*/
|
|
582
|
+
constructor(guidance_scale) {
|
|
583
|
+
super();
|
|
584
|
+
if (guidance_scale <= 1) {
|
|
585
|
+
throw new Error(
|
|
586
|
+
`Require guidance scale >1 to use the classifier free guidance processor, got guidance scale ${guidance_scale}.`
|
|
587
|
+
)
|
|
588
|
+
}
|
|
589
|
+
this.guidance_scale = guidance_scale;
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
/**
|
|
593
|
+
* Apply logit processor.
|
|
594
|
+
* @param {bigint[][]} input_ids The input IDs.
|
|
595
|
+
* @param {Tensor} logits The logits.
|
|
596
|
+
* @returns {Object} The processed logits.
|
|
597
|
+
*/
|
|
598
|
+
_call(input_ids, logits) {
|
|
599
|
+
if (logits.dims[0] !== 2 * input_ids.length) {
|
|
600
|
+
throw new Error(
|
|
601
|
+
`Logits should have twice the batch size of the input ids, the first half of batches corresponding to ` +
|
|
602
|
+
`the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got ` +
|
|
603
|
+
`batch size ${logits.dims[0]} for the logits and ${input_ids.length} for the input ids.`
|
|
604
|
+
)
|
|
605
|
+
}
|
|
606
|
+
|
|
607
|
+
const unguided_bsz = input_ids.length;
|
|
608
|
+
const cond_logits = logits.slice([0, unguided_bsz], null);
|
|
609
|
+
const uncond_logits = logits.slice([unguided_bsz, logits.dims[0]], null);
|
|
610
|
+
|
|
611
|
+
// Merge into uncond_logits (to save memory). This is equivalent to the following:
|
|
612
|
+
// scores = uncond_logits + (cond_logits - uncond_logits) * guidance_scale
|
|
613
|
+
for (let i = 0; i < uncond_logits.data.length; ++i) {
|
|
614
|
+
uncond_logits.data[i] += (cond_logits.data[i] - uncond_logits.data[i]) * this.guidance_scale;
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
return uncond_logits;
|
|
618
|
+
}
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
/**
|
|
622
|
+
* [`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means
|
|
623
|
+
* that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and [`TopKLogitsWarper`].
|
|
624
|
+
*/
|
|
625
|
+
export class TemperatureLogitsWarper extends LogitsWarper {
|
|
626
|
+
/**
|
|
627
|
+
* Create a `TemperatureLogitsWarper`.
|
|
628
|
+
* @param {number} temperature Strictly positive float value used to modulate the logits distribution.
|
|
629
|
+
* A value smaller than `1` decreases randomness (and vice versa), with `0` being equivalent to shifting
|
|
630
|
+
* all probability mass to the most likely token.
|
|
631
|
+
*/
|
|
632
|
+
constructor(temperature) {
|
|
633
|
+
super();
|
|
634
|
+
|
|
635
|
+
if (typeof temperature !== 'number' || temperature <= 0) {
|
|
636
|
+
let errorMessage =
|
|
637
|
+
`\`temperature\` (=${temperature}) must be a strictly positive float, otherwise your next token scores will be invalid.`;
|
|
638
|
+
|
|
639
|
+
if (temperature === 0) {
|
|
640
|
+
errorMessage += " If you're looking for greedy decoding strategies, set `do_sample=false`."
|
|
641
|
+
}
|
|
642
|
+
}
|
|
643
|
+
this.temperature = temperature;
|
|
644
|
+
}
|
|
645
|
+
|
|
646
|
+
/**
|
|
647
|
+
* Apply logit warper.
|
|
648
|
+
* @param {bigint[][]} input_ids The input IDs.
|
|
649
|
+
* @param {Tensor} logits The logits.
|
|
650
|
+
* @returns {Object} The processed logits.
|
|
651
|
+
*/
|
|
652
|
+
_call(input_ids, logits) {
|
|
653
|
+
const logitsData = /** @type {Float32Array} */(logits.data);
|
|
654
|
+
for (let i = 0; i < logitsData.length; ++i) {
|
|
655
|
+
logitsData[i] /= this.temperature;
|
|
656
|
+
}
|
|
657
|
+
return logits;
|
|
658
|
+
}
|
|
659
|
+
}
|
|
660
|
+
|
|
661
|
+
/**
|
|
662
|
+
* [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
|
663
|
+
* Often used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].
|
|
664
|
+
*/
|
|
665
|
+
export class TopPLogitsWarper extends LogitsWarper {
|
|
666
|
+
/**
|
|
667
|
+
* Create a `TopPLogitsWarper`.
|
|
668
|
+
* @param {number} top_p If set to < 1, only the smallest set of most probable tokens with
|
|
669
|
+
* probabilities that add up to `top_p` or higher are kept for generation.
|
|
670
|
+
* @param {Object} options Additional options for the top-p sampling.
|
|
671
|
+
* @param {number} [options.filter_value=-Infinity] All filtered values will be set to this float value.
|
|
672
|
+
* @param {number} [options.min_tokens_to_keep=1] Minimum number of tokens that cannot be filtered.
|
|
673
|
+
*/
|
|
674
|
+
constructor(top_p, {
|
|
675
|
+
filter_value = -Infinity,
|
|
676
|
+
min_tokens_to_keep = 1,
|
|
677
|
+
} = {}) {
|
|
678
|
+
super();
|
|
679
|
+
if (top_p < 0 || top_p > 1.0) {
|
|
680
|
+
throw new Error(`\`top_p\` must be a float > 0 and < 1, but is ${top_p}`)
|
|
681
|
+
}
|
|
682
|
+
if (!Number.isInteger(min_tokens_to_keep) || min_tokens_to_keep < 1) {
|
|
683
|
+
throw new Error(`\`min_tokens_to_keep\` must be a positive integer, but is ${min_tokens_to_keep}`)
|
|
684
|
+
}
|
|
685
|
+
|
|
686
|
+
this.top_p = top_p
|
|
687
|
+
this.filter_value = filter_value
|
|
688
|
+
this.min_tokens_to_keep = min_tokens_to_keep
|
|
689
|
+
}
|
|
690
|
+
}
|
|
691
|
+
|
|
692
|
+
/**
|
|
693
|
+
* [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
|
694
|
+
* Often used together with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].
|
|
695
|
+
*/
|
|
696
|
+
export class TopKLogitsWarper extends LogitsWarper {
|
|
697
|
+
/**
|
|
698
|
+
* Create a `TopKLogitsWarper`.
|
|
699
|
+
* @param {number} top_k If set to > 0, only the top `top_k` tokens are kept for generation.
|
|
700
|
+
* @param {Object} options Additional options for the top-k sampling.
|
|
701
|
+
* @param {number} [options.filter_value=-Infinity] All filtered values will be set to this float value.
|
|
702
|
+
* @param {number} [options.min_tokens_to_keep=1] Minimum number of tokens that cannot be filtered.
|
|
703
|
+
*/
|
|
704
|
+
constructor(top_k, {
|
|
705
|
+
filter_value = -Infinity,
|
|
706
|
+
min_tokens_to_keep = 1,
|
|
707
|
+
} = {}) {
|
|
708
|
+
super();
|
|
709
|
+
if (!Number.isInteger(top_k) || top_k < 0) {
|
|
710
|
+
throw new Error(`\`top_k\` must be a positive integer, but is ${top_k}`)
|
|
711
|
+
}
|
|
712
|
+
|
|
713
|
+
this.top_k = Math.max(top_k, min_tokens_to_keep)
|
|
714
|
+
this.filter_value = filter_value
|
|
715
|
+
}
|
|
716
|
+
}
|