@huggingface/transformers 3.0.0-alpha.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +202 -0
- package/README.md +376 -0
- package/dist/ort-wasm-simd-threaded.jsep.wasm +0 -0
- package/dist/transformers.cjs +30741 -0
- package/dist/transformers.cjs.map +1 -0
- package/dist/transformers.js +33858 -0
- package/dist/transformers.js.map +1 -0
- package/dist/transformers.min.cjs +173 -0
- package/dist/transformers.min.cjs.map +1 -0
- package/dist/transformers.min.js +231 -0
- package/dist/transformers.min.js.map +1 -0
- package/package.json +92 -0
- package/src/backends/onnx.js +151 -0
- package/src/configs.js +360 -0
- package/src/env.js +152 -0
- package/src/generation/configuration_utils.js +381 -0
- package/src/generation/logits_process.js +716 -0
- package/src/generation/logits_sampler.js +204 -0
- package/src/generation/parameters.js +35 -0
- package/src/generation/stopping_criteria.js +156 -0
- package/src/generation/streamers.js +212 -0
- package/src/models/whisper/common_whisper.js +151 -0
- package/src/models/whisper/generation_whisper.js +89 -0
- package/src/models.js +7028 -0
- package/src/ops/registry.js +92 -0
- package/src/pipelines.js +3341 -0
- package/src/processors.js +2614 -0
- package/src/tokenizers.js +4395 -0
- package/src/transformers.js +28 -0
- package/src/utils/audio.js +704 -0
- package/src/utils/constants.js +2 -0
- package/src/utils/core.js +149 -0
- package/src/utils/data-structures.js +445 -0
- package/src/utils/devices.js +11 -0
- package/src/utils/dtypes.js +62 -0
- package/src/utils/generic.js +35 -0
- package/src/utils/hub.js +671 -0
- package/src/utils/image.js +745 -0
- package/src/utils/maths.js +1050 -0
- package/src/utils/tensor.js +1378 -0
- package/types/backends/onnx.d.ts +26 -0
- package/types/backends/onnx.d.ts.map +1 -0
- package/types/configs.d.ts +59 -0
- package/types/configs.d.ts.map +1 -0
- package/types/env.d.ts +106 -0
- package/types/env.d.ts.map +1 -0
- package/types/generation/configuration_utils.d.ts +320 -0
- package/types/generation/configuration_utils.d.ts.map +1 -0
- package/types/generation/logits_process.d.ts +354 -0
- package/types/generation/logits_process.d.ts.map +1 -0
- package/types/generation/logits_sampler.d.ts +51 -0
- package/types/generation/logits_sampler.d.ts.map +1 -0
- package/types/generation/parameters.d.ts +47 -0
- package/types/generation/parameters.d.ts.map +1 -0
- package/types/generation/stopping_criteria.d.ts +81 -0
- package/types/generation/stopping_criteria.d.ts.map +1 -0
- package/types/generation/streamers.d.ts +81 -0
- package/types/generation/streamers.d.ts.map +1 -0
- package/types/models/whisper/common_whisper.d.ts +8 -0
- package/types/models/whisper/common_whisper.d.ts.map +1 -0
- package/types/models/whisper/generation_whisper.d.ts +76 -0
- package/types/models/whisper/generation_whisper.d.ts.map +1 -0
- package/types/models.d.ts +3845 -0
- package/types/models.d.ts.map +1 -0
- package/types/ops/registry.d.ts +11 -0
- package/types/ops/registry.d.ts.map +1 -0
- package/types/pipelines.d.ts +2403 -0
- package/types/pipelines.d.ts.map +1 -0
- package/types/processors.d.ts +917 -0
- package/types/processors.d.ts.map +1 -0
- package/types/tokenizers.d.ts +999 -0
- package/types/tokenizers.d.ts.map +1 -0
- package/types/transformers.d.ts +13 -0
- package/types/transformers.d.ts.map +1 -0
- package/types/utils/audio.d.ts +130 -0
- package/types/utils/audio.d.ts.map +1 -0
- package/types/utils/constants.d.ts +2 -0
- package/types/utils/constants.d.ts.map +1 -0
- package/types/utils/core.d.ts +91 -0
- package/types/utils/core.d.ts.map +1 -0
- package/types/utils/data-structures.d.ts +236 -0
- package/types/utils/data-structures.d.ts.map +1 -0
- package/types/utils/devices.d.ts +8 -0
- package/types/utils/devices.d.ts.map +1 -0
- package/types/utils/dtypes.d.ts +22 -0
- package/types/utils/dtypes.d.ts.map +1 -0
- package/types/utils/generic.d.ts +11 -0
- package/types/utils/generic.d.ts.map +1 -0
- package/types/utils/hub.d.ts +191 -0
- package/types/utils/hub.d.ts.map +1 -0
- package/types/utils/image.d.ts +119 -0
- package/types/utils/image.d.ts.map +1 -0
- package/types/utils/maths.d.ts +280 -0
- package/types/utils/maths.d.ts.map +1 -0
- package/types/utils/tensor.d.ts +392 -0
- package/types/utils/tensor.d.ts.map +1 -0
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
|
|
2
|
+
/**
|
|
3
|
+
* @module generation/logits_sampler
|
|
4
|
+
*/
|
|
5
|
+
|
|
6
|
+
import { Callable } from "../utils/generic.js";
|
|
7
|
+
import { Tensor, topk } from "../utils/tensor.js";
|
|
8
|
+
|
|
9
|
+
import {
|
|
10
|
+
max,
|
|
11
|
+
softmax,
|
|
12
|
+
} from '../utils/maths.js';
|
|
13
|
+
import { GenerationConfig } from '../generation/configuration_utils.js';
|
|
14
|
+
|
|
15
|
+
/**
|
|
16
|
+
* Sampler is a base class for all sampling methods used for text generation.
|
|
17
|
+
*/
|
|
18
|
+
export class LogitsSampler extends Callable {
|
|
19
|
+
/**
|
|
20
|
+
* Creates a new Sampler object with the specified generation config.
|
|
21
|
+
* @param {GenerationConfig} generation_config The generation config.
|
|
22
|
+
*/
|
|
23
|
+
constructor(generation_config) {
|
|
24
|
+
super();
|
|
25
|
+
this.generation_config = generation_config;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
/**
|
|
29
|
+
* Executes the sampler, using the specified logits.
|
|
30
|
+
* @param {Tensor} logits
|
|
31
|
+
* @returns {Promise<[bigint, number][]>}
|
|
32
|
+
*/
|
|
33
|
+
async _call(logits) {
|
|
34
|
+
// Sample from logits, of dims [batch, sequence_length, vocab_size].
|
|
35
|
+
// If index is specified, sample from [batch, index, vocab_size].
|
|
36
|
+
return this.sample(logits);
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
/**
|
|
40
|
+
* Abstract method for sampling the logits.
|
|
41
|
+
* @param {Tensor} logits
|
|
42
|
+
* @throws {Error} If not implemented in subclass.
|
|
43
|
+
* @returns {Promise<[bigint, number][]>}
|
|
44
|
+
*/
|
|
45
|
+
async sample(logits) {
|
|
46
|
+
throw Error("sample should be implemented in subclasses.")
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
/**
|
|
50
|
+
* Returns the specified logits as an array, with temperature applied.
|
|
51
|
+
* @param {Tensor} logits
|
|
52
|
+
* @param {number} index
|
|
53
|
+
* @returns {Float32Array}
|
|
54
|
+
*/
|
|
55
|
+
getLogits(logits, index) {
|
|
56
|
+
let vocabSize = logits.dims.at(-1);
|
|
57
|
+
|
|
58
|
+
let logs = /** @type {Float32Array} */(logits.data);
|
|
59
|
+
|
|
60
|
+
if (index === -1) {
|
|
61
|
+
logs = logs.slice(-vocabSize);
|
|
62
|
+
} else {
|
|
63
|
+
let startIndex = index * vocabSize;
|
|
64
|
+
logs = logs.slice(startIndex, startIndex + vocabSize);
|
|
65
|
+
}
|
|
66
|
+
return logs;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
/**
|
|
70
|
+
* Selects an item randomly based on the specified probabilities.
|
|
71
|
+
* @param {import("../transformers.js").DataArray} probabilities An array of probabilities to use for selection.
|
|
72
|
+
* @returns {number} The index of the selected item.
|
|
73
|
+
*/
|
|
74
|
+
randomSelect(probabilities) {
|
|
75
|
+
// Return index of chosen item
|
|
76
|
+
let sumProbabilities = 0;
|
|
77
|
+
for (let i = 0; i < probabilities.length; ++i) {
|
|
78
|
+
sumProbabilities += probabilities[i];
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
let r = Math.random() * sumProbabilities;
|
|
82
|
+
for (let i = 0; i < probabilities.length; ++i) {
|
|
83
|
+
r -= probabilities[i];
|
|
84
|
+
if (r <= 0) {
|
|
85
|
+
return i;
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
return 0; // return first (most probable) as a fallback
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
/**
|
|
92
|
+
* Returns a Sampler object based on the specified options.
|
|
93
|
+
* @param {GenerationConfig} generation_config An object containing options for the sampler.
|
|
94
|
+
* @returns {LogitsSampler} A Sampler object.
|
|
95
|
+
*/
|
|
96
|
+
static getSampler(generation_config) {
|
|
97
|
+
// - *greedy decoding*: `num_beams=1` and `do_sample=False`
|
|
98
|
+
// - *contrastive search*: `penalty_alpha>0` and `top_k>1`
|
|
99
|
+
// - *multinomial sampling*: `num_beams=1` and `do_sample=True`
|
|
100
|
+
// - *beam-search decoding*: `num_beams>1` and `do_sample=False`
|
|
101
|
+
// - *beam-search multinomial sampling*: `num_beams>1` and `do_sample=True`
|
|
102
|
+
// - *diverse beam-search decoding*: `num_beams>1` and `num_beam_groups>1`
|
|
103
|
+
// - *constrained beam-search decoding*: `constraints!=None` or `force_words_ids!=None`
|
|
104
|
+
|
|
105
|
+
// NOTE: beam search is implemented directly into the generation function
|
|
106
|
+
if (generation_config.do_sample) {
|
|
107
|
+
return new MultinomialSampler(generation_config);
|
|
108
|
+
|
|
109
|
+
} else if (generation_config.num_beams > 1) {
|
|
110
|
+
return new BeamSearchSampler(generation_config);
|
|
111
|
+
|
|
112
|
+
} else {
|
|
113
|
+
if (generation_config.num_return_sequences > 1) {
|
|
114
|
+
throw Error(`num_return_sequences has to be 1 when doing greedy search, but is ${generation_config.num_return_sequences}.`)
|
|
115
|
+
}
|
|
116
|
+
return new GreedySampler(generation_config);
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
/**
|
|
122
|
+
* Class representing a Greedy Sampler.
|
|
123
|
+
*/
|
|
124
|
+
class GreedySampler extends LogitsSampler {
|
|
125
|
+
/**
|
|
126
|
+
* Sample the maximum probability of a given logits tensor.
|
|
127
|
+
* @param {Tensor} logits
|
|
128
|
+
* @returns {Promise<[bigint, number][]>} An array with a single tuple, containing the index of the maximum value and a meaningless score (since this is a greedy search).
|
|
129
|
+
*/
|
|
130
|
+
async sample(logits) {
|
|
131
|
+
// NOTE: no need to do log_softmax here since we only take the maximum
|
|
132
|
+
const argmax = max(logits.data)[1];
|
|
133
|
+
|
|
134
|
+
// Note: score is meaningless in this context, since we are performing
|
|
135
|
+
// greedy search (p = 1 => log(p) = 0)
|
|
136
|
+
return [
|
|
137
|
+
[BigInt(argmax), 0]
|
|
138
|
+
];
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
/**
|
|
143
|
+
* Class representing a MultinomialSampler.
|
|
144
|
+
*/
|
|
145
|
+
class MultinomialSampler extends LogitsSampler {
|
|
146
|
+
|
|
147
|
+
/**
|
|
148
|
+
* Sample from the logits.
|
|
149
|
+
* @param {Tensor} logits
|
|
150
|
+
* @returns {Promise<[bigint, number][]>}
|
|
151
|
+
*/
|
|
152
|
+
async sample(logits) {
|
|
153
|
+
let k = logits.dims.at(-1); // defaults to vocab size
|
|
154
|
+
if (this.generation_config.top_k > 0) {
|
|
155
|
+
k = Math.min(this.generation_config.top_k, k);
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
// Get top k tokens
|
|
159
|
+
const [v, i] = await topk(logits, k);
|
|
160
|
+
|
|
161
|
+
// Compute softmax over logits
|
|
162
|
+
const probabilities = softmax(/** @type {Float32Array} */(v.data));
|
|
163
|
+
|
|
164
|
+
return Array.from({ length: this.generation_config.num_beams }, () => {
|
|
165
|
+
const sampledIndex = this.randomSelect(probabilities);
|
|
166
|
+
return [
|
|
167
|
+
i.data[sampledIndex], // token id
|
|
168
|
+
Math.log(probabilities[sampledIndex]), // score
|
|
169
|
+
];
|
|
170
|
+
});
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
/**
|
|
176
|
+
* Class representing a BeamSearchSampler.
|
|
177
|
+
*/
|
|
178
|
+
class BeamSearchSampler extends LogitsSampler {
|
|
179
|
+
|
|
180
|
+
/**
|
|
181
|
+
* Sample from the logits.
|
|
182
|
+
* @param {Tensor} logits
|
|
183
|
+
* @returns {Promise<[bigint, number][]>}
|
|
184
|
+
*/
|
|
185
|
+
async sample(logits) {
|
|
186
|
+
let k = logits.dims.at(-1); // defaults to vocab size
|
|
187
|
+
if (this.generation_config.top_k > 0) {
|
|
188
|
+
k = Math.min(this.generation_config.top_k, k);
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
// Get top k tokens
|
|
192
|
+
const [v, i] = await topk(logits, k);
|
|
193
|
+
|
|
194
|
+
// Compute softmax over logits
|
|
195
|
+
const probabilities = softmax(/** @type {Float32Array} */(v.data));
|
|
196
|
+
|
|
197
|
+
return Array.from({ length: this.generation_config.num_beams }, (_, x) => {
|
|
198
|
+
return [
|
|
199
|
+
i.data[x], // token id
|
|
200
|
+
Math.log(probabilities[x]), // score
|
|
201
|
+
];
|
|
202
|
+
});
|
|
203
|
+
}
|
|
204
|
+
}
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
|
|
2
|
+
/**
|
|
3
|
+
* @module generation/parameters
|
|
4
|
+
*/
|
|
5
|
+
|
|
6
|
+
/**
|
|
7
|
+
* @typedef {Object} GenerationFunctionParameters
|
|
8
|
+
* @property {import('../utils/tensor.js').Tensor} [inputs=null] (`Tensor` of varying shape depending on the modality, *optional*):
|
|
9
|
+
* The sequence used as a prompt for the generation or as model inputs to the encoder. If `null` the
|
|
10
|
+
* method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
|
|
11
|
+
* should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
|
|
12
|
+
* `input_ids`, `input_values`, `input_features`, or `pixel_values`.
|
|
13
|
+
* @property {import('./configuration_utils.js').GenerationConfig} [generation_config=null] (`GenerationConfig`, *optional*):
|
|
14
|
+
* The generation configuration to be used as base parametrization for the generation call.
|
|
15
|
+
* `**kwargs` passed to generate matching the attributes of `generation_config` will override them.
|
|
16
|
+
* If `generation_config` is not provided, the default will be used, which has the following loading
|
|
17
|
+
* priority:
|
|
18
|
+
* - (1) from the `generation_config.json` model file, if it exists;
|
|
19
|
+
* - (2) from the model configuration. Please note that unspecified parameters will inherit [`GenerationConfig`]'s
|
|
20
|
+
* default values, whose documentation should be checked to parameterize generation.
|
|
21
|
+
* @property {import('./logits_process.js').LogitsProcessorList} [logits_processor=null] (`LogitsProcessorList`, *optional*):
|
|
22
|
+
* Custom logits processors that complement the default logits processors built from arguments and
|
|
23
|
+
* generation config. If a logit processor is passed that is already created with the arguments or a
|
|
24
|
+
* generation config an error is thrown. This feature is intended for advanced users.
|
|
25
|
+
* @property {import('./stopping_criteria.js').StoppingCriteriaList} [stopping_criteria=null] (`StoppingCriteriaList`, *optional*):
|
|
26
|
+
* Custom stopping criteria that complements the default stopping criteria built from arguments and a
|
|
27
|
+
* generation config. If a stopping criteria is passed that is already created with the arguments or a
|
|
28
|
+
* generation config an error is thrown. This feature is intended for advanced users.
|
|
29
|
+
* @property {import('./streamers.js').BaseStreamer} [streamer=null] (`BaseStreamer`, *optional*):
|
|
30
|
+
* Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
|
31
|
+
* through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
|
32
|
+
* @property {number[]} [decoder_input_ids=null] (`number[]`, *optional*):
|
|
33
|
+
* If the model is an encoder-decoder model, this argument is used to pass the `decoder_input_ids`.
|
|
34
|
+
* @param {any} [kwargs] (`Dict[str, any]`, *optional*):
|
|
35
|
+
*/
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
|
|
2
|
+
/**
|
|
3
|
+
* @module generation/stopping_criteria
|
|
4
|
+
*/
|
|
5
|
+
|
|
6
|
+
import { Callable } from "../utils/generic.js";
|
|
7
|
+
|
|
8
|
+
// NOTE:
|
|
9
|
+
// Stopping Criteria returns a list of `batch_size` booleans, indicating whether each sequence in the batch should be stopped.
|
|
10
|
+
|
|
11
|
+
/**
|
|
12
|
+
* Abstract base class for all stopping criteria that can be applied during generation.
|
|
13
|
+
*/
|
|
14
|
+
export class StoppingCriteria extends Callable {
|
|
15
|
+
/**
|
|
16
|
+
*
|
|
17
|
+
* @param {number[][]} input_ids (`number[][]` of shape `(batch_size, sequence_length)`):
|
|
18
|
+
* Indices of input sequence tokens in the vocabulary.
|
|
19
|
+
* @param {number[][]} scores scores (`number[][]` of shape `(batch_size, config.vocab_size)`):
|
|
20
|
+
* Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
|
|
21
|
+
* or scores for each vocabulary token after SoftMax.
|
|
22
|
+
* @returns {boolean[]} A list of booleans indicating whether each sequence should be stopped.
|
|
23
|
+
*/
|
|
24
|
+
_call(input_ids, scores) {
|
|
25
|
+
throw Error("StoppingCriteria needs to be subclassed");
|
|
26
|
+
}
|
|
27
|
+
}
|
|
28
|
+
/**
|
|
29
|
+
*/
|
|
30
|
+
export class StoppingCriteriaList extends Callable {
|
|
31
|
+
/**
|
|
32
|
+
* Constructs a new instance of `StoppingCriteriaList`.
|
|
33
|
+
*/
|
|
34
|
+
constructor() {
|
|
35
|
+
super();
|
|
36
|
+
this.criteria = [];
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
/**
|
|
40
|
+
* Adds a new stopping criterion to the list.
|
|
41
|
+
*
|
|
42
|
+
* @param {StoppingCriteria} item The stopping criterion to add.
|
|
43
|
+
*/
|
|
44
|
+
push(item) {
|
|
45
|
+
this.criteria.push(item);
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
/**
|
|
49
|
+
* Adds multiple stopping criteria to the list.
|
|
50
|
+
*
|
|
51
|
+
* @param {StoppingCriteria|StoppingCriteriaList|StoppingCriteria[]} items The stopping criteria to add.
|
|
52
|
+
*/
|
|
53
|
+
extend(items) {
|
|
54
|
+
if (items instanceof StoppingCriteriaList) {
|
|
55
|
+
items = items.criteria;
|
|
56
|
+
} else if (items instanceof StoppingCriteria) {
|
|
57
|
+
items = [items];
|
|
58
|
+
}
|
|
59
|
+
this.criteria.push(...items);
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
_call(input_ids, scores) {
|
|
63
|
+
const is_done = new Array(input_ids.length).fill(false);
|
|
64
|
+
for (const criterion of this.criteria) {
|
|
65
|
+
const criterion_done = criterion(input_ids, scores);
|
|
66
|
+
for (let i = 0; i < is_done.length; ++i) {
|
|
67
|
+
is_done[i] ||= criterion_done[i];
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
return is_done;
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
[Symbol.iterator]() {
|
|
74
|
+
return this.criteria.values();
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
/**
|
|
79
|
+
* This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`.
|
|
80
|
+
* Keep in mind for decoder-only type of transformers, this will include the initial prompted tokens.
|
|
81
|
+
*/
|
|
82
|
+
export class MaxLengthCriteria extends StoppingCriteria {
|
|
83
|
+
|
|
84
|
+
/**
|
|
85
|
+
*
|
|
86
|
+
* @param {number} max_length The maximum length that the output sequence can have in number of tokens.
|
|
87
|
+
* @param {number} [max_position_embeddings=null] The maximum model length, as defined by the model's `config.max_position_embeddings` attribute.
|
|
88
|
+
*/
|
|
89
|
+
constructor(max_length, max_position_embeddings = null) {
|
|
90
|
+
super();
|
|
91
|
+
this.max_length = max_length;
|
|
92
|
+
this.max_position_embeddings = max_position_embeddings;
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
_call(input_ids) {
|
|
96
|
+
return input_ids.map(ids => ids.length >= this.max_length);
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
// TODO: add MaxTimeCriteria
|
|
101
|
+
|
|
102
|
+
/**
|
|
103
|
+
* This class can be used to stop generation whenever the "end-of-sequence" token is generated.
|
|
104
|
+
* By default, it uses the `model.generation_config.eos_token_id`.
|
|
105
|
+
*/
|
|
106
|
+
export class EosTokenCriteria extends StoppingCriteria {
|
|
107
|
+
|
|
108
|
+
/**
|
|
109
|
+
*
|
|
110
|
+
* @param {number|number[]} eos_token_id The id of the *end-of-sequence* token.
|
|
111
|
+
* Optionally, use a list to set multiple *end-of-sequence* tokens.
|
|
112
|
+
*/
|
|
113
|
+
constructor(eos_token_id) {
|
|
114
|
+
super();
|
|
115
|
+
if (!Array.isArray(eos_token_id)) {
|
|
116
|
+
eos_token_id = [eos_token_id];
|
|
117
|
+
}
|
|
118
|
+
this.eos_token_id = eos_token_id;
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
/**
|
|
122
|
+
*
|
|
123
|
+
* @param {number[][]} input_ids
|
|
124
|
+
* @param {number[][]} scores
|
|
125
|
+
* @returns {boolean[]}
|
|
126
|
+
*/
|
|
127
|
+
_call(input_ids, scores) {
|
|
128
|
+
return input_ids.map(ids => {
|
|
129
|
+
const last = ids.at(-1);
|
|
130
|
+
// NOTE: We use == instead of === to allow for number/bigint comparison
|
|
131
|
+
return this.eos_token_id.some(eos_id => last == eos_id);
|
|
132
|
+
});
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
/**
|
|
137
|
+
* This class can be used to stop generation whenever the user interrupts the process.
|
|
138
|
+
*/
|
|
139
|
+
export class InterruptableStoppingCriteria extends StoppingCriteria {
|
|
140
|
+
constructor() {
|
|
141
|
+
super();
|
|
142
|
+
this.interrupted = false;
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
interrupt() {
|
|
146
|
+
this.interrupted = true;
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
reset() {
|
|
150
|
+
this.interrupted = false;
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
_call(input_ids, scores) {
|
|
154
|
+
return new Array(input_ids.length).fill(this.interrupted);
|
|
155
|
+
}
|
|
156
|
+
}
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
|
|
2
|
+
/**
|
|
3
|
+
* @module generation/streamers
|
|
4
|
+
*/
|
|
5
|
+
|
|
6
|
+
import { mergeArrays } from '../utils/core.js';
|
|
7
|
+
import { is_chinese_char } from '../tokenizers.js';
|
|
8
|
+
import { apis } from '../env.js';
|
|
9
|
+
|
|
10
|
+
export class BaseStreamer {
|
|
11
|
+
/**
|
|
12
|
+
* Function that is called by `.generate()` to push new tokens
|
|
13
|
+
* @param {bigint[][]} value
|
|
14
|
+
*/
|
|
15
|
+
put(value) {
|
|
16
|
+
throw Error('Not implemented');
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
/**
|
|
20
|
+
* Function that is called by `.generate()` to signal the end of generation
|
|
21
|
+
*/
|
|
22
|
+
end() {
|
|
23
|
+
throw Error('Not implemented');
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
const stdout_write = apis.IS_PROCESS_AVAILABLE
|
|
28
|
+
? x => process.stdout.write(x)
|
|
29
|
+
: x => console.log(x);
|
|
30
|
+
|
|
31
|
+
/**
|
|
32
|
+
* Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
|
|
33
|
+
*/
|
|
34
|
+
export class TextStreamer extends BaseStreamer {
|
|
35
|
+
/**
|
|
36
|
+
*
|
|
37
|
+
* @param {import('../tokenizers.js').PreTrainedTokenizer} tokenizer
|
|
38
|
+
*/
|
|
39
|
+
constructor(tokenizer, {
|
|
40
|
+
skip_prompt = false,
|
|
41
|
+
callback_function = null,
|
|
42
|
+
token_callback_function = null,
|
|
43
|
+
decode_kwargs = {},
|
|
44
|
+
...kwargs
|
|
45
|
+
} = {}) {
|
|
46
|
+
super();
|
|
47
|
+
this.tokenizer = tokenizer;
|
|
48
|
+
this.skip_prompt = skip_prompt;
|
|
49
|
+
this.callback_function = callback_function ?? stdout_write;
|
|
50
|
+
this.token_callback_function = token_callback_function;
|
|
51
|
+
this.decode_kwargs = { ...decode_kwargs, ...kwargs };
|
|
52
|
+
|
|
53
|
+
// variables used in the streaming process
|
|
54
|
+
this.token_cache = [];
|
|
55
|
+
this.print_len = 0;
|
|
56
|
+
this.next_tokens_are_prompt = true;
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
/**
|
|
60
|
+
* Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
|
|
61
|
+
* @param {bigint[][]} value
|
|
62
|
+
*/
|
|
63
|
+
put(value) {
|
|
64
|
+
if (value.length > 1) {
|
|
65
|
+
throw Error('TextStreamer only supports batch size of 1');
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
const tokens = value[0];
|
|
69
|
+
this.token_callback_function?.(tokens)
|
|
70
|
+
|
|
71
|
+
if (this.skip_prompt && this.next_tokens_are_prompt) {
|
|
72
|
+
this.next_tokens_are_prompt = false;
|
|
73
|
+
return;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
// Add the new token to the cache and decodes the entire thing.
|
|
77
|
+
this.token_cache = mergeArrays(this.token_cache, tokens);
|
|
78
|
+
const text = this.tokenizer.decode(this.token_cache, this.decode_kwargs);
|
|
79
|
+
|
|
80
|
+
let printable_text;
|
|
81
|
+
if (text.endsWith('\n')) {
|
|
82
|
+
// After the symbol for a new line, we flush the cache.
|
|
83
|
+
printable_text = text.slice(this.print_len);
|
|
84
|
+
this.token_cache = [];
|
|
85
|
+
this.print_len = 0;
|
|
86
|
+
} else if (text.length > 0 && is_chinese_char(text.charCodeAt(text.length - 1))) {
|
|
87
|
+
// If the last token is a CJK character, we print the characters.
|
|
88
|
+
printable_text = text.slice(this.print_len);
|
|
89
|
+
this.print_len += printable_text.length;
|
|
90
|
+
} else {
|
|
91
|
+
// Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
|
|
92
|
+
// which may change with the subsequent token -- there are probably smarter ways to do this!)
|
|
93
|
+
printable_text = text.slice(this.print_len, text.lastIndexOf(' ') + 1);
|
|
94
|
+
this.print_len += printable_text.length;
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
this.on_finalized_text(printable_text, false);
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
/**
|
|
101
|
+
* Flushes any remaining cache and prints a newline to stdout.
|
|
102
|
+
*/
|
|
103
|
+
end() {
|
|
104
|
+
let printable_text;
|
|
105
|
+
if (this.token_cache.length > 0) {
|
|
106
|
+
const text = this.tokenizer.decode(this.token_cache, this.decode_kwargs);
|
|
107
|
+
printable_text = text.slice(this.print_len);
|
|
108
|
+
this.token_cache = [];
|
|
109
|
+
this.print_len = 0;
|
|
110
|
+
} else {
|
|
111
|
+
printable_text = '';
|
|
112
|
+
}
|
|
113
|
+
this.next_tokens_are_prompt = true;
|
|
114
|
+
this.on_finalized_text(printable_text, true);
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
/**
|
|
118
|
+
* Prints the new text to stdout. If the stream is ending, also prints a newline.
|
|
119
|
+
* @param {string} text
|
|
120
|
+
* @param {boolean} stream_end
|
|
121
|
+
*/
|
|
122
|
+
on_finalized_text(text, stream_end) {
|
|
123
|
+
if (text.length > 0) {
|
|
124
|
+
this.callback_function?.(text);
|
|
125
|
+
}
|
|
126
|
+
if (stream_end && this.callback_function === stdout_write && apis.IS_PROCESS_AVAILABLE) {
|
|
127
|
+
this.callback_function?.('\n');
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
/**
|
|
133
|
+
* Utility class to handle streaming of tokens generated by whisper speech-to-text models.
|
|
134
|
+
* Callback functions are invoked when each of the following events occur:
|
|
135
|
+
* - A new chunk starts (on_chunk_start)
|
|
136
|
+
* - A new token is generated (callback_function)
|
|
137
|
+
* - A chunk ends (on_chunk_end)
|
|
138
|
+
* - The stream is finalized (on_finalize)
|
|
139
|
+
*/
|
|
140
|
+
export class WhisperTextStreamer extends TextStreamer {
|
|
141
|
+
/**
|
|
142
|
+
* @param {import('../tokenizers.js').WhisperTokenizer} tokenizer
|
|
143
|
+
* @param {Object} options
|
|
144
|
+
* @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens
|
|
145
|
+
* @param {function(string): void} [options.callback_function=null] Function to call when a piece of text is ready to display
|
|
146
|
+
* @param {function(string): void} [options.token_callback_function=null] Function to call when a new token is generated
|
|
147
|
+
* @param {function(number): void} [options.on_chunk_start=null] Function to call when a new chunk starts
|
|
148
|
+
* @param {function(number): void} [options.on_chunk_end=null] Function to call when a chunk ends
|
|
149
|
+
* @param {function(): void} [options.on_finalize=null] Function to call when the stream is finalized
|
|
150
|
+
* @param {number} [options.time_precision=0.02] Precision of the timestamps
|
|
151
|
+
* @param {boolean} [options.skip_special_tokens=true] Whether to skip special tokens when decoding
|
|
152
|
+
* @param {Object} [options.decode_kwargs={}] Additional keyword arguments to pass to the tokenizer's decode method
|
|
153
|
+
*/
|
|
154
|
+
constructor(tokenizer, {
|
|
155
|
+
skip_prompt = false,
|
|
156
|
+
callback_function = null,
|
|
157
|
+
token_callback_function = null,
|
|
158
|
+
on_chunk_start = null,
|
|
159
|
+
on_chunk_end = null,
|
|
160
|
+
on_finalize = null,
|
|
161
|
+
time_precision = 0.02,
|
|
162
|
+
skip_special_tokens = true,
|
|
163
|
+
decode_kwargs = {},
|
|
164
|
+
} = {}) {
|
|
165
|
+
super(tokenizer, {
|
|
166
|
+
skip_prompt,
|
|
167
|
+
callback_function,
|
|
168
|
+
token_callback_function,
|
|
169
|
+
decode_kwargs: { skip_special_tokens, ...decode_kwargs },
|
|
170
|
+
});
|
|
171
|
+
this.timestamp_begin = tokenizer.timestamp_begin;
|
|
172
|
+
|
|
173
|
+
this.on_chunk_start = on_chunk_start;
|
|
174
|
+
this.on_chunk_end = on_chunk_end;
|
|
175
|
+
this.on_finalize = on_finalize;
|
|
176
|
+
|
|
177
|
+
this.time_precision = time_precision;
|
|
178
|
+
|
|
179
|
+
this.waiting_for_timestamp = false;
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
/**
|
|
183
|
+
* @param {bigint[][]} value
|
|
184
|
+
*/
|
|
185
|
+
put(value) {
|
|
186
|
+
if (value.length > 1) {
|
|
187
|
+
throw Error('WhisperTextStreamer only supports batch size of 1');
|
|
188
|
+
}
|
|
189
|
+
const tokens = value[0];
|
|
190
|
+
|
|
191
|
+
// Check if the token is a timestamp
|
|
192
|
+
if (tokens.length === 1) {
|
|
193
|
+
const offset = Number(tokens[0]) - this.timestamp_begin;
|
|
194
|
+
if (offset >= 0) {
|
|
195
|
+
const time = offset * this.time_precision;
|
|
196
|
+
if (this.waiting_for_timestamp) {
|
|
197
|
+
this.on_chunk_end?.(time);
|
|
198
|
+
} else {
|
|
199
|
+
this.on_chunk_start?.(time);
|
|
200
|
+
}
|
|
201
|
+
this.waiting_for_timestamp = !this.waiting_for_timestamp; // Toggle
|
|
202
|
+
value = [[]]; // Skip timestamp
|
|
203
|
+
}
|
|
204
|
+
}
|
|
205
|
+
return super.put(value);
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
end() {
|
|
209
|
+
super.end();
|
|
210
|
+
this.on_finalize?.();
|
|
211
|
+
}
|
|
212
|
+
}
|