spec-cat 0.1.0 → 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. package/.output/nitro.json +1 -1
  2. package/.output/public/_nuxt/{Bqlz6CoK.js → BE_75kPa.js} +1 -1
  3. package/.output/public/_nuxt/{B2wdmh_w.js → BJ7m4fRW.js} +53 -53
  4. package/.output/public/_nuxt/{KNuzSjk0.js → CCNYUZ9m.js} +1 -1
  5. package/.output/public/_nuxt/{BvosqTnx.js → DGtcdWVl.js} +1 -1
  6. package/.output/public/_nuxt/DxEx-kFx.js +1 -0
  7. package/.output/public/_nuxt/{BwcbSlWF.js → DyMq_cQC.js} +2 -2
  8. package/.output/public/_nuxt/{COTT6rNZ.js → _cj5lOdZ.js} +1 -1
  9. package/.output/public/_nuxt/builds/latest.json +1 -1
  10. package/.output/public/_nuxt/builds/meta/3a0aacc1-0bd1-4d15-8b8a-3cee48cbfc69.json +1 -0
  11. package/.output/public/_nuxt/{BUOk7wkI.js → gDut6QrP.js} +1 -1
  12. package/.output/public/_nuxt/{C5wk2twv.js → nJpWpjzg.js} +1 -1
  13. package/.output/public/_nuxt/{DBab5Zcv.js → waQ9fPC1.js} +1 -1
  14. package/.output/server/chunks/_/codexProvider.mjs +64 -18
  15. package/.output/server/chunks/_/codexProvider.mjs.map +1 -1
  16. package/.output/server/chunks/build/client.precomputed.mjs +1 -1
  17. package/.output/server/chunks/build/client.precomputed.mjs.map +1 -1
  18. package/.output/server/chunks/nitro/nitro.mjs +702 -703
  19. package/.output/server/chunks/routes/_ws.mjs +37 -7
  20. package/.output/server/chunks/routes/_ws.mjs.map +1 -1
  21. package/.output/server/node_modules/@huggingface/jinja/dist/index.js +1572 -0
  22. package/.output/server/node_modules/@huggingface/jinja/package.json +55 -0
  23. package/.output/server/node_modules/@xenova/transformers/package.json +84 -0
  24. package/.output/server/node_modules/@xenova/transformers/src/backends/onnx.js +50 -0
  25. package/.output/server/node_modules/@xenova/transformers/src/configs.js +107 -0
  26. package/.output/server/node_modules/@xenova/transformers/src/env.js +128 -0
  27. package/.output/server/node_modules/@xenova/transformers/src/models.js +6267 -0
  28. package/.output/server/node_modules/@xenova/transformers/src/pipelines.js +3287 -0
  29. package/.output/server/node_modules/@xenova/transformers/src/processors.js +2248 -0
  30. package/.output/server/node_modules/@xenova/transformers/src/tokenizers.js +4479 -0
  31. package/.output/server/node_modules/@xenova/transformers/src/transformers.js +24 -0
  32. package/.output/server/node_modules/@xenova/transformers/src/utils/audio.js +672 -0
  33. package/.output/server/node_modules/@xenova/transformers/src/utils/core.js +175 -0
  34. package/.output/server/node_modules/@xenova/transformers/src/utils/data-structures.js +415 -0
  35. package/.output/server/node_modules/@xenova/transformers/src/utils/generation.js +873 -0
  36. package/.output/server/node_modules/@xenova/transformers/src/utils/hub.js +658 -0
  37. package/.output/server/node_modules/@xenova/transformers/src/utils/image.js +731 -0
  38. package/.output/server/node_modules/@xenova/transformers/src/utils/maths.js +985 -0
  39. package/.output/server/node_modules/@xenova/transformers/src/utils/tensor.js +1239 -0
  40. package/.output/server/node_modules/color/index.js +496 -0
  41. package/.output/server/node_modules/color/package.json +47 -0
  42. package/.output/server/node_modules/color-convert/conversions.js +839 -0
  43. package/.output/server/node_modules/color-convert/index.js +81 -0
  44. package/.output/server/node_modules/color-convert/package.json +48 -0
  45. package/.output/server/node_modules/color-convert/route.js +97 -0
  46. package/.output/server/node_modules/color-name/index.js +152 -0
  47. package/.output/server/node_modules/color-name/package.json +28 -0
  48. package/.output/server/node_modules/color-string/index.js +242 -0
  49. package/.output/server/node_modules/color-string/package.json +39 -0
  50. package/.output/server/node_modules/detect-libc/lib/detect-libc.js +313 -0
  51. package/.output/server/node_modules/detect-libc/lib/elf.js +39 -0
  52. package/.output/server/node_modules/detect-libc/lib/filesystem.js +51 -0
  53. package/.output/server/node_modules/detect-libc/lib/process.js +24 -0
  54. package/.output/server/node_modules/detect-libc/package.json +44 -0
  55. package/.output/server/node_modules/is-arrayish/index.js +9 -0
  56. package/.output/server/node_modules/is-arrayish/package.json +45 -0
  57. package/.output/server/node_modules/onnxruntime-common/dist/ort-common.node.js +7 -0
  58. package/.output/server/node_modules/onnxruntime-common/package.json +31 -0
  59. package/.output/server/node_modules/onnxruntime-node/bin/napi-v3/darwin/arm64/onnxruntime_binding.node +0 -0
  60. package/.output/server/node_modules/onnxruntime-node/bin/napi-v3/darwin/x64/onnxruntime_binding.node +0 -0
  61. package/.output/server/node_modules/onnxruntime-node/bin/napi-v3/linux/arm64/libonnxruntime.so.1.14.0 +0 -0
  62. package/.output/server/node_modules/onnxruntime-node/bin/napi-v3/linux/arm64/onnxruntime_binding.node +0 -0
  63. package/.output/server/node_modules/onnxruntime-node/bin/napi-v3/linux/x64/libonnxruntime.so.1.14.0 +0 -0
  64. package/.output/server/node_modules/onnxruntime-node/bin/napi-v3/linux/x64/onnxruntime_binding.node +0 -0
  65. package/.output/server/node_modules/onnxruntime-node/bin/napi-v3/win32/arm64/onnxruntime_binding.node +0 -0
  66. package/.output/server/node_modules/onnxruntime-node/bin/napi-v3/win32/x64/onnxruntime_binding.node +0 -0
  67. package/.output/server/node_modules/onnxruntime-node/dist/backend.js +75 -0
  68. package/.output/server/node_modules/onnxruntime-node/dist/binding.js +10 -0
  69. package/.output/server/node_modules/onnxruntime-node/dist/index.js +23 -0
  70. package/.output/server/node_modules/onnxruntime-node/package.json +58 -0
  71. package/.output/server/node_modules/onnxruntime-web/dist/ort-web.node.js +7 -0
  72. package/.output/server/node_modules/onnxruntime-web/package.json +84 -0
  73. package/.output/server/node_modules/semver/classes/semver.js +333 -0
  74. package/.output/server/node_modules/semver/functions/coerce.js +62 -0
  75. package/.output/server/node_modules/semver/functions/compare.js +7 -0
  76. package/.output/server/node_modules/semver/functions/gte.js +5 -0
  77. package/.output/server/node_modules/semver/functions/parse.js +18 -0
  78. package/.output/server/node_modules/semver/internal/constants.js +37 -0
  79. package/.output/server/node_modules/semver/internal/debug.js +11 -0
  80. package/.output/server/node_modules/semver/internal/identifiers.js +29 -0
  81. package/.output/server/node_modules/semver/internal/parse-options.js +17 -0
  82. package/.output/server/node_modules/semver/internal/re.js +223 -0
  83. package/.output/server/node_modules/semver/package.json +78 -0
  84. package/.output/server/node_modules/sharp/build/Release/sharp-linux-x64.node +0 -0
  85. package/.output/server/node_modules/sharp/lib/channel.js +174 -0
  86. package/.output/server/node_modules/sharp/lib/colour.js +184 -0
  87. package/.output/server/node_modules/sharp/lib/composite.js +210 -0
  88. package/.output/server/node_modules/sharp/lib/constructor.js +439 -0
  89. package/.output/server/node_modules/sharp/lib/index.js +16 -0
  90. package/.output/server/node_modules/sharp/lib/input.js +631 -0
  91. package/.output/server/node_modules/sharp/lib/is.js +155 -0
  92. package/.output/server/node_modules/sharp/lib/libvips.js +140 -0
  93. package/.output/server/node_modules/sharp/lib/operation.js +919 -0
  94. package/.output/server/node_modules/sharp/lib/output.js +1413 -0
  95. package/.output/server/node_modules/sharp/lib/platform.js +30 -0
  96. package/.output/server/node_modules/sharp/lib/resize.js +582 -0
  97. package/.output/server/node_modules/sharp/lib/sharp.js +38 -0
  98. package/.output/server/node_modules/sharp/lib/utility.js +287 -0
  99. package/.output/server/node_modules/sharp/package.json +204 -0
  100. package/.output/server/node_modules/sharp/vendor/8.14.5/linux-x64/THIRD-PARTY-NOTICES.md +43 -0
  101. package/.output/server/node_modules/sharp/vendor/8.14.5/linux-x64/lib/libvips-cpp.so.42 +0 -0
  102. package/.output/server/node_modules/sharp/vendor/8.14.5/linux-x64/platform.json +1 -0
  103. package/.output/server/node_modules/sharp/vendor/8.14.5/linux-x64/versions.json +31 -0
  104. package/.output/server/node_modules/simple-swizzle/index.js +29 -0
  105. package/.output/server/node_modules/simple-swizzle/package.json +36 -0
  106. package/.output/server/package.json +15 -1
  107. package/README.md +2 -0
  108. package/package.json +12 -19
  109. package/.output/public/_nuxt/5FxpIoe_.js +0 -1
  110. package/.output/public/_nuxt/builds/meta/21578a05-1b7e-4847-a8ff-7480800ea4a6.json +0 -1
