@huggingface/transformers 3.0.0-alpha.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. package/LICENSE +202 -0
  2. package/README.md +376 -0
  3. package/dist/ort-wasm-simd-threaded.jsep.wasm +0 -0
  4. package/dist/transformers.cjs +30741 -0
  5. package/dist/transformers.cjs.map +1 -0
  6. package/dist/transformers.js +33858 -0
  7. package/dist/transformers.js.map +1 -0
  8. package/dist/transformers.min.cjs +173 -0
  9. package/dist/transformers.min.cjs.map +1 -0
  10. package/dist/transformers.min.js +231 -0
  11. package/dist/transformers.min.js.map +1 -0
  12. package/package.json +92 -0
  13. package/src/backends/onnx.js +151 -0
  14. package/src/configs.js +360 -0
  15. package/src/env.js +152 -0
  16. package/src/generation/configuration_utils.js +381 -0
  17. package/src/generation/logits_process.js +716 -0
  18. package/src/generation/logits_sampler.js +204 -0
  19. package/src/generation/parameters.js +35 -0
  20. package/src/generation/stopping_criteria.js +156 -0
  21. package/src/generation/streamers.js +212 -0
  22. package/src/models/whisper/common_whisper.js +151 -0
  23. package/src/models/whisper/generation_whisper.js +89 -0
  24. package/src/models.js +7028 -0
  25. package/src/ops/registry.js +92 -0
  26. package/src/pipelines.js +3341 -0
  27. package/src/processors.js +2614 -0
  28. package/src/tokenizers.js +4395 -0
  29. package/src/transformers.js +28 -0
  30. package/src/utils/audio.js +704 -0
  31. package/src/utils/constants.js +2 -0
  32. package/src/utils/core.js +149 -0
  33. package/src/utils/data-structures.js +445 -0
  34. package/src/utils/devices.js +11 -0
  35. package/src/utils/dtypes.js +62 -0
  36. package/src/utils/generic.js +35 -0
  37. package/src/utils/hub.js +671 -0
  38. package/src/utils/image.js +745 -0
  39. package/src/utils/maths.js +1050 -0
  40. package/src/utils/tensor.js +1378 -0
  41. package/types/backends/onnx.d.ts +26 -0
  42. package/types/backends/onnx.d.ts.map +1 -0
  43. package/types/configs.d.ts +59 -0
  44. package/types/configs.d.ts.map +1 -0
  45. package/types/env.d.ts +106 -0
  46. package/types/env.d.ts.map +1 -0
  47. package/types/generation/configuration_utils.d.ts +320 -0
  48. package/types/generation/configuration_utils.d.ts.map +1 -0
  49. package/types/generation/logits_process.d.ts +354 -0
  50. package/types/generation/logits_process.d.ts.map +1 -0
  51. package/types/generation/logits_sampler.d.ts +51 -0
  52. package/types/generation/logits_sampler.d.ts.map +1 -0
  53. package/types/generation/parameters.d.ts +47 -0
  54. package/types/generation/parameters.d.ts.map +1 -0
  55. package/types/generation/stopping_criteria.d.ts +81 -0
  56. package/types/generation/stopping_criteria.d.ts.map +1 -0
  57. package/types/generation/streamers.d.ts +81 -0
  58. package/types/generation/streamers.d.ts.map +1 -0
  59. package/types/models/whisper/common_whisper.d.ts +8 -0
  60. package/types/models/whisper/common_whisper.d.ts.map +1 -0
  61. package/types/models/whisper/generation_whisper.d.ts +76 -0
  62. package/types/models/whisper/generation_whisper.d.ts.map +1 -0
  63. package/types/models.d.ts +3845 -0
  64. package/types/models.d.ts.map +1 -0
  65. package/types/ops/registry.d.ts +11 -0
  66. package/types/ops/registry.d.ts.map +1 -0
  67. package/types/pipelines.d.ts +2403 -0
  68. package/types/pipelines.d.ts.map +1 -0
  69. package/types/processors.d.ts +917 -0
  70. package/types/processors.d.ts.map +1 -0
  71. package/types/tokenizers.d.ts +999 -0
  72. package/types/tokenizers.d.ts.map +1 -0
  73. package/types/transformers.d.ts +13 -0
  74. package/types/transformers.d.ts.map +1 -0
  75. package/types/utils/audio.d.ts +130 -0
  76. package/types/utils/audio.d.ts.map +1 -0
  77. package/types/utils/constants.d.ts +2 -0
  78. package/types/utils/constants.d.ts.map +1 -0
  79. package/types/utils/core.d.ts +91 -0
  80. package/types/utils/core.d.ts.map +1 -0
  81. package/types/utils/data-structures.d.ts +236 -0
  82. package/types/utils/data-structures.d.ts.map +1 -0
  83. package/types/utils/devices.d.ts +8 -0
  84. package/types/utils/devices.d.ts.map +1 -0
  85. package/types/utils/dtypes.d.ts +22 -0
  86. package/types/utils/dtypes.d.ts.map +1 -0
  87. package/types/utils/generic.d.ts +11 -0
  88. package/types/utils/generic.d.ts.map +1 -0
  89. package/types/utils/hub.d.ts +191 -0
  90. package/types/utils/hub.d.ts.map +1 -0
  91. package/types/utils/image.d.ts +119 -0
  92. package/types/utils/image.d.ts.map +1 -0
  93. package/types/utils/maths.d.ts +280 -0
  94. package/types/utils/maths.d.ts.map +1 -0
  95. package/types/utils/tensor.d.ts +392 -0
  96. package/types/utils/tensor.d.ts.map +1 -0