@@ -0,0 +1,873 @@
1
+
2
+ /**
3
+ * @file Classes, functions, and utilities for generation.
4
+ *
5
+ * @todo Describe how to create a custom `GenerationConfig`.
6
+ *
7
+ * @module utils/generation
8
+ */
9
+ import { Tensor } from './tensor.js';
10
+ import {
11
+ Callable,
12
+ exists,
13
+ } from './core.js';
14
+ import {
15
+ max,
16
+ softmax,
17
+ log_softmax,
18
+ getTopItems,
19
+ } from './maths.js';
20
+
21
+ /**
22
+ * A class representing a list of logits processors. A logits processor is a function that modifies the logits
23
+ * output of a language model. This class provides methods for adding new processors and applying all processors to a
24
+ * batch of logits.
25
+ *
26
+ * @extends Callable
27
+ */
28
+ export class LogitsProcessorList extends Callable {
29
+ /**
30
+ * Constructs a new instance of `LogitsProcessorList`.
31
+ */
32
+ constructor() {
33
+ super();
34
+ this.processors = [];
35
+ }
36
+
37
+ /**
38
+ * Adds a new logits processor to the list.
39
+ *
40
+ * @param {LogitsProcessor} item The logits processor function to add.
41
+ */
42
+ push(item) {
43
+ this.processors.push(item);
44
+ }
45
+
46
+ /**
47
+ * Adds multiple logits processors to the list.
48
+ *
49
+ * @param {LogitsProcessor[]} items The logits processor functions to add.
50
+ */
51
+ extend(items) {
52
+ this.processors.push(...items);
53
+ }
54
+
55
+ /**
56
+ * Applies all logits processors in the list to a batch of logits, modifying them in-place.
57
+ *
58
+ * @param {number[]} input_ids The input IDs for the language model.
59
+ * @param {number[][]} batchedLogits A 2D array of logits, where each row corresponds to a single
60
+ * input sequence in the batch.
61
+ */
62
+ _call(input_ids, batchedLogits) {
63
+ // NOTE: This is different from the Python code, since vanilla JS does not support vectorized operations.
64
+ // As a result, we apply each processor to each item in the batch.
65
+ for (let logits of batchedLogits) {
66
+ // Modifies logits inplace
67
+ this.processors.forEach(
68
+ func => func(input_ids, logits)
69
+ )
70
+ }
71
+ }
72
+
73
+ [Symbol.iterator]() {
74
+ return this.processors.values();
75
+ }
76
+ }
77
+
78
+ /**
79
+ * Base class for processing logits.
80
+ * @extends Callable
81
+ */
82
+ export class LogitsProcessor extends Callable {
83
+ /**
84
+ * Apply the processor to the input logits.
85
+ *
86
+ * @abstract
87
+ * @param {Array} input_ids The input ids.
88
+ * @param {Tensor} logits The logits to process.
89
+ * @throws {Error} Throws an error if `_call` is not implemented in the subclass.
90
+ */
91
+ _call(input_ids, logits) {
92
+ throw Error("`_call` should be implemented in a subclass")
93
+ }
94
+ }
95
+
96
+ /**
97
+ * A logits processor that forces a specific token to be generated by the decoder.
98
+ *
99
+ * @extends LogitsProcessor
100
+ */
101
+ export class ForceTokensLogitsProcessor extends LogitsProcessor {
102
+ /**
103
+ * Constructs a new instance of `ForceTokensLogitsProcessor`.
104
+ *
105
+ * @param {Array} forced_decoder_ids The ids of tokens that should be forced.
106
+ */
107
+ constructor(forced_decoder_ids) {
108
+ super();
109
+ this.force_token_map = Object.fromEntries(forced_decoder_ids ?? []);
110
+ }
111
+
112
+ /**
113
+ * Apply the processor to the input logits.
114
+ *
115
+ * @param {Array} input_ids The input ids.
116
+ * @param {Tensor} logits The logits to process.
117
+ * @returns {Tensor} The processed logits.
118
+ */
119
+ _call(input_ids, logits) {
120
+ let map = this.force_token_map[input_ids.length];
121
+ if (exists(map)) { // There exists a mapping
122
+ logits.data.fill(-Infinity)
123
+ logits.data[map] = 0;
124
+ }
125
+ return logits;
126
+ }
127
+ }
128
+
129
+ /**
130
+ * A LogitsProcessor that forces a BOS token at the beginning of the generated sequence.
131
+ * @extends LogitsProcessor
132
+ */
133
+ export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
134
+ /**
135
+ * Create a ForcedBOSTokenLogitsProcessor.
136
+ * @param {number} bos_token_id The ID of the beginning-of-sequence token to be forced.
137
+ */
138
+ constructor(bos_token_id) {
139
+ super();
140
+ this.bos_token_id = bos_token_id;
141
+ }
142
+
143
+ /**
144
+ * Apply the BOS token forcing to the logits.
145
+ * @param {Array} input_ids The input IDs.
146
+ * @param {Object} logits The logits.
147
+ * @returns {Object} The logits with BOS token forcing.
148
+ */
149
+ _call(input_ids, logits) {
150
+ if (input_ids.length === 1) {
151
+ logits.data.fill(-Infinity)
152
+ logits.data[this.bos_token_id] = 0;
153
+ }
154
+ return logits;
155
+ }
156
+ }
157
+
158
+ /**
159
+ * A logits processor that forces end-of-sequence token probability to 1.
160
+ *
161
+ * @extends LogitsProcessor
162
+ */
163
+ export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
164
+ /**
165
+ * Create a ForcedEOSTokenLogitsProcessor.
166
+ * @param {number} max_length Max length of the sequence.
167
+ * @param {number|number[]} forced_eos_token_id The ID of the end-of-sequence token to be forced.
168
+ */
169
+ constructor(max_length, forced_eos_token_id) {
170
+ super();
171
+ this.max_length = max_length;
172
+ this.forced_eos_token_id = forced_eos_token_id;
173
+ }
174
+
175
+ /**
176
+ * Apply the processor to input_ids and logits.
177
+ *
178
+ * @param {number[]} input_ids The input ids.
179
+ * @param {Tensor} logits The logits tensor.
180
+ */
181
+ _call(input_ids, logits) {
182
+ // console.log('call ForcedEOSTokenLogitsProcessor')
183
+ // TODO
184
+ }
185
+ }
186
+
187
+ /**
188
+ * A LogitsProcessor that suppresses a list of tokens as soon as the `generate` function starts
189
+ * generating using `begin_index` tokens. This should ensure that the tokens defined by
190
+ * `begin_suppress_tokens` at not sampled at the begining of the generation.
191
+ * @extends LogitsProcessor
192
+ */
193
+ export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor {
194
+ /**
195
+ * Create a SuppressTokensAtBeginLogitsProcessor.
196
+ * @param {number[]} begin_suppress_tokens The IDs of the tokens to suppress.
197
+ * @param {number} begin_index The number of tokens to generate before suppressing tokens.
198
+ */
199
+ constructor(begin_suppress_tokens, begin_index) {
200
+ super();
201
+ this.begin_suppress_tokens = begin_suppress_tokens;
202
+ this.begin_index = begin_index;
203
+ }
204
+
205
+ /**
206
+ * Apply the BOS token forcing to the logits.
207
+ * @param {Array} input_ids The input IDs.
208
+ * @param {Object} logits The logits.
209
+ * @returns {Object} The logits with BOS token forcing.
210
+ */
211
+ _call(input_ids, logits) {
212
+ if (input_ids.length === this.begin_index) {
213
+ for (let token_id of this.begin_suppress_tokens) {
214
+ logits.data[token_id] = -Infinity;
215
+ }
216
+ }
217
+ return logits;
218
+ }
219
+ }
220
+
221
+ /**
222
+ * A LogitsProcessor that handles adding timestamps to generated text.
223
+ * @extends LogitsProcessor
224
+ */
225
+ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
226
+ /**
227
+ * Constructs a new WhisperTimeStampLogitsProcessor.
228
+ * @param {Object} generate_config The config object passed to the `generate()` method of a transformer model.
229
+ * @param {number} generate_config.eos_token_id The ID of the end-of-sequence token.
230
+ * @param {number} generate_config.no_timestamps_token_id The ID of the token used to indicate that a token should not have a timestamp.
231
+ * @param {number[][]} [generate_config.forced_decoder_ids] An array of two-element arrays representing decoder IDs that are forced to appear in the output. The second element of each array indicates whether the token is a timestamp.
232
+ * @param {number} [generate_config.max_initial_timestamp_index] The maximum index at which an initial timestamp can appear.
233
+ */
234
+ constructor(generate_config) {
235
+ super();
236
+ this.eos_token_id = generate_config.eos_token_id;
237
+ this.no_timestamps_token_id = generate_config.no_timestamps_token_id;
238
+ this.timestamp_begin = this.no_timestamps_token_id + 1;
239
+
240
+ this.begin_index = (generate_config.forced_decoder_ids || []).length + 2;
241
+ if (generate_config.forced_decoder_ids.slice(-1)[0][1] === this.no_timestamps_token_id) {
242
+ this.begin_index -= 1;
243
+ }
244
+ this.max_initial_timestamp_index = generate_config.max_initial_timestamp_index;
245
+
246
+ }
247
+
248
+ /**
249
+ * Modify the logits to handle timestamp tokens.
250
+ * @param {Array} input_ids The input sequence of tokens.
251
+ * @param {Tensor} logits The logits output by the model.
252
+ * @returns {Tensor} The modified logits.
253
+ */
254
+ _call(input_ids, logits) {
255
+ const logitsData = /** @type {Float32Array} */(logits.data);
256
+
257
+ // suppress <|notimestamps|> which is handled by without_timestamps
258
+ logitsData[this.no_timestamps_token_id] = -Infinity;
259
+
260
+ if (input_ids.length === this.begin_index - 1) {
261
+ logitsData.fill(-Infinity);
262
+ logitsData[this.timestamp_begin] = 0;
263
+ return logits;
264
+ }
265
+
266
+ // timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
267
+ const seq = input_ids.slice(this.begin_index);
268
+ const last_was_timestamp = seq.length >= 1 && seq[seq.length - 1] >= this.timestamp_begin;
269
+ const penultimate_was_timestamp = seq.length < 2 || seq[seq.length - 2] >= this.timestamp_begin;
270
+
271
+ if (last_was_timestamp) {
272
+ if (penultimate_was_timestamp) { // has to be non-timestamp
273
+ logitsData.subarray(this.timestamp_begin).fill(-Infinity);
274
+ } else { // cannot be normal text tokens
275
+ logitsData.subarray(0, this.eos_token_id).fill(-Infinity);
276
+ }
277
+ }
278
+
279
+ // apply the `max_initial_timestamp` option
280
+ if (input_ids.length === this.begin_index && this.max_initial_timestamp_index !== null) {
281
+ const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index;
282
+ logitsData.subarray(last_allowed + 1).fill(-Infinity);
283
+ }
284
+
285
+ // if sum of probability over timestamps is above any other token, sample timestamp
286
+ const logprobs = log_softmax(logitsData);
287
+ const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b));
288
+ const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0];
289
+
290
+ if (timestamp_logprob > max_text_token_logprob) {
291
+ logitsData.subarray(0, this.timestamp_begin).fill(-Infinity);
292
+ }
293
+
294
+ return logits;
295
+ }
296
+ }
297
+
298
+ /**
299
+ * A logits processor that disallows ngrams of a certain size to be repeated.
300
+ *
301
+ * @extends LogitsProcessor
302
+ */
303
+ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
304
+ /**
305
+ * Create a NoRepeatNGramLogitsProcessor.
306
+ * @param {number} no_repeat_ngram_size The no-repeat-ngram size. All ngrams of this size can only occur once.
307
+ */
308
+ constructor(no_repeat_ngram_size) {
309
+ super();
310
+ this.no_repeat_ngram_size = no_repeat_ngram_size;
311
+ }
312
+
313
+ /**
314
+ * Generate n-grams from a sequence of token ids.
315
+ * @param {number[]} prevInputIds List of previous input ids
316
+ * @returns {Map<string, number[]>} Map of generated n-grams
317
+ */
318
+ getNgrams(prevInputIds) {
319
+ const curLen = prevInputIds.length;
320
+
321
+ /**@type {number[][]} */
322
+ const ngrams = [];
323
+ for (let j = 0; j < curLen + 1 - this.no_repeat_ngram_size; ++j) {
324
+ const ngram = [];
325
+ for (let k = 0; k < this.no_repeat_ngram_size; ++k) {
326
+ ngram.push(prevInputIds[j + k]);
327
+ }
328
+ ngrams.push(ngram);
329
+ }
330
+
331
+ /** @type {Map<string, number[]>} */
332
+ const generatedNgram = new Map();
333
+ for (const ngram of ngrams) {
334
+ const prevNgram = ngram.slice(0, ngram.length - 1);
335
+ const prevNgramKey = JSON.stringify(prevNgram);
336
+ const prevNgramValue = generatedNgram.get(prevNgramKey) ?? [];
337
+ prevNgramValue.push(ngram[ngram.length - 1]);
338
+ generatedNgram.set(prevNgramKey, prevNgramValue);
339
+ }
340
+ return generatedNgram;
341
+ }
342
+
343
+ /**
344
+ * Generate n-grams from a sequence of token ids.
345
+ * @param {Map<string, number[]>} bannedNgrams Map of banned n-grams
346
+ * @param {number[]} prevInputIds List of previous input ids
347
+ * @returns {number[]} Map of generated n-grams
348
+ */
349
+ getGeneratedNgrams(bannedNgrams, prevInputIds) {
350
+ const ngramIdx = prevInputIds.slice(prevInputIds.length + 1 - this.no_repeat_ngram_size, prevInputIds.length);
351
+ const banned = bannedNgrams.get(JSON.stringify(ngramIdx)) ?? [];
352
+ return banned;
353
+ }
354
+
355
+ /**
356
+ * Calculate banned n-gram tokens
357
+ * @param {number[]} prevInputIds List of previous input ids
358
+ * @returns {number[]} Map of generated n-grams
359
+ */
360
+ calcBannedNgramTokens(prevInputIds) {
361
+ const bannedTokens = [];
362
+ if (prevInputIds.length + 1 < this.no_repeat_ngram_size) {
363
+ // return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
364
+ return bannedTokens;
365
+
366
+ } else {
367
+ const generatedNgrams = this.getNgrams(prevInputIds);
368
+ const bannedTokens = this.getGeneratedNgrams(generatedNgrams, prevInputIds);
369
+ return bannedTokens;
370
+ }
371
+ }
372
+
373
+ /**
374
+ * Apply the no-repeat-ngram processor to the logits.
375
+ * @param {Array} input_ids The input IDs.
376
+ * @param {Object} logits The logits.
377
+ * @returns {Object} The logits with no-repeat-ngram processing.
378
+ */
379
+ _call(input_ids, logits) {
380
+ const bannedTokens = this.calcBannedNgramTokens(input_ids);
381
+
382
+ for (const token of bannedTokens) {
383
+ logits.data[token] = -Infinity;
384
+ }
385
+ return logits;
386
+ }
387
+ }
388
+
389
+ /**
390
+ * A logits processor that penalises repeated output tokens.
391
+ *
392
+ * @extends LogitsProcessor
393
+ */
394
+ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
395
+ /**
396
+ * Create a RepetitionPenaltyLogitsProcessor.
397
+ * @param {number} penalty The penalty to apply for repeated tokens.
398
+ */
399
+ constructor(penalty) {
400
+ super();
401
+ this.penalty = penalty;
402
+ }
403
+
404
+ /**
405
+ * Apply the repetition penalty to the logits.
406
+ * @param {Array} input_ids The input IDs.
407
+ * @param {Object} logits The logits.
408
+ * @returns {Object} The logits with repetition penalty processing.
409
+ */
410
+ _call(input_ids, logits) {
411
+ // Modify the logits corresponding to each element in `input_ids`.
412
+ // As a consequence, the logits corresponding to tokens that appear
413
+ // many times in the output will be penalised more.
414
+ for (const input_id of input_ids) {
415
+ if (logits.data[input_id] < 0) {
416
+ logits.data[input_id] *= this.penalty;
417
+ } else {
418
+ logits.data[input_id] /= this.penalty;
419
+ }
420
+ }
421
+ return logits
422
+ }
423
+ }
424
+
425
+ /**
426
+ * A logits processor that enforces a minimum number of tokens.
427
+ *
428
+ * @extends LogitsProcessor
429
+ */
430
+ export class MinLengthLogitsProcessor extends LogitsProcessor {
431
+ /**
432
+ * Create a MinLengthLogitsProcessor.
433
+ * @param {number} min_length The minimum length below which the score of `eos_token_id` is set to negative infinity.
434
+ * @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
435
+ */
436
+ constructor(min_length, eos_token_id) {
437
+ super();
438
+ this.min_length = min_length;
439
+ this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
440
+ }
441
+
442
+ /**
443
+ * Apply logit processor.
444
+ * @param {Array} input_ids The input IDs.
445
+ * @param {Object} logits The logits.
446
+ * @returns {Object} The processed logits.
447
+ */
448
+ _call(input_ids, logits) {
449
+ if (input_ids.length < this.min_length) {
450
+ for (const eos_token of this.eos_token_id) {
451
+ logits.data[eos_token] = -Infinity;
452
+ }
453
+ }
454
+
455
+ return logits
456
+ }
457
+ }
458
+
459
+ /**
460
+ * A logits processor that enforces a minimum number of new tokens.
461
+ *
462
+ * @extends LogitsProcessor
463
+ */
464
+ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
465
+ /**
466
+ * Create a MinNewTokensLengthLogitsProcessor.
467
+ * @param {number} prompt_length_to_skip The input tokens length.
468
+ * @param {number} min_new_tokens The minimum *new* tokens length below which the score of `eos_token_id` is set to negative infinity.
469
+ * @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
470
+ */
471
+ constructor(prompt_length_to_skip, min_new_tokens, eos_token_id) {
472
+ super();
473
+ this.prompt_length_to_skip = prompt_length_to_skip;
474
+ this.min_new_tokens = min_new_tokens;
475
+ this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
476
+ }
477
+
478
+ /**
479
+ * Apply logit processor.
480
+ * @param {Array} input_ids The input IDs.
481
+ * @param {Object} logits The logits.
482
+ * @returns {Object} The processed logits.
483
+ */
484
+ _call(input_ids, logits) {
485
+ const new_tokens_length = input_ids.length - this.prompt_length_to_skip;
486
+ if (new_tokens_length < this.min_new_tokens) {
487
+ for (const eos_token of this.eos_token_id) {
488
+ logits.data[eos_token] = -Infinity;
489
+ }
490
+ }
491
+
492
+ return logits
493
+ }
494
+ }
495
+
496
+ export class NoBadWordsLogitsProcessor extends LogitsProcessor {
497
+ /**
498
+ * Create a `NoBadWordsLogitsProcessor`.
499
+ * @param {number[][]} bad_words_ids List of list of token ids that are not allowed to be generated.
500
+ * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
501
+ */
502
+ constructor(bad_words_ids, eos_token_id) {
503
+ super();
504
+ this.bad_words_ids = bad_words_ids;
505
+ this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
506
+ }
507
+
508
+ /**
509
+ * Apply logit processor.
510
+ * @param {Array} input_ids The input IDs.
511
+ * @param {Object} logits The logits.
512
+ * @returns {Object} The processed logits.
513
+ */
514
+ _call(input_ids, logits) {
515
+
516
+ for (const bad_word_ids of this.bad_words_ids) {
517
+ // Whether to modify the logits of the last token in the bad word id sequence
518
+ let mark = true;
519
+
520
+ // For each bad word in the list, if the current sequence of input ids ends with this sequence (excluding the last),
521
+ // then we set the logits of the last bad word id to -Infinity.
522
+ for (let i = 1; i <= bad_word_ids.length - 1 && bad_word_ids.length < input_ids.length; ++i) {
523
+
524
+ if (bad_word_ids.at(-i - 1) !== input_ids.at(-i)) {
525
+ // We have found a mismatch
526
+ mark = false;
527
+ break;
528
+ }
529
+ }
530
+ if (mark) {
531
+ logits.data[bad_word_ids.at(-1)] = -Infinity;
532
+ }
533
+ }
534
+
535
+ return logits
536
+ }
537
+ }
538
+
539
+ /**
540
+ * @typedef {Object} GenerationConfigType The default configuration parameters.
541
+ * @property {number} [max_length=20] The maximum length the generated tokens can have. Corresponds to the length of the input prompt + `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set.
542
+ * @property {number} [max_new_tokens=null] The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
543
+ * @property {number} [min_length=0] The minimum length of the sequence to be generated. Corresponds to the length of the input prompt + `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.
544
+ * @property {number} [min_new_tokens=null] The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
545
+ * @property {boolean|"never"} [early_stopping=false] Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
546
+ * - `true`, where the generation stops as soon as there are `num_beams` complete candidates;
547
+ * - `false`, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates;
548
+ * - `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm).
549
+ * @property {number} [max_time=null] The maximum amount of time you allow the computation to run for in seconds. Generation will still finish the current pass after allocated time has been passed.
550
+ *
551
+ * @property {boolean} [do_sample=false] Whether or not to use sampling; use greedy decoding otherwise.
552
+ * @property {number} [num_beams=1] Number of beams for beam search. 1 means no beam search.
553
+ * @property {number} [num_beam_groups=1] Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
554
+ * @property {number} [penalty_alpha=null] The values balance the model confidence and the degeneration penalty in contrastive search decoding.
555
+ * @property {boolean} [use_cache=true] Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
556
+ *
557
+ * @property {number} [temperature=1.0] The value used to modulate the next token probabilities.
558
+ * @property {number} [top_k=50] The number of highest probability vocabulary tokens to keep for top-k-filtering.
559
+ * @property {number} [top_p=1.0] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.
560
+ * @property {number} [typical_p=1.0] Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text already generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that add up to `typical_p` or higher are kept for generation. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
561
+ * @property {number} [epsilon_cutoff=0.0] If set to float strictly between 0 and 1, only tokens with a conditional probability greater than `epsilon_cutoff` will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more details.
562
+ * @property {number} [eta_cutoff=0.0] Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between 0 and 1, a token is only considered if it is greater than either `eta_cutoff` or `sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits)))`. The latter term is intuitively the expected next token probability, scaled by `sqrt(eta_cutoff)`. In the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more details.
563
+ * @property {number} [diversity_penalty=0.0] This value is subtracted from a beam's score if it generates a token same as any beam from other group at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled.
564
+ * @property {number} [repetition_penalty=1.0] The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
565
+ * @property {number} [encoder_repetition_penalty=1.0] The paramater for encoder_repetition_penalty. An exponential penalty on sequences that are not in the original input. 1.0 means no penalty.
566
+ * @property {number} [length_penalty=1.0] Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
567
+ * @property {number} [no_repeat_ngram_size=0] If set to int > 0, all ngrams of that size can only occur once.
568
+ * @property {number[][]} [bad_words_ids=null] List of token ids that are not allowed to be generated. In order to get the token ids of the words that should not appear in the generated text, use `(await tokenizer(bad_words, {add_prefix_space: true, add_special_tokens: false})).input_ids`.
569
+ * @property {number[][]|number[][][]} [force_words_ids=null] List of token ids that must be generated. If given a `number[][]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`. If given `number[][][]`, this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one can allow different forms of each word.
570
+ * @property {boolean} [renormalize_logits=false] Whether to renormalize the logits after applying all the logits processors or warpers (including the custom ones). It's highly recommended to set this flag to `true` as the search algorithms suppose the score logits are normalized but some logit processors or warpers break the normalization.
571
+ * @property {Object[]} [constraints=null] Custom constraints that can be added to the generation to ensure that the output will contain the use of certain tokens as defined by `Constraint` objects, in the most sensible way possible.
572
+ *
573
+ * @property {number} [forced_bos_token_id=null] The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for multilingual models like mBART where the first generated token needs to be the target language token.
574
+ * @property {number|number[]} [forced_eos_token_id=null] The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a list to set multiple *end-of-sequence* tokens.
575
+ * @property {boolean} [remove_invalid_values=false] Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash. Note that using `remove_invalid_values` can slow down generation.
576
+ * @property {number[]} [exponential_decay_length_penalty=null] This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay.
577
+ * @property {number[]} [suppress_tokens=null] A list of tokens that will be suppressed at generation. The `SupressTokens` logit processor will set their log probs to `-inf` so that they are not sampled.
578
+ * @property {number[]} [begin_suppress_tokens=null] A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled.
579
+ * @property {number[][]} [forced_decoder_ids=null] A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token of index 123.
580
+ *
581
+ * @property {number} [num_return_sequences=1] The number of independently computed returned sequences for each element in the batch.
582
+ * @property {boolean} [output_attentions=false] Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details.
583
+ * @property {boolean} [output_hidden_states=false] Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details.
584
+ * @property {boolean} [output_scores=false] Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
585
+ * @property {boolean} [return_dict_in_generate=false] Whether or not to return a `ModelOutput` instead of a plain tuple.
586
+ *
587
+ * @property {number} [pad_token_id=null] The id of the *padding* token.
588
+ * @property {number} [bos_token_id=null] The id of the *beginning-of-sequence* token.
589
+ * @property {number|number[]} [eos_token_id=null] The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
590
+ *
591
+ * @property {number} [encoder_no_repeat_ngram_size=0] If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`.
592
+ * @property {number} [decoder_start_token_id=null] If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
593
+ *
594
+ * @property {Object} [generation_kwargs={}] Additional generation kwargs will be forwarded to the `generate` function of the model. Kwargs that are not present in `generate`'s signature will be used in the model forward pass.
595
+ */
596
+
597
+ /**
598
+ * Class that holds a configuration for a generation task.
599
+ * @type {new (kwargs?: GenerationConfigType) => GenerationConfigType}
600
+ */
601
+ export const GenerationConfig = /** @type {any} */ (class {
602
+
603
+ /**
604
+ * Create a new GenerationConfig object.
605
+ * @param {GenerationConfigType} kwargs
606
+ */
607
+ constructor(kwargs = {}) {
608
+ // Parameters that control the length of the output
609
+ this.max_length = kwargs.max_length ?? 20;
610
+ this.max_new_tokens = kwargs.max_new_tokens ?? null;
611
+ this.min_length = kwargs.min_length ?? 0;
612
+ this.min_new_tokens = kwargs.min_new_tokens ?? null;
613
+ this.early_stopping = kwargs.early_stopping ?? false;
614
+ this.max_time = kwargs.max_time ?? null;
615
+
616
+ // Parameters that control the generation strategy used
617
+ this.do_sample = kwargs.do_sample ?? false;
618
+ this.num_beams = kwargs.num_beams ?? 1;
619
+ this.num_beam_groups = kwargs.num_beam_groups ?? 1;
620
+ this.penalty_alpha = kwargs.penalty_alpha ?? null;
621
+ this.use_cache = kwargs.use_cache ?? true;
622
+
623
+ // Parameters for manipulation of the model output logits
624
+ this.temperature = kwargs.temperature ?? 1.0;
625
+ this.top_k = kwargs.top_k ?? 50;
626
+ this.top_p = kwargs.top_p ?? 1.0;
627
+ this.typical_p = kwargs.typical_p ?? 1.0;
628
+ this.epsilon_cutoff = kwargs.epsilon_cutoff ?? 0.0;
629
+ this.eta_cutoff = kwargs.eta_cutoff ?? 0.0;
630
+ this.diversity_penalty = kwargs.diversity_penalty ?? 0.0;
631
+ this.repetition_penalty = kwargs.repetition_penalty ?? 1.0;
632
+ this.encoder_repetition_penalty = kwargs.encoder_repetition_penalty ?? 1.0;
633
+ this.length_penalty = kwargs.length_penalty ?? 1.0;
634
+ this.no_repeat_ngram_size = kwargs.no_repeat_ngram_size ?? 0;
635
+ this.bad_words_ids = kwargs.bad_words_ids ?? null;
636
+ this.force_words_ids = kwargs.force_words_ids ?? null;
637
+ this.renormalize_logits = kwargs.renormalize_logits ?? false;
638
+ this.constraints = kwargs.constraints ?? null;
639
+ this.forced_bos_token_id = kwargs.forced_bos_token_id ?? null;
640
+ this.forced_eos_token_id = kwargs.forced_eos_token_id ?? null;
641
+ this.remove_invalid_values = kwargs.remove_invalid_values ?? false;
642
+ this.exponential_decay_length_penalty = kwargs.exponential_decay_length_penalty ?? null;
643
+ this.suppress_tokens = kwargs.suppress_tokens ?? null;
644
+ this.begin_suppress_tokens = kwargs.begin_suppress_tokens ?? null;
645
+ this.forced_decoder_ids = kwargs.forced_decoder_ids ?? null;
646
+
647
+ // Parameters that define the output variables of `generate`
648
+ this.num_return_sequences = kwargs.num_return_sequences ?? 1;
649
+ this.output_attentions = kwargs.output_attentions ?? false;
650
+ this.output_hidden_states = kwargs.output_hidden_states ?? false;
651
+ this.output_scores = kwargs.output_scores ?? false;
652
+ this.return_dict_in_generate = kwargs.return_dict_in_generate ?? false;
653
+
654
+ // Special tokens that can be used at generation time
655
+ this.pad_token_id = kwargs.pad_token_id ?? null;
656
+ this.bos_token_id = kwargs.bos_token_id ?? null;
657
+ this.eos_token_id = kwargs.eos_token_id ?? null;
658
+
659
+ // Generation parameters exclusive to encoder-decoder models
660
+ this.encoder_no_repeat_ngram_size = kwargs.encoder_no_repeat_ngram_size ?? 0;
661
+ this.decoder_start_token_id = kwargs.decoder_start_token_id ?? null;
662
+
663
+ // Wild card
664
+ this.generation_kwargs = kwargs.generation_kwargs ?? {};
665
+ }
666
+ });
667
+
668
+ /**
669
+ * Sampler is a base class for all sampling methods used for text generation.
670
+ */
671
+ export class Sampler extends Callable {
672
+ /**
673
+ * Creates a new Sampler object with the specified generation config.
674
+ * @param {GenerationConfigType} generation_config The generation config.
675
+ */
676
+ constructor(generation_config) {
677
+ super();
678
+ this.generation_config = generation_config;
679
+ }
680
+
681
+ /**
682
+ * Executes the sampler, using the specified logits.
683
+ * @param {Tensor} logits
684
+ * @param {number} index
685
+ * @returns {void}
686
+ */
687
+ _call(logits, index = -1) {
688
+ // Sample from logits, of dims [batch, sequence_length, vocab_size].
689
+ // If index is specified, sample from [batch, index, vocab_size].
690
+ return this.sample(logits, index);
691
+ }
692
+
693
+ /**
694
+ * Abstract method for sampling the logits.
695
+ * @param {Tensor} logits
696
+ * @param {number} index
697
+ * @throws {Error}
698
+ */
699
+ sample(logits, index) {
700
+ throw Error("sample should be implemented in subclasses.")
701
+ }
702
+
703
+ /**
704
+ * Returns the specified logits as an array, with temperature applied.
705
+ * @param {Tensor} logits
706
+ * @param {number} index
707
+ * @returns {Float32Array}
708
+ */
709
+ getLogits(logits, index) {
710
+ let vocabSize = logits.dims.at(-1);
711
+
712
+ let logs = /** @type {Float32Array} */(logits.data);
713
+
714
+ if (index === -1) {
715
+ logs = logs.slice(-vocabSize);
716
+ } else {
717
+ let startIndex = index * vocabSize;
718
+ logs = logs.slice(startIndex, startIndex + vocabSize);
719
+ }
720
+
721
+ // add temperature
722
+ if (this.generation_config.temperature > 0) {
723
+ logs = logs.map(x => x / this.generation_config.temperature)
724
+ }
725
+ return logs;
726
+ }
727
+
728
+ /**
729
+ * Selects an item randomly based on the specified probabilities.
730
+ * @param {Array} probabilities An array of probabilities to use for selection.
731
+ * @returns {number} The index of the selected item.
732
+ */
733
+ randomSelect(probabilities) {
734
+ // Return index of chosen item
735
+ let sumProbabilities = probabilities.reduce((acc, curr) => acc + curr, 0);
736
+
737
+ let r = Math.random() * sumProbabilities;
738
+ for (let i = 0; i < probabilities.length; ++i) {
739
+ r -= probabilities[i];
740
+ if (r <= 0) {
741
+ return i;
742
+ }
743
+ }
744
+ return 0; // return first (most probable) as a fallback
745
+ }
746
+
747
+ /**
748
+ * Returns a Sampler object based on the specified options.
749
+ * @param {GenerationConfigType} generation_config An object containing options for the sampler.
750
+ * @returns {Sampler} A Sampler object.
751
+ */
752
+ static getSampler(generation_config) {
753
+ // - *greedy decoding*: `num_beams=1` and `do_sample=False`
754
+ // - *contrastive search*: `penalty_alpha>0` and `top_k>1`
755
+ // - *multinomial sampling*: `num_beams=1` and `do_sample=True`
756
+ // - *beam-search decoding*: `num_beams>1` and `do_sample=False`
757
+ // - *beam-search multinomial sampling*: `num_beams>1` and `do_sample=True`
758
+ // - *diverse beam-search decoding*: `num_beams>1` and `num_beam_groups>1`
759
+ // - *constrained beam-search decoding*: `constraints!=None` or `force_words_ids!=None`
760
+
761
+ // NOTE: beam search is implemented directly into the generation function
762
+ if (generation_config.do_sample) {
763
+ return new MultinomialSampler(generation_config);
764
+
765
+ } else if (generation_config.num_beams > 1) {
766
+ return new BeamSearchSampler(generation_config);
767
+
768
+ } else {
769
+ if (generation_config.num_return_sequences > 1) {
770
+ throw Error(`num_return_sequences has to be 1 when doing greedy search, but is ${generation_config.num_return_sequences}.`)
771
+ }
772
+ return new GreedySampler(generation_config);
773
+ }
774
+ }
775
+ }
776
+
777
+ /**
778
+ * Class representing a Greedy Sampler.
779
+ * @extends Sampler
780
+ */
781
+ class GreedySampler extends Sampler {
782
+ /**
783
+ * Sample the maximum probability of a given logits tensor.
784
+ * @param {Tensor} logits
785
+ * @param {number} [index=-1]
786
+ * @returns {Array} An array with a single tuple, containing the index of the maximum value and a meaningless score (since this is a greedy search).
787
+ */
788
+ sample(logits, index = -1) {
789
+ // NOTE: no need to do log_softmax here since we only take the maximum
790
+ let logs = this.getLogits(logits, index);
791
+ let argmax = max(logs)[1];
792
+
793
+ // Note: score is meaningless in this context, since we are performing
794
+ // greedy search (p = 1 => log(p) = 0)
795
+ return [
796
+ [argmax, 0]
797
+ ];
798
+ }
799
+ }
800
+
801
+ /**
802
+ * Class representing a MultinomialSampler.
803
+ * @extends Sampler
804
+ */
805
+ class MultinomialSampler extends Sampler {
806
+
807
+ /**
808
+ * Sample from the logits.
809
+ * @param {Tensor} logits
810
+ * @param {number} index
811
+ * @returns {Array}
812
+ */
813
+ sample(logits, index = -1) {
814
+ let k = logits.dims.at(-1); // defaults to vocab size
815
+ if (this.generation_config.top_k > 0) {
816
+ k = Math.min(this.generation_config.top_k, k);
817
+ }
818
+
819
+ // Get logits of nth token
820
+ const logs = this.getLogits(logits, index);
821
+
822
+ // Get top k tokens
823
+ const topLogits = getTopItems(logs, k);
824
+
825
+ // Compute softmax over logits
826
+ const probabilities = softmax(topLogits.map(x => x[1]));
827
+
828
+ return Array.from({ length: this.generation_config.num_beams }, () => {
829
+ const sampledIndex = this.randomSelect(probabilities);
830
+ return [
831
+ topLogits[sampledIndex][0], // token id
832
+ Math.log(probabilities[sampledIndex]), // score
833
+ ];
834
+ });
835
+ }
836
+ }
837
+
838
+
839
+ /**
840
+ * Class representing a BeamSearchSampler.
841
+ * @extends Sampler
842
+ */
843
+ class BeamSearchSampler extends Sampler {
844
+
845
+ /**
846
+ * Sample from the logits.
847
+ * @param {Tensor} logits
848
+ * @param {number} index
849
+ * @returns {Array}
850
+ */
851
+ sample(logits, index = -1) {
852
+ let k = logits.dims.at(-1); // defaults to vocab size
853
+ if (this.generation_config.top_k > 0) {
854
+ k = Math.min(this.generation_config.top_k, k);
855
+ }
856
+
857
+ // Get logits of nth token
858
+ const logs = this.getLogits(logits, index);
859
+
860
+ // Get top k tokens
861
+ const topLogits = getTopItems(logs, k);
862
+
863
+ // Compute softmax over logits
864
+ const probabilities = softmax(topLogits.map(x => x[1]));
865
+
866
+ return Array.from({ length: this.generation_config.num_beams }, (_, i) => {
867
+ return [
868
+ topLogits[i][0], // token id
869
+ Math.log(probabilities[i]), // score
870
+ ];
871
+ });
872
+ }
873
+ }