@@ -0,0 +1,204 @@
1
+
2
+ /**
3
+ * @module generation/logits_sampler
4
+ */
5
+
6
+ import { Callable } from "../utils/generic.js";
7
+ import { Tensor, topk } from "../utils/tensor.js";
8
+
9
+ import {
10
+ max,
11
+ softmax,
12
+ } from '../utils/maths.js';
13
+ import { GenerationConfig } from '../generation/configuration_utils.js';
14
+
15
+ /**
16
+ * Sampler is a base class for all sampling methods used for text generation.
17
+ */
18
+ export class LogitsSampler extends Callable {
19
+ /**
20
+ * Creates a new Sampler object with the specified generation config.
21
+ * @param {GenerationConfig} generation_config The generation config.
22
+ */
23
+ constructor(generation_config) {
24
+ super();
25
+ this.generation_config = generation_config;
26
+ }
27
+
28
+ /**
29
+ * Executes the sampler, using the specified logits.
30
+ * @param {Tensor} logits
31
+ * @returns {Promise<[bigint, number][]>}
32
+ */
33
+ async _call(logits) {
34
+ // Sample from logits, of dims [batch, sequence_length, vocab_size].
35
+ // If index is specified, sample from [batch, index, vocab_size].
36
+ return this.sample(logits);
37
+ }
38
+
39
+ /**
40
+ * Abstract method for sampling the logits.
41
+ * @param {Tensor} logits
42
+ * @throws {Error} If not implemented in subclass.
43
+ * @returns {Promise<[bigint, number][]>}
44
+ */
45
+ async sample(logits) {
46
+ throw Error("sample should be implemented in subclasses.")
47
+ }
48
+
49
+ /**
50
+ * Returns the specified logits as an array, with temperature applied.
51
+ * @param {Tensor} logits
52
+ * @param {number} index
53
+ * @returns {Float32Array}
54
+ */
55
+ getLogits(logits, index) {
56
+ let vocabSize = logits.dims.at(-1);
57
+
58
+ let logs = /** @type {Float32Array} */(logits.data);
59
+
60
+ if (index === -1) {
61
+ logs = logs.slice(-vocabSize);
62
+ } else {
63
+ let startIndex = index * vocabSize;
64
+ logs = logs.slice(startIndex, startIndex + vocabSize);
65
+ }
66
+ return logs;
67
+ }
68
+
69
+ /**
70
+ * Selects an item randomly based on the specified probabilities.
71
+ * @param {import("../transformers.js").DataArray} probabilities An array of probabilities to use for selection.
72
+ * @returns {number} The index of the selected item.
73
+ */
74
+ randomSelect(probabilities) {
75
+ // Return index of chosen item
76
+ let sumProbabilities = 0;
77
+ for (let i = 0; i < probabilities.length; ++i) {
78
+ sumProbabilities += probabilities[i];
79
+ }
80
+
81
+ let r = Math.random() * sumProbabilities;
82
+ for (let i = 0; i < probabilities.length; ++i) {
83
+ r -= probabilities[i];
84
+ if (r <= 0) {
85
+ return i;
86
+ }
87
+ }
88
+ return 0; // return first (most probable) as a fallback
89
+ }
90
+
91
+ /**
92
+ * Returns a Sampler object based on the specified options.
93
+ * @param {GenerationConfig} generation_config An object containing options for the sampler.
94
+ * @returns {LogitsSampler} A Sampler object.
95
+ */
96
+ static getSampler(generation_config) {
97
+ // - *greedy decoding*: `num_beams=1` and `do_sample=False`
98
+ // - *contrastive search*: `penalty_alpha>0` and `top_k>1`
99
+ // - *multinomial sampling*: `num_beams=1` and `do_sample=True`
100
+ // - *beam-search decoding*: `num_beams>1` and `do_sample=False`
101
+ // - *beam-search multinomial sampling*: `num_beams>1` and `do_sample=True`
102
+ // - *diverse beam-search decoding*: `num_beams>1` and `num_beam_groups>1`
103
+ // - *constrained beam-search decoding*: `constraints!=None` or `force_words_ids!=None`
104
+
105
+ // NOTE: beam search is implemented directly into the generation function
106
+ if (generation_config.do_sample) {
107
+ return new MultinomialSampler(generation_config);
108
+
109
+ } else if (generation_config.num_beams > 1) {
110
+ return new BeamSearchSampler(generation_config);
111
+
112
+ } else {
113
+ if (generation_config.num_return_sequences > 1) {
114
+ throw Error(`num_return_sequences has to be 1 when doing greedy search, but is ${generation_config.num_return_sequences}.`)
115
+ }
116
+ return new GreedySampler(generation_config);
117
+ }
118
+ }
119
+ }
120
+
121
+ /**
122
+ * Class representing a Greedy Sampler.
123
+ */
124
+ class GreedySampler extends LogitsSampler {
125
+ /**
126
+ * Sample the maximum probability of a given logits tensor.
127
+ * @param {Tensor} logits
128
+ * @returns {Promise<[bigint, number][]>} An array with a single tuple, containing the index of the maximum value and a meaningless score (since this is a greedy search).
129
+ */
130
+ async sample(logits) {
131
+ // NOTE: no need to do log_softmax here since we only take the maximum
132
+ const argmax = max(logits.data)[1];
133
+
134
+ // Note: score is meaningless in this context, since we are performing
135
+ // greedy search (p = 1 => log(p) = 0)
136
+ return [
137
+ [BigInt(argmax), 0]
138
+ ];
139
+ }
140
+ }
141
+
142
+ /**
143
+ * Class representing a MultinomialSampler.
144
+ */
145
+ class MultinomialSampler extends LogitsSampler {
146
+
147
+ /**
148
+ * Sample from the logits.
149
+ * @param {Tensor} logits
150
+ * @returns {Promise<[bigint, number][]>}
151
+ */
152
+ async sample(logits) {
153
+ let k = logits.dims.at(-1); // defaults to vocab size
154
+ if (this.generation_config.top_k > 0) {
155
+ k = Math.min(this.generation_config.top_k, k);
156
+ }
157
+
158
+ // Get top k tokens
159
+ const [v, i] = await topk(logits, k);
160
+
161
+ // Compute softmax over logits
162
+ const probabilities = softmax(/** @type {Float32Array} */(v.data));
163
+
164
+ return Array.from({ length: this.generation_config.num_beams }, () => {
165
+ const sampledIndex = this.randomSelect(probabilities);
166
+ return [
167
+ i.data[sampledIndex], // token id
168
+ Math.log(probabilities[sampledIndex]), // score
169
+ ];
170
+ });
171
+ }
172
+ }
173
+
174
+
175
+ /**
176
+ * Class representing a BeamSearchSampler.
177
+ */
178
+ class BeamSearchSampler extends LogitsSampler {
179
+
180
+ /**
181
+ * Sample from the logits.
182
+ * @param {Tensor} logits
183
+ * @returns {Promise<[bigint, number][]>}
184
+ */
185
+ async sample(logits) {
186
+ let k = logits.dims.at(-1); // defaults to vocab size
187
+ if (this.generation_config.top_k > 0) {
188
+ k = Math.min(this.generation_config.top_k, k);
189
+ }
190
+
191
+ // Get top k tokens
192
+ const [v, i] = await topk(logits, k);
193
+
194
+ // Compute softmax over logits
195
+ const probabilities = softmax(/** @type {Float32Array} */(v.data));
196
+
197
+ return Array.from({ length: this.generation_config.num_beams }, (_, x) => {
198
+ return [
199
+ i.data[x], // token id
200
+ Math.log(probabilities[x]), // score
201
+ ];
202
+ });
203
+ }
204
+ }
@@ -0,0 +1,35 @@
1
+
2
+ /**
3
+ * @module generation/parameters
4
+ */
5
+
6
+ /**
7
+ * @typedef {Object} GenerationFunctionParameters
8
+ * @property {import('../utils/tensor.js').Tensor} [inputs=null] (`Tensor` of varying shape depending on the modality, *optional*):
9
+ * The sequence used as a prompt for the generation or as model inputs to the encoder. If `null` the
10
+ * method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
11
+ * should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
12
+ * `input_ids`, `input_values`, `input_features`, or `pixel_values`.
13
+ * @property {import('./configuration_utils.js').GenerationConfig} [generation_config=null] (`GenerationConfig`, *optional*):
14
+ * The generation configuration to be used as base parametrization for the generation call.
15
+ * `**kwargs` passed to generate matching the attributes of `generation_config` will override them.
16
+ * If `generation_config` is not provided, the default will be used, which has the following loading
17
+ * priority:
18
+ * - (1) from the `generation_config.json` model file, if it exists;
19
+ * - (2) from the model configuration. Please note that unspecified parameters will inherit [`GenerationConfig`]'s
20
+ * default values, whose documentation should be checked to parameterize generation.
21
+ * @property {import('./logits_process.js').LogitsProcessorList} [logits_processor=null] (`LogitsProcessorList`, *optional*):
22
+ * Custom logits processors that complement the default logits processors built from arguments and
23
+ * generation config. If a logit processor is passed that is already created with the arguments or a
24
+ * generation config an error is thrown. This feature is intended for advanced users.
25
+ * @property {import('./stopping_criteria.js').StoppingCriteriaList} [stopping_criteria=null] (`StoppingCriteriaList`, *optional*):
26
+ * Custom stopping criteria that complements the default stopping criteria built from arguments and a
27
+ * generation config. If a stopping criteria is passed that is already created with the arguments or a
28
+ * generation config an error is thrown. This feature is intended for advanced users.
29
+ * @property {import('./streamers.js').BaseStreamer} [streamer=null] (`BaseStreamer`, *optional*):
30
+ * Streamer object that will be used to stream the generated sequences. Generated tokens are passed
31
+ * through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
32
+ * @property {number[]} [decoder_input_ids=null] (`number[]`, *optional*):
33
+ * If the model is an encoder-decoder model, this argument is used to pass the `decoder_input_ids`.
34
+ * @param {any} [kwargs] (`Dict[str, any]`, *optional*):
35
+ */
@@ -0,0 +1,156 @@
1
+
2
+ /**
3
+ * @module generation/stopping_criteria
4
+ */
5
+
6
+ import { Callable } from "../utils/generic.js";
7
+
8
+ // NOTE:
9
+ // Stopping Criteria returns a list of `batch_size` booleans, indicating whether each sequence in the batch should be stopped.
10
+
11
+ /**
12
+ * Abstract base class for all stopping criteria that can be applied during generation.
13
+ */
14
+ export class StoppingCriteria extends Callable {
15
+ /**
16
+ *
17
+ * @param {number[][]} input_ids (`number[][]` of shape `(batch_size, sequence_length)`):
18
+ * Indices of input sequence tokens in the vocabulary.
19
+ * @param {number[][]} scores scores (`number[][]` of shape `(batch_size, config.vocab_size)`):
20
+ * Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
21
+ * or scores for each vocabulary token after SoftMax.
22
+ * @returns {boolean[]} A list of booleans indicating whether each sequence should be stopped.
23
+ */
24
+ _call(input_ids, scores) {
25
+ throw Error("StoppingCriteria needs to be subclassed");
26
+ }
27
+ }
28
+ /**
29
+ */
30
+ export class StoppingCriteriaList extends Callable {
31
+ /**
32
+ * Constructs a new instance of `StoppingCriteriaList`.
33
+ */
34
+ constructor() {
35
+ super();
36
+ this.criteria = [];
37
+ }
38
+
39
+ /**
40
+ * Adds a new stopping criterion to the list.
41
+ *
42
+ * @param {StoppingCriteria} item The stopping criterion to add.
43
+ */
44
+ push(item) {
45
+ this.criteria.push(item);
46
+ }
47
+
48
+ /**
49
+ * Adds multiple stopping criteria to the list.
50
+ *
51
+ * @param {StoppingCriteria|StoppingCriteriaList|StoppingCriteria[]} items The stopping criteria to add.
52
+ */
53
+ extend(items) {
54
+ if (items instanceof StoppingCriteriaList) {
55
+ items = items.criteria;
56
+ } else if (items instanceof StoppingCriteria) {
57
+ items = [items];
58
+ }
59
+ this.criteria.push(...items);
60
+ }
61
+
62
+ _call(input_ids, scores) {
63
+ const is_done = new Array(input_ids.length).fill(false);
64
+ for (const criterion of this.criteria) {
65
+ const criterion_done = criterion(input_ids, scores);
66
+ for (let i = 0; i < is_done.length; ++i) {
67
+ is_done[i] ||= criterion_done[i];
68
+ }
69
+ }
70
+ return is_done;
71
+ }
72
+
73
+ [Symbol.iterator]() {
74
+ return this.criteria.values();
75
+ }
76
+ }
77
+
78
+ /**
79
+ * This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`.
80
+ * Keep in mind for decoder-only type of transformers, this will include the initial prompted tokens.
81
+ */
82
+ export class MaxLengthCriteria extends StoppingCriteria {
83
+
84
+ /**
85
+ *
86
+ * @param {number} max_length The maximum length that the output sequence can have in number of tokens.
87
+ * @param {number} [max_position_embeddings=null] The maximum model length, as defined by the model's `config.max_position_embeddings` attribute.
88
+ */
89
+ constructor(max_length, max_position_embeddings = null) {
90
+ super();
91
+ this.max_length = max_length;
92
+ this.max_position_embeddings = max_position_embeddings;
93
+ }
94
+
95
+ _call(input_ids) {
96
+ return input_ids.map(ids => ids.length >= this.max_length);
97
+ }
98
+ }
99
+
100
+ // TODO: add MaxTimeCriteria
101
+
102
+ /**
103
+ * This class can be used to stop generation whenever the "end-of-sequence" token is generated.
104
+ * By default, it uses the `model.generation_config.eos_token_id`.
105
+ */
106
+ export class EosTokenCriteria extends StoppingCriteria {
107
+
108
+ /**
109
+ *
110
+ * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token.
111
+ * Optionally, use a list to set multiple *end-of-sequence* tokens.
112
+ */
113
+ constructor(eos_token_id) {
114
+ super();
115
+ if (!Array.isArray(eos_token_id)) {
116
+ eos_token_id = [eos_token_id];
117
+ }
118
+ this.eos_token_id = eos_token_id;
119
+ }
120
+
121
+ /**
122
+ *
123
+ * @param {number[][]} input_ids
124
+ * @param {number[][]} scores
125
+ * @returns {boolean[]}
126
+ */
127
+ _call(input_ids, scores) {
128
+ return input_ids.map(ids => {
129
+ const last = ids.at(-1);
130
+ // NOTE: We use == instead of === to allow for number/bigint comparison
131
+ return this.eos_token_id.some(eos_id => last == eos_id);
132
+ });
133
+ }
134
+ }
135
+
136
+ /**
137
+ * This class can be used to stop generation whenever the user interrupts the process.
138
+ */
139
+ export class InterruptableStoppingCriteria extends StoppingCriteria {
140
+ constructor() {
141
+ super();
142
+ this.interrupted = false;
143
+ }
144
+
145
+ interrupt() {
146
+ this.interrupted = true;
147
+ }
148
+
149
+ reset() {
150
+ this.interrupted = false;
151
+ }
152
+
153
+ _call(input_ids, scores) {
154
+ return new Array(input_ids.length).fill(this.interrupted);
155
+ }
156
+ }
@@ -0,0 +1,212 @@
1
+
2
+ /**
3
+ * @module generation/streamers
4
+ */
5
+
6
+ import { mergeArrays } from '../utils/core.js';
7
+ import { is_chinese_char } from '../tokenizers.js';
8
+ import { apis } from '../env.js';
9
+
10
+ export class BaseStreamer {
11
+ /**
12
+ * Function that is called by `.generate()` to push new tokens
13
+ * @param {bigint[][]} value
14
+ */
15
+ put(value) {
16
+ throw Error('Not implemented');
17
+ }
18
+
19
+ /**
20
+ * Function that is called by `.generate()` to signal the end of generation
21
+ */
22
+ end() {
23
+ throw Error('Not implemented');
24
+ }
25
+ }
26
+
27
+ const stdout_write = apis.IS_PROCESS_AVAILABLE
28
+ ? x => process.stdout.write(x)
29
+ : x => console.log(x);
30
+
31
+ /**
32
+ * Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
33
+ */
34
+ export class TextStreamer extends BaseStreamer {
35
+ /**
36
+ *
37
+ * @param {import('../tokenizers.js').PreTrainedTokenizer} tokenizer
38
+ */
39
+ constructor(tokenizer, {
40
+ skip_prompt = false,
41
+ callback_function = null,
42
+ token_callback_function = null,
43
+ decode_kwargs = {},
44
+ ...kwargs
45
+ } = {}) {
46
+ super();
47
+ this.tokenizer = tokenizer;
48
+ this.skip_prompt = skip_prompt;
49
+ this.callback_function = callback_function ?? stdout_write;
50
+ this.token_callback_function = token_callback_function;
51
+ this.decode_kwargs = { ...decode_kwargs, ...kwargs };
52
+
53
+ // variables used in the streaming process
54
+ this.token_cache = [];
55
+ this.print_len = 0;
56
+ this.next_tokens_are_prompt = true;
57
+ }
58
+
59
+ /**
60
+ * Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
61
+ * @param {bigint[][]} value
62
+ */
63
+ put(value) {
64
+ if (value.length > 1) {
65
+ throw Error('TextStreamer only supports batch size of 1');
66
+ }
67
+
68
+ const tokens = value[0];
69
+ this.token_callback_function?.(tokens)
70
+
71
+ if (this.skip_prompt && this.next_tokens_are_prompt) {
72
+ this.next_tokens_are_prompt = false;
73
+ return;
74
+ }
75
+
76
+ // Add the new token to the cache and decodes the entire thing.
77
+ this.token_cache = mergeArrays(this.token_cache, tokens);
78
+ const text = this.tokenizer.decode(this.token_cache, this.decode_kwargs);
79
+
80
+ let printable_text;
81
+ if (text.endsWith('\n')) {
82
+ // After the symbol for a new line, we flush the cache.
83
+ printable_text = text.slice(this.print_len);
84
+ this.token_cache = [];
85
+ this.print_len = 0;
86
+ } else if (text.length > 0 && is_chinese_char(text.charCodeAt(text.length - 1))) {
87
+ // If the last token is a CJK character, we print the characters.
88
+ printable_text = text.slice(this.print_len);
89
+ this.print_len += printable_text.length;
90
+ } else {
91
+ // Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
92
+ // which may change with the subsequent token -- there are probably smarter ways to do this!)
93
+ printable_text = text.slice(this.print_len, text.lastIndexOf(' ') + 1);
94
+ this.print_len += printable_text.length;
95
+ }
96
+
97
+ this.on_finalized_text(printable_text, false);
98
+ }
99
+
100
+ /**
101
+ * Flushes any remaining cache and prints a newline to stdout.
102
+ */
103
+ end() {
104
+ let printable_text;
105
+ if (this.token_cache.length > 0) {
106
+ const text = this.tokenizer.decode(this.token_cache, this.decode_kwargs);
107
+ printable_text = text.slice(this.print_len);
108
+ this.token_cache = [];
109
+ this.print_len = 0;
110
+ } else {
111
+ printable_text = '';
112
+ }
113
+ this.next_tokens_are_prompt = true;
114
+ this.on_finalized_text(printable_text, true);
115
+ }
116
+
117
+ /**
118
+ * Prints the new text to stdout. If the stream is ending, also prints a newline.
119
+ * @param {string} text
120
+ * @param {boolean} stream_end
121
+ */
122
+ on_finalized_text(text, stream_end) {
123
+ if (text.length > 0) {
124
+ this.callback_function?.(text);
125
+ }
126
+ if (stream_end && this.callback_function === stdout_write && apis.IS_PROCESS_AVAILABLE) {
127
+ this.callback_function?.('\n');
128
+ }
129
+ }
130
+ }
131
+
132
+ /**
133
+ * Utility class to handle streaming of tokens generated by whisper speech-to-text models.
134
+ * Callback functions are invoked when each of the following events occur:
135
+ * - A new chunk starts (on_chunk_start)
136
+ * - A new token is generated (callback_function)
137
+ * - A chunk ends (on_chunk_end)
138
+ * - The stream is finalized (on_finalize)
139
+ */
140
+ export class WhisperTextStreamer extends TextStreamer {
141
+ /**
142
+ * @param {import('../tokenizers.js').WhisperTokenizer} tokenizer
143
+ * @param {Object} options
144
+ * @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens
145
+ * @param {function(string): void} [options.callback_function=null] Function to call when a piece of text is ready to display
146
+ * @param {function(string): void} [options.token_callback_function=null] Function to call when a new token is generated
147
+ * @param {function(number): void} [options.on_chunk_start=null] Function to call when a new chunk starts
148
+ * @param {function(number): void} [options.on_chunk_end=null] Function to call when a chunk ends
149
+ * @param {function(): void} [options.on_finalize=null] Function to call when the stream is finalized
150
+ * @param {number} [options.time_precision=0.02] Precision of the timestamps
151
+ * @param {boolean} [options.skip_special_tokens=true] Whether to skip special tokens when decoding
152
+ * @param {Object} [options.decode_kwargs={}] Additional keyword arguments to pass to the tokenizer's decode method
153
+ */
154
+ constructor(tokenizer, {
155
+ skip_prompt = false,
156
+ callback_function = null,
157
+ token_callback_function = null,
158
+ on_chunk_start = null,
159
+ on_chunk_end = null,
160
+ on_finalize = null,
161
+ time_precision = 0.02,
162
+ skip_special_tokens = true,
163
+ decode_kwargs = {},
164
+ } = {}) {
165
+ super(tokenizer, {
166
+ skip_prompt,
167
+ callback_function,
168
+ token_callback_function,
169
+ decode_kwargs: { skip_special_tokens, ...decode_kwargs },
170
+ });
171
+ this.timestamp_begin = tokenizer.timestamp_begin;
172
+
173
+ this.on_chunk_start = on_chunk_start;
174
+ this.on_chunk_end = on_chunk_end;
175
+ this.on_finalize = on_finalize;
176
+
177
+ this.time_precision = time_precision;
178
+
179
+ this.waiting_for_timestamp = false;
180
+ }
181
+
182
+ /**
183
+ * @param {bigint[][]} value
184
+ */
185
+ put(value) {
186
+ if (value.length > 1) {
187
+ throw Error('WhisperTextStreamer only supports batch size of 1');
188
+ }
189
+ const tokens = value[0];
190
+
191
+ // Check if the token is a timestamp
192
+ if (tokens.length === 1) {
193
+ const offset = Number(tokens[0]) - this.timestamp_begin;
194
+ if (offset >= 0) {
195
+ const time = offset * this.time_precision;
196
+ if (this.waiting_for_timestamp) {
197
+ this.on_chunk_end?.(time);
198
+ } else {
199
+ this.on_chunk_start?.(time);
200
+ }
201
+ this.waiting_for_timestamp = !this.waiting_for_timestamp; // Toggle
202
+ value = [[]]; // Skip timestamp
203
+ }
204
+ }
205
+ return super.put(value);
206
+ }
207
+
208
+ end() {
209
+ super.end();
210
+ this.on_finalize?.();
211
+ }
212
+ }