@weavelogic/knowledge-graph-agent 0.7.4 → 0.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/_virtual/__vite-browser-external.js +2 -2
- package/dist/_virtual/__vite-browser-external.js.map +1 -1
- package/dist/_virtual/browser.js +2 -3
- package/dist/_virtual/browser.js.map +1 -1
- package/dist/_virtual/index10.js +2 -4
- package/dist/_virtual/index10.js.map +1 -1
- package/dist/_virtual/index11.js +2 -2
- package/dist/cli/commands/hive-mind/add-frontmatter.js +2 -2
- package/dist/cli/commands/hive-mind/add-frontmatter.js.map +1 -1
- package/dist/cli/commands/hive-mind/analyze-links.js +2 -2
- package/dist/cli/commands/hive-mind/analyze-links.js.map +1 -1
- package/dist/cli/commands/hive-mind/find-connections.js +2 -2
- package/dist/cli/commands/hive-mind/find-connections.js.map +1 -1
- package/dist/cli/commands/hive-mind/validate-names.js +2 -2
- package/dist/cli/commands/hive-mind/validate-names.js.map +1 -1
- package/dist/graphql/server.js +2 -2
- package/dist/graphql/server.js.map +1 -1
- package/dist/mcp-server/tools/audit/index.d.ts +4 -0
- package/dist/mcp-server/tools/audit/index.d.ts.map +1 -1
- package/dist/node_modules/@typescript-eslint/project-service/dist/index.js +1 -1
- package/dist/node_modules/debug/src/browser.js +1 -1
- package/dist/node_modules/fdir/dist/index.js +14 -14
- package/dist/node_modules/fdir/dist/index.js.map +1 -1
- package/dist/node_modules/tinyglobby/dist/index.js +14 -14
- package/dist/node_modules/tinyglobby/dist/index.js.map +1 -1
- package/dist/node_modules/ts-api-utils/lib/index.js +1 -1
- package/dist/node_modules/typescript/lib/typescript.js +24 -24
- package/dist/node_modules/typescript/lib/typescript.js.map +1 -1
- package/dist/vector/services/embedding-service.js +1 -7
- package/dist/vector/services/embedding-service.js.map +1 -1
- package/package.json +2 -1
- package/dist/_virtual/browser2.js +0 -5
- package/dist/_virtual/browser2.js.map +0 -1
- package/dist/_virtual/index12.js +0 -5
- package/dist/_virtual/index12.js.map +0 -1
- package/dist/_virtual/ort-web.min.js +0 -8
- package/dist/_virtual/ort-web.min.js.map +0 -1
- package/dist/_virtual/ort-web.min2.js +0 -5
- package/dist/_virtual/ort-web.min2.js.map +0 -1
- package/dist/node_modules/@huggingface/jinja/dist/index.js +0 -118
- package/dist/node_modules/@huggingface/jinja/dist/index.js.map +0 -1
- package/dist/node_modules/@xenova/transformers/src/backends/onnx.js +0 -24
- package/dist/node_modules/@xenova/transformers/src/backends/onnx.js.map +0 -1
- package/dist/node_modules/@xenova/transformers/src/configs.js +0 -52
- package/dist/node_modules/@xenova/transformers/src/configs.js.map +0 -1
- package/dist/node_modules/@xenova/transformers/src/env.js +0 -35
- package/dist/node_modules/@xenova/transformers/src/env.js.map +0 -1
- package/dist/node_modules/@xenova/transformers/src/models.js +0 -3852
- package/dist/node_modules/@xenova/transformers/src/models.js.map +0 -1
- package/dist/node_modules/@xenova/transformers/src/tokenizers.js +0 -144
- package/dist/node_modules/@xenova/transformers/src/tokenizers.js.map +0 -1
- package/dist/node_modules/@xenova/transformers/src/utils/core.js +0 -52
- package/dist/node_modules/@xenova/transformers/src/utils/core.js.map +0 -1
- package/dist/node_modules/@xenova/transformers/src/utils/generation.js +0 -623
- package/dist/node_modules/@xenova/transformers/src/utils/generation.js.map +0 -1
- package/dist/node_modules/@xenova/transformers/src/utils/hub.js +0 -395
- package/dist/node_modules/@xenova/transformers/src/utils/hub.js.map +0 -1
- package/dist/node_modules/@xenova/transformers/src/utils/image.js +0 -12
- package/dist/node_modules/@xenova/transformers/src/utils/image.js.map +0 -1
- package/dist/node_modules/@xenova/transformers/src/utils/maths.js +0 -89
- package/dist/node_modules/@xenova/transformers/src/utils/maths.js.map +0 -1
- package/dist/node_modules/@xenova/transformers/src/utils/tensor.js +0 -750
- package/dist/node_modules/@xenova/transformers/src/utils/tensor.js.map +0 -1
- package/dist/node_modules/onnxruntime-common/dist/lib/backend-impl.js +0 -67
- package/dist/node_modules/onnxruntime-common/dist/lib/backend-impl.js.map +0 -1
- package/dist/node_modules/onnxruntime-common/dist/lib/env-impl.js +0 -24
- package/dist/node_modules/onnxruntime-common/dist/lib/env-impl.js.map +0 -1
- package/dist/node_modules/onnxruntime-common/dist/lib/env.js +0 -6
- package/dist/node_modules/onnxruntime-common/dist/lib/env.js.map +0 -1
- package/dist/node_modules/onnxruntime-common/dist/lib/index.js +0 -11
- package/dist/node_modules/onnxruntime-common/dist/lib/index.js.map +0 -1
- package/dist/node_modules/onnxruntime-common/dist/lib/inference-session-impl.js +0 -162
- package/dist/node_modules/onnxruntime-common/dist/lib/inference-session-impl.js.map +0 -1
- package/dist/node_modules/onnxruntime-common/dist/lib/inference-session.js +0 -6
- package/dist/node_modules/onnxruntime-common/dist/lib/inference-session.js.map +0 -1
- package/dist/node_modules/onnxruntime-common/dist/lib/tensor-impl.js +0 -393
- package/dist/node_modules/onnxruntime-common/dist/lib/tensor-impl.js.map +0 -1
- package/dist/node_modules/onnxruntime-common/dist/lib/tensor.js +0 -6
- package/dist/node_modules/onnxruntime-common/dist/lib/tensor.js.map +0 -1
- package/dist/node_modules/onnxruntime-web/dist/ort-web.min.js +0 -12919
- package/dist/node_modules/onnxruntime-web/dist/ort-web.min.js.map +0 -1
- package/dist/node_modules/ws/browser.js +0 -16
- package/dist/node_modules/ws/browser.js.map +0 -1
|
@@ -1,52 +0,0 @@
|
|
|
1
|
-
function dispatchCallback(progress_callback, data) {
|
|
2
|
-
if (progress_callback) progress_callback(data);
|
|
3
|
-
}
|
|
4
|
-
function reverseDictionary(data) {
|
|
5
|
-
return Object.fromEntries(Object.entries(data).map(([key, value]) => [value, key]));
|
|
6
|
-
}
|
|
7
|
-
const Callable = (
|
|
8
|
-
/** @type {any} */
|
|
9
|
-
class {
|
|
10
|
-
/**
|
|
11
|
-
* Creates a new instance of the Callable class.
|
|
12
|
-
*/
|
|
13
|
-
constructor() {
|
|
14
|
-
let closure = function(...args) {
|
|
15
|
-
return closure._call(...args);
|
|
16
|
-
};
|
|
17
|
-
return Object.setPrototypeOf(closure, new.target.prototype);
|
|
18
|
-
}
|
|
19
|
-
/**
|
|
20
|
-
* This method should be implemented in subclasses to provide the
|
|
21
|
-
* functionality of the callable object.
|
|
22
|
-
*
|
|
23
|
-
* @param {any[]} args
|
|
24
|
-
* @throws {Error} If the subclass does not implement the `_call` method.
|
|
25
|
-
*/
|
|
26
|
-
_call(...args) {
|
|
27
|
-
throw Error("Must implement _call method in subclass");
|
|
28
|
-
}
|
|
29
|
-
}
|
|
30
|
-
);
|
|
31
|
-
function isTypedArray(val) {
|
|
32
|
-
return val?.prototype?.__proto__?.constructor?.name === "TypedArray";
|
|
33
|
-
}
|
|
34
|
-
function isIntegralNumber(x) {
|
|
35
|
-
return Number.isInteger(x) || typeof x === "bigint";
|
|
36
|
-
}
|
|
37
|
-
function exists(x) {
|
|
38
|
-
return x !== void 0 && x !== null;
|
|
39
|
-
}
|
|
40
|
-
function mergeArrays(...arrs) {
|
|
41
|
-
return Array.prototype.concat.apply([], arrs);
|
|
42
|
-
}
|
|
43
|
-
export {
|
|
44
|
-
Callable,
|
|
45
|
-
dispatchCallback,
|
|
46
|
-
exists,
|
|
47
|
-
isIntegralNumber,
|
|
48
|
-
isTypedArray,
|
|
49
|
-
mergeArrays,
|
|
50
|
-
reverseDictionary
|
|
51
|
-
};
|
|
52
|
-
//# sourceMappingURL=core.js.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"core.js","sources":["../../../../../../node_modules/@xenova/transformers/src/utils/core.js"],"sourcesContent":["\n/**\n * @file Core utility functions/classes for Transformers.js.\n * \n * These are only used internally, meaning an end-user shouldn't\n * need to access anything here.\n * \n * @module utils/core\n */\n\n/**\n * Helper function to dispatch progress callbacks.\n *\n * @param {Function} progress_callback The progress callback function to dispatch.\n * @param {any} data The data to pass to the progress callback function.\n * @returns {void}\n * @private\n */\nexport function dispatchCallback(progress_callback, data) {\n if (progress_callback) progress_callback(data);\n}\n\n/**\n * Reverses the keys and values of an object.\n *\n * @param {Object} data The object to reverse.\n * @returns {Object} The reversed object.\n * @see https://ultimatecourses.com/blog/reverse-object-keys-and-values-in-javascript\n */\nexport function reverseDictionary(data) {\n // https://ultimatecourses.com/blog/reverse-object-keys-and-values-in-javascript\n return Object.fromEntries(Object.entries(data).map(([key, value]) => [value, key]));\n}\n\n/**\n * Escapes regular expression special characters from a string by replacing them with their escaped counterparts.\n *\n * @param {string} string The string to escape.\n * @returns {string} The escaped string.\n */\nexport function escapeRegExp(string) {\n return string.replace(/[.*+?^${}()|[\\]\\\\]/g, '\\\\$&'); // $& means the whole matched string\n}\n\n/**\n * A base class for creating callable objects.\n * \n * @type {new () => {(...args: any[]): any, _call(...args: any[]): any}}\n */\nexport const Callable = /** @type {any} */ (class {\n /**\n * Creates a new instance of the Callable class.\n */\n constructor() {\n /**\n * Creates a closure that delegates to a private method '_call' with the given arguments.\n * @type {any}\n * @param {...any} args Zero or more arguments to pass to the '_call' method.\n * @returns {*} The result of calling the '_call' method.\n */\n let closure = function (...args) {\n return closure._call(...args)\n }\n return Object.setPrototypeOf(closure, new.target.prototype)\n }\n\n /**\n * This method should be implemented in subclasses to provide the\n * functionality of the callable object.\n *\n * @param {any[]} args\n * @throws {Error} If the subclass does not implement the `_call` method.\n */\n _call(...args) {\n throw Error('Must implement _call method in subclass')\n }\n});\n\n/**\n * Check if a value is a typed array.\n * @param {*} val The value to check.\n * @returns {boolean} True if the value is a `TypedArray`, false otherwise.\n * \n * Adapted from https://stackoverflow.com/a/71091338/13989043\n */\nexport function isTypedArray(val) {\n return val?.prototype?.__proto__?.constructor?.name === 'TypedArray';\n}\n\n\n/**\n * Check if a value is an integer.\n * @param {*} x The value to check.\n * @returns {boolean} True if the value is a string, false otherwise.\n */\nexport function isIntegralNumber(x) {\n return Number.isInteger(x) || typeof x === 'bigint'\n}\n\n/**\n * Check if a value is exists.\n * @param {*} x The value to check.\n * @returns {boolean} True if the value exists, false otherwise.\n */\nexport function exists(x) {\n return x !== undefined && x !== null;\n}\n\n/**\n * Calculates the dimensions of a nested array.\n *\n * @param {any[]} arr The nested array to calculate dimensions for.\n * @returns {number[]} An array containing the dimensions of the input array.\n */\nexport function calculateDimensions(arr) {\n const dimensions = [];\n let current = arr;\n while (Array.isArray(current)) {\n dimensions.push(current.length);\n current = current[0];\n }\n return dimensions;\n}\n\n/**\n * Replicate python's .pop() method for objects.\n * @param {Object} obj The object to pop from.\n * @param {string} key The key to pop.\n * @param {*} defaultValue The default value to return if the key does not exist.\n * @returns {*} The value of the popped key.\n * @throws {Error} If the key does not exist and no default value is provided.\n */\nexport function pop(obj, key, defaultValue = undefined) {\n const value = obj[key];\n if (value !== undefined) {\n delete obj[key];\n return value;\n }\n if (defaultValue === undefined) {\n throw Error(`Key ${key} does not exist in object.`)\n }\n return defaultValue;\n}\n\n/**\n * Efficiently merge arrays, creating a new copy.\n * Adapted from https://stackoverflow.com/a/6768642/13989043\n * @param {Array[]} arrs Arrays to merge.\n * @returns {Array} The merged array.\n */\nexport function mergeArrays(...arrs) {\n return Array.prototype.concat.apply([], arrs);\n}\n\n/**\n * Compute the Cartesian product of given arrays\n * @param {...Array} a Arrays to compute the product\n * @returns {Array} Returns the computed Cartesian product as an array\n * @private\n */\nexport function product(...a) {\n // Cartesian product of items\n // Adapted from https://stackoverflow.com/a/43053803\n return a.reduce((a, b) => a.flatMap(d => b.map(e => [d, e])));\n}\n\n/**\n * Calculates the index offset for a given index and window size.\n * @param {number} i The index.\n * @param {number} w The window size.\n * @returns {number} The index offset.\n */\nexport function calculateReflectOffset(i, w) {\n return Math.abs((i + w) % (2 * w) - w);\n}\n"],"names":[],"mappings":"AAkBO,SAAS,iBAAiB,mBAAmB,MAAM;AACtD,MAAI,kBAAmB,mBAAkB,IAAI;AACjD;AASO,SAAS,kBAAkB,MAAM;AAEpC,SAAO,OAAO,YAAY,OAAO,QAAQ,IAAI,EAAE,IAAI,CAAC,CAAC,KAAK,KAAK,MAAM,CAAC,OAAO,GAAG,CAAC,CAAC;AACtF;AAiBY,MAAC;AAAA;AAAA,EAA+B,MAAM;AAAA;AAAA;AAAA;AAAA,IAI9C,cAAc;AAOV,UAAI,UAAU,YAAa,MAAM;AAC7B,eAAO,QAAQ,MAAM,GAAG,IAAI;AAAA,MAChC;AACA,aAAO,OAAO,eAAe,SAAS,WAAW,SAAS;AAAA,IAC9D;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA,IASA,SAAS,MAAM;AACX,YAAM,MAAM,yCAAyC;AAAA,IACzD;AAAA,EACJ;AAAA;AASO,SAAS,aAAa,KAAK;AAC9B,SAAO,KAAK,WAAW,WAAW,aAAa,SAAS;AAC5D;AAQO,SAAS,iBAAiB,GAAG;AAChC,SAAO,OAAO,UAAU,CAAC,KAAK,OAAO,MAAM;AAC/C;AAOO,SAAS,OAAO,GAAG;AACtB,SAAO,MAAM,UAAa,MAAM;AACpC;AA4CO,SAAS,eAAe,MAAM;AACjC,SAAO,MAAM,UAAU,OAAO,MAAM,CAAA,GAAI,IAAI;AAChD;","x_google_ignoreList":[0]}
|
|
@@ -1,623 +0,0 @@
|
|
|
1
|
-
import "./tensor.js";
|
|
2
|
-
import { Callable, exists } from "./core.js";
|
|
3
|
-
import { log_softmax, max, getTopItems, softmax } from "./maths.js";
|
|
4
|
-
class LogitsProcessorList extends Callable {
|
|
5
|
-
/**
|
|
6
|
-
* Constructs a new instance of `LogitsProcessorList`.
|
|
7
|
-
*/
|
|
8
|
-
constructor() {
|
|
9
|
-
super();
|
|
10
|
-
this.processors = [];
|
|
11
|
-
}
|
|
12
|
-
/**
|
|
13
|
-
* Adds a new logits processor to the list.
|
|
14
|
-
*
|
|
15
|
-
* @param {LogitsProcessor} item The logits processor function to add.
|
|
16
|
-
*/
|
|
17
|
-
push(item) {
|
|
18
|
-
this.processors.push(item);
|
|
19
|
-
}
|
|
20
|
-
/**
|
|
21
|
-
* Adds multiple logits processors to the list.
|
|
22
|
-
*
|
|
23
|
-
* @param {LogitsProcessor[]} items The logits processor functions to add.
|
|
24
|
-
*/
|
|
25
|
-
extend(items) {
|
|
26
|
-
this.processors.push(...items);
|
|
27
|
-
}
|
|
28
|
-
/**
|
|
29
|
-
* Applies all logits processors in the list to a batch of logits, modifying them in-place.
|
|
30
|
-
*
|
|
31
|
-
* @param {number[]} input_ids The input IDs for the language model.
|
|
32
|
-
* @param {number[][]} batchedLogits A 2D array of logits, where each row corresponds to a single
|
|
33
|
-
* input sequence in the batch.
|
|
34
|
-
*/
|
|
35
|
-
_call(input_ids, batchedLogits) {
|
|
36
|
-
for (let logits of batchedLogits) {
|
|
37
|
-
this.processors.forEach(
|
|
38
|
-
(func) => func(input_ids, logits)
|
|
39
|
-
);
|
|
40
|
-
}
|
|
41
|
-
}
|
|
42
|
-
[Symbol.iterator]() {
|
|
43
|
-
return this.processors.values();
|
|
44
|
-
}
|
|
45
|
-
}
|
|
46
|
-
class LogitsProcessor extends Callable {
|
|
47
|
-
/**
|
|
48
|
-
* Apply the processor to the input logits.
|
|
49
|
-
*
|
|
50
|
-
* @abstract
|
|
51
|
-
* @param {Array} input_ids The input ids.
|
|
52
|
-
* @param {Tensor} logits The logits to process.
|
|
53
|
-
* @throws {Error} Throws an error if `_call` is not implemented in the subclass.
|
|
54
|
-
*/
|
|
55
|
-
_call(input_ids, logits) {
|
|
56
|
-
throw Error("`_call` should be implemented in a subclass");
|
|
57
|
-
}
|
|
58
|
-
}
|
|
59
|
-
class ForceTokensLogitsProcessor extends LogitsProcessor {
|
|
60
|
-
/**
|
|
61
|
-
* Constructs a new instance of `ForceTokensLogitsProcessor`.
|
|
62
|
-
*
|
|
63
|
-
* @param {Array} forced_decoder_ids The ids of tokens that should be forced.
|
|
64
|
-
*/
|
|
65
|
-
constructor(forced_decoder_ids) {
|
|
66
|
-
super();
|
|
67
|
-
this.force_token_map = Object.fromEntries(forced_decoder_ids ?? []);
|
|
68
|
-
}
|
|
69
|
-
/**
|
|
70
|
-
* Apply the processor to the input logits.
|
|
71
|
-
*
|
|
72
|
-
* @param {Array} input_ids The input ids.
|
|
73
|
-
* @param {Tensor} logits The logits to process.
|
|
74
|
-
* @returns {Tensor} The processed logits.
|
|
75
|
-
*/
|
|
76
|
-
_call(input_ids, logits) {
|
|
77
|
-
let map = this.force_token_map[input_ids.length];
|
|
78
|
-
if (exists(map)) {
|
|
79
|
-
logits.data.fill(-Infinity);
|
|
80
|
-
logits.data[map] = 0;
|
|
81
|
-
}
|
|
82
|
-
return logits;
|
|
83
|
-
}
|
|
84
|
-
}
|
|
85
|
-
class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
|
|
86
|
-
/**
|
|
87
|
-
* Create a ForcedBOSTokenLogitsProcessor.
|
|
88
|
-
* @param {number} bos_token_id The ID of the beginning-of-sequence token to be forced.
|
|
89
|
-
*/
|
|
90
|
-
constructor(bos_token_id) {
|
|
91
|
-
super();
|
|
92
|
-
this.bos_token_id = bos_token_id;
|
|
93
|
-
}
|
|
94
|
-
/**
|
|
95
|
-
* Apply the BOS token forcing to the logits.
|
|
96
|
-
* @param {Array} input_ids The input IDs.
|
|
97
|
-
* @param {Object} logits The logits.
|
|
98
|
-
* @returns {Object} The logits with BOS token forcing.
|
|
99
|
-
*/
|
|
100
|
-
_call(input_ids, logits) {
|
|
101
|
-
if (input_ids.length === 1) {
|
|
102
|
-
logits.data.fill(-Infinity);
|
|
103
|
-
logits.data[this.bos_token_id] = 0;
|
|
104
|
-
}
|
|
105
|
-
return logits;
|
|
106
|
-
}
|
|
107
|
-
}
|
|
108
|
-
class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
|
|
109
|
-
/**
|
|
110
|
-
* Create a ForcedEOSTokenLogitsProcessor.
|
|
111
|
-
* @param {number} max_length Max length of the sequence.
|
|
112
|
-
* @param {number|number[]} forced_eos_token_id The ID of the end-of-sequence token to be forced.
|
|
113
|
-
*/
|
|
114
|
-
constructor(max_length, forced_eos_token_id) {
|
|
115
|
-
super();
|
|
116
|
-
this.max_length = max_length;
|
|
117
|
-
this.forced_eos_token_id = forced_eos_token_id;
|
|
118
|
-
}
|
|
119
|
-
/**
|
|
120
|
-
* Apply the processor to input_ids and logits.
|
|
121
|
-
*
|
|
122
|
-
* @param {number[]} input_ids The input ids.
|
|
123
|
-
* @param {Tensor} logits The logits tensor.
|
|
124
|
-
*/
|
|
125
|
-
_call(input_ids, logits) {
|
|
126
|
-
}
|
|
127
|
-
}
|
|
128
|
-
class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor {
|
|
129
|
-
/**
|
|
130
|
-
* Create a SuppressTokensAtBeginLogitsProcessor.
|
|
131
|
-
* @param {number[]} begin_suppress_tokens The IDs of the tokens to suppress.
|
|
132
|
-
* @param {number} begin_index The number of tokens to generate before suppressing tokens.
|
|
133
|
-
*/
|
|
134
|
-
constructor(begin_suppress_tokens, begin_index) {
|
|
135
|
-
super();
|
|
136
|
-
this.begin_suppress_tokens = begin_suppress_tokens;
|
|
137
|
-
this.begin_index = begin_index;
|
|
138
|
-
}
|
|
139
|
-
/**
|
|
140
|
-
* Apply the BOS token forcing to the logits.
|
|
141
|
-
* @param {Array} input_ids The input IDs.
|
|
142
|
-
* @param {Object} logits The logits.
|
|
143
|
-
* @returns {Object} The logits with BOS token forcing.
|
|
144
|
-
*/
|
|
145
|
-
_call(input_ids, logits) {
|
|
146
|
-
if (input_ids.length === this.begin_index) {
|
|
147
|
-
for (let token_id of this.begin_suppress_tokens) {
|
|
148
|
-
logits.data[token_id] = -Infinity;
|
|
149
|
-
}
|
|
150
|
-
}
|
|
151
|
-
return logits;
|
|
152
|
-
}
|
|
153
|
-
}
|
|
154
|
-
class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
|
|
155
|
-
/**
|
|
156
|
-
* Constructs a new WhisperTimeStampLogitsProcessor.
|
|
157
|
-
* @param {Object} generate_config The config object passed to the `generate()` method of a transformer model.
|
|
158
|
-
* @param {number} generate_config.eos_token_id The ID of the end-of-sequence token.
|
|
159
|
-
* @param {number} generate_config.no_timestamps_token_id The ID of the token used to indicate that a token should not have a timestamp.
|
|
160
|
-
* @param {number[][]} [generate_config.forced_decoder_ids] An array of two-element arrays representing decoder IDs that are forced to appear in the output. The second element of each array indicates whether the token is a timestamp.
|
|
161
|
-
* @param {number} [generate_config.max_initial_timestamp_index] The maximum index at which an initial timestamp can appear.
|
|
162
|
-
*/
|
|
163
|
-
constructor(generate_config) {
|
|
164
|
-
super();
|
|
165
|
-
this.eos_token_id = generate_config.eos_token_id;
|
|
166
|
-
this.no_timestamps_token_id = generate_config.no_timestamps_token_id;
|
|
167
|
-
this.timestamp_begin = this.no_timestamps_token_id + 1;
|
|
168
|
-
this.begin_index = (generate_config.forced_decoder_ids || []).length + 2;
|
|
169
|
-
if (generate_config.forced_decoder_ids.slice(-1)[0][1] === this.no_timestamps_token_id) {
|
|
170
|
-
this.begin_index -= 1;
|
|
171
|
-
}
|
|
172
|
-
this.max_initial_timestamp_index = generate_config.max_initial_timestamp_index;
|
|
173
|
-
}
|
|
174
|
-
/**
|
|
175
|
-
* Modify the logits to handle timestamp tokens.
|
|
176
|
-
* @param {Array} input_ids The input sequence of tokens.
|
|
177
|
-
* @param {Tensor} logits The logits output by the model.
|
|
178
|
-
* @returns {Tensor} The modified logits.
|
|
179
|
-
*/
|
|
180
|
-
_call(input_ids, logits) {
|
|
181
|
-
const logitsData = (
|
|
182
|
-
/** @type {Float32Array} */
|
|
183
|
-
logits.data
|
|
184
|
-
);
|
|
185
|
-
logitsData[this.no_timestamps_token_id] = -Infinity;
|
|
186
|
-
if (input_ids.length === this.begin_index - 1) {
|
|
187
|
-
logitsData.fill(-Infinity);
|
|
188
|
-
logitsData[this.timestamp_begin] = 0;
|
|
189
|
-
return logits;
|
|
190
|
-
}
|
|
191
|
-
const seq = input_ids.slice(this.begin_index);
|
|
192
|
-
const last_was_timestamp = seq.length >= 1 && seq[seq.length - 1] >= this.timestamp_begin;
|
|
193
|
-
const penultimate_was_timestamp = seq.length < 2 || seq[seq.length - 2] >= this.timestamp_begin;
|
|
194
|
-
if (last_was_timestamp) {
|
|
195
|
-
if (penultimate_was_timestamp) {
|
|
196
|
-
logitsData.subarray(this.timestamp_begin).fill(-Infinity);
|
|
197
|
-
} else {
|
|
198
|
-
logitsData.subarray(0, this.eos_token_id).fill(-Infinity);
|
|
199
|
-
}
|
|
200
|
-
}
|
|
201
|
-
if (input_ids.length === this.begin_index && this.max_initial_timestamp_index !== null) {
|
|
202
|
-
const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index;
|
|
203
|
-
logitsData.subarray(last_allowed + 1).fill(-Infinity);
|
|
204
|
-
}
|
|
205
|
-
const logprobs = log_softmax(logitsData);
|
|
206
|
-
const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b));
|
|
207
|
-
const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0];
|
|
208
|
-
if (timestamp_logprob > max_text_token_logprob) {
|
|
209
|
-
logitsData.subarray(0, this.timestamp_begin).fill(-Infinity);
|
|
210
|
-
}
|
|
211
|
-
return logits;
|
|
212
|
-
}
|
|
213
|
-
}
|
|
214
|
-
class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
|
|
215
|
-
/**
|
|
216
|
-
* Create a NoRepeatNGramLogitsProcessor.
|
|
217
|
-
* @param {number} no_repeat_ngram_size The no-repeat-ngram size. All ngrams of this size can only occur once.
|
|
218
|
-
*/
|
|
219
|
-
constructor(no_repeat_ngram_size) {
|
|
220
|
-
super();
|
|
221
|
-
this.no_repeat_ngram_size = no_repeat_ngram_size;
|
|
222
|
-
}
|
|
223
|
-
/**
|
|
224
|
-
* Generate n-grams from a sequence of token ids.
|
|
225
|
-
* @param {number[]} prevInputIds List of previous input ids
|
|
226
|
-
* @returns {Map<string, number[]>} Map of generated n-grams
|
|
227
|
-
*/
|
|
228
|
-
getNgrams(prevInputIds) {
|
|
229
|
-
const curLen = prevInputIds.length;
|
|
230
|
-
const ngrams = [];
|
|
231
|
-
for (let j = 0; j < curLen + 1 - this.no_repeat_ngram_size; ++j) {
|
|
232
|
-
const ngram = [];
|
|
233
|
-
for (let k = 0; k < this.no_repeat_ngram_size; ++k) {
|
|
234
|
-
ngram.push(prevInputIds[j + k]);
|
|
235
|
-
}
|
|
236
|
-
ngrams.push(ngram);
|
|
237
|
-
}
|
|
238
|
-
const generatedNgram = /* @__PURE__ */ new Map();
|
|
239
|
-
for (const ngram of ngrams) {
|
|
240
|
-
const prevNgram = ngram.slice(0, ngram.length - 1);
|
|
241
|
-
const prevNgramKey = JSON.stringify(prevNgram);
|
|
242
|
-
const prevNgramValue = generatedNgram.get(prevNgramKey) ?? [];
|
|
243
|
-
prevNgramValue.push(ngram[ngram.length - 1]);
|
|
244
|
-
generatedNgram.set(prevNgramKey, prevNgramValue);
|
|
245
|
-
}
|
|
246
|
-
return generatedNgram;
|
|
247
|
-
}
|
|
248
|
-
/**
|
|
249
|
-
* Generate n-grams from a sequence of token ids.
|
|
250
|
-
* @param {Map<string, number[]>} bannedNgrams Map of banned n-grams
|
|
251
|
-
* @param {number[]} prevInputIds List of previous input ids
|
|
252
|
-
* @returns {number[]} Map of generated n-grams
|
|
253
|
-
*/
|
|
254
|
-
getGeneratedNgrams(bannedNgrams, prevInputIds) {
|
|
255
|
-
const ngramIdx = prevInputIds.slice(prevInputIds.length + 1 - this.no_repeat_ngram_size, prevInputIds.length);
|
|
256
|
-
const banned = bannedNgrams.get(JSON.stringify(ngramIdx)) ?? [];
|
|
257
|
-
return banned;
|
|
258
|
-
}
|
|
259
|
-
/**
|
|
260
|
-
* Calculate banned n-gram tokens
|
|
261
|
-
* @param {number[]} prevInputIds List of previous input ids
|
|
262
|
-
* @returns {number[]} Map of generated n-grams
|
|
263
|
-
*/
|
|
264
|
-
calcBannedNgramTokens(prevInputIds) {
|
|
265
|
-
const bannedTokens = [];
|
|
266
|
-
if (prevInputIds.length + 1 < this.no_repeat_ngram_size) {
|
|
267
|
-
return bannedTokens;
|
|
268
|
-
} else {
|
|
269
|
-
const generatedNgrams = this.getNgrams(prevInputIds);
|
|
270
|
-
const bannedTokens2 = this.getGeneratedNgrams(generatedNgrams, prevInputIds);
|
|
271
|
-
return bannedTokens2;
|
|
272
|
-
}
|
|
273
|
-
}
|
|
274
|
-
/**
|
|
275
|
-
* Apply the no-repeat-ngram processor to the logits.
|
|
276
|
-
* @param {Array} input_ids The input IDs.
|
|
277
|
-
* @param {Object} logits The logits.
|
|
278
|
-
* @returns {Object} The logits with no-repeat-ngram processing.
|
|
279
|
-
*/
|
|
280
|
-
_call(input_ids, logits) {
|
|
281
|
-
const bannedTokens = this.calcBannedNgramTokens(input_ids);
|
|
282
|
-
for (const token of bannedTokens) {
|
|
283
|
-
logits.data[token] = -Infinity;
|
|
284
|
-
}
|
|
285
|
-
return logits;
|
|
286
|
-
}
|
|
287
|
-
}
|
|
288
|
-
class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
|
|
289
|
-
/**
|
|
290
|
-
* Create a RepetitionPenaltyLogitsProcessor.
|
|
291
|
-
* @param {number} penalty The penalty to apply for repeated tokens.
|
|
292
|
-
*/
|
|
293
|
-
constructor(penalty) {
|
|
294
|
-
super();
|
|
295
|
-
this.penalty = penalty;
|
|
296
|
-
}
|
|
297
|
-
/**
|
|
298
|
-
* Apply the repetition penalty to the logits.
|
|
299
|
-
* @param {Array} input_ids The input IDs.
|
|
300
|
-
* @param {Object} logits The logits.
|
|
301
|
-
* @returns {Object} The logits with repetition penalty processing.
|
|
302
|
-
*/
|
|
303
|
-
_call(input_ids, logits) {
|
|
304
|
-
for (const input_id of input_ids) {
|
|
305
|
-
if (logits.data[input_id] < 0) {
|
|
306
|
-
logits.data[input_id] *= this.penalty;
|
|
307
|
-
} else {
|
|
308
|
-
logits.data[input_id] /= this.penalty;
|
|
309
|
-
}
|
|
310
|
-
}
|
|
311
|
-
return logits;
|
|
312
|
-
}
|
|
313
|
-
}
|
|
314
|
-
class MinLengthLogitsProcessor extends LogitsProcessor {
|
|
315
|
-
/**
|
|
316
|
-
* Create a MinLengthLogitsProcessor.
|
|
317
|
-
* @param {number} min_length The minimum length below which the score of `eos_token_id` is set to negative infinity.
|
|
318
|
-
* @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
|
|
319
|
-
*/
|
|
320
|
-
constructor(min_length, eos_token_id) {
|
|
321
|
-
super();
|
|
322
|
-
this.min_length = min_length;
|
|
323
|
-
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
|
|
324
|
-
}
|
|
325
|
-
/**
|
|
326
|
-
* Apply logit processor.
|
|
327
|
-
* @param {Array} input_ids The input IDs.
|
|
328
|
-
* @param {Object} logits The logits.
|
|
329
|
-
* @returns {Object} The processed logits.
|
|
330
|
-
*/
|
|
331
|
-
_call(input_ids, logits) {
|
|
332
|
-
if (input_ids.length < this.min_length) {
|
|
333
|
-
for (const eos_token of this.eos_token_id) {
|
|
334
|
-
logits.data[eos_token] = -Infinity;
|
|
335
|
-
}
|
|
336
|
-
}
|
|
337
|
-
return logits;
|
|
338
|
-
}
|
|
339
|
-
}
|
|
340
|
-
class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
|
|
341
|
-
/**
|
|
342
|
-
* Create a MinNewTokensLengthLogitsProcessor.
|
|
343
|
-
* @param {number} prompt_length_to_skip The input tokens length.
|
|
344
|
-
* @param {number} min_new_tokens The minimum *new* tokens length below which the score of `eos_token_id` is set to negative infinity.
|
|
345
|
-
* @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
|
|
346
|
-
*/
|
|
347
|
-
constructor(prompt_length_to_skip, min_new_tokens, eos_token_id) {
|
|
348
|
-
super();
|
|
349
|
-
this.prompt_length_to_skip = prompt_length_to_skip;
|
|
350
|
-
this.min_new_tokens = min_new_tokens;
|
|
351
|
-
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
|
|
352
|
-
}
|
|
353
|
-
/**
|
|
354
|
-
* Apply logit processor.
|
|
355
|
-
* @param {Array} input_ids The input IDs.
|
|
356
|
-
* @param {Object} logits The logits.
|
|
357
|
-
* @returns {Object} The processed logits.
|
|
358
|
-
*/
|
|
359
|
-
_call(input_ids, logits) {
|
|
360
|
-
const new_tokens_length = input_ids.length - this.prompt_length_to_skip;
|
|
361
|
-
if (new_tokens_length < this.min_new_tokens) {
|
|
362
|
-
for (const eos_token of this.eos_token_id) {
|
|
363
|
-
logits.data[eos_token] = -Infinity;
|
|
364
|
-
}
|
|
365
|
-
}
|
|
366
|
-
return logits;
|
|
367
|
-
}
|
|
368
|
-
}
|
|
369
|
-
class NoBadWordsLogitsProcessor extends LogitsProcessor {
|
|
370
|
-
/**
|
|
371
|
-
* Create a `NoBadWordsLogitsProcessor`.
|
|
372
|
-
* @param {number[][]} bad_words_ids List of list of token ids that are not allowed to be generated.
|
|
373
|
-
* @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
|
374
|
-
*/
|
|
375
|
-
constructor(bad_words_ids, eos_token_id) {
|
|
376
|
-
super();
|
|
377
|
-
this.bad_words_ids = bad_words_ids;
|
|
378
|
-
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
|
|
379
|
-
}
|
|
380
|
-
/**
|
|
381
|
-
* Apply logit processor.
|
|
382
|
-
* @param {Array} input_ids The input IDs.
|
|
383
|
-
* @param {Object} logits The logits.
|
|
384
|
-
* @returns {Object} The processed logits.
|
|
385
|
-
*/
|
|
386
|
-
_call(input_ids, logits) {
|
|
387
|
-
for (const bad_word_ids of this.bad_words_ids) {
|
|
388
|
-
let mark = true;
|
|
389
|
-
for (let i = 1; i <= bad_word_ids.length - 1 && bad_word_ids.length < input_ids.length; ++i) {
|
|
390
|
-
if (bad_word_ids.at(-i - 1) !== input_ids.at(-i)) {
|
|
391
|
-
mark = false;
|
|
392
|
-
break;
|
|
393
|
-
}
|
|
394
|
-
}
|
|
395
|
-
if (mark) {
|
|
396
|
-
logits.data[bad_word_ids.at(-1)] = -Infinity;
|
|
397
|
-
}
|
|
398
|
-
}
|
|
399
|
-
return logits;
|
|
400
|
-
}
|
|
401
|
-
}
|
|
402
|
-
const GenerationConfig = (
|
|
403
|
-
/** @type {any} */
|
|
404
|
-
class {
|
|
405
|
-
/**
|
|
406
|
-
* Create a new GenerationConfig object.
|
|
407
|
-
* @param {GenerationConfigType} kwargs
|
|
408
|
-
*/
|
|
409
|
-
constructor(kwargs = {}) {
|
|
410
|
-
this.max_length = kwargs.max_length ?? 20;
|
|
411
|
-
this.max_new_tokens = kwargs.max_new_tokens ?? null;
|
|
412
|
-
this.min_length = kwargs.min_length ?? 0;
|
|
413
|
-
this.min_new_tokens = kwargs.min_new_tokens ?? null;
|
|
414
|
-
this.early_stopping = kwargs.early_stopping ?? false;
|
|
415
|
-
this.max_time = kwargs.max_time ?? null;
|
|
416
|
-
this.do_sample = kwargs.do_sample ?? false;
|
|
417
|
-
this.num_beams = kwargs.num_beams ?? 1;
|
|
418
|
-
this.num_beam_groups = kwargs.num_beam_groups ?? 1;
|
|
419
|
-
this.penalty_alpha = kwargs.penalty_alpha ?? null;
|
|
420
|
-
this.use_cache = kwargs.use_cache ?? true;
|
|
421
|
-
this.temperature = kwargs.temperature ?? 1;
|
|
422
|
-
this.top_k = kwargs.top_k ?? 50;
|
|
423
|
-
this.top_p = kwargs.top_p ?? 1;
|
|
424
|
-
this.typical_p = kwargs.typical_p ?? 1;
|
|
425
|
-
this.epsilon_cutoff = kwargs.epsilon_cutoff ?? 0;
|
|
426
|
-
this.eta_cutoff = kwargs.eta_cutoff ?? 0;
|
|
427
|
-
this.diversity_penalty = kwargs.diversity_penalty ?? 0;
|
|
428
|
-
this.repetition_penalty = kwargs.repetition_penalty ?? 1;
|
|
429
|
-
this.encoder_repetition_penalty = kwargs.encoder_repetition_penalty ?? 1;
|
|
430
|
-
this.length_penalty = kwargs.length_penalty ?? 1;
|
|
431
|
-
this.no_repeat_ngram_size = kwargs.no_repeat_ngram_size ?? 0;
|
|
432
|
-
this.bad_words_ids = kwargs.bad_words_ids ?? null;
|
|
433
|
-
this.force_words_ids = kwargs.force_words_ids ?? null;
|
|
434
|
-
this.renormalize_logits = kwargs.renormalize_logits ?? false;
|
|
435
|
-
this.constraints = kwargs.constraints ?? null;
|
|
436
|
-
this.forced_bos_token_id = kwargs.forced_bos_token_id ?? null;
|
|
437
|
-
this.forced_eos_token_id = kwargs.forced_eos_token_id ?? null;
|
|
438
|
-
this.remove_invalid_values = kwargs.remove_invalid_values ?? false;
|
|
439
|
-
this.exponential_decay_length_penalty = kwargs.exponential_decay_length_penalty ?? null;
|
|
440
|
-
this.suppress_tokens = kwargs.suppress_tokens ?? null;
|
|
441
|
-
this.begin_suppress_tokens = kwargs.begin_suppress_tokens ?? null;
|
|
442
|
-
this.forced_decoder_ids = kwargs.forced_decoder_ids ?? null;
|
|
443
|
-
this.num_return_sequences = kwargs.num_return_sequences ?? 1;
|
|
444
|
-
this.output_attentions = kwargs.output_attentions ?? false;
|
|
445
|
-
this.output_hidden_states = kwargs.output_hidden_states ?? false;
|
|
446
|
-
this.output_scores = kwargs.output_scores ?? false;
|
|
447
|
-
this.return_dict_in_generate = kwargs.return_dict_in_generate ?? false;
|
|
448
|
-
this.pad_token_id = kwargs.pad_token_id ?? null;
|
|
449
|
-
this.bos_token_id = kwargs.bos_token_id ?? null;
|
|
450
|
-
this.eos_token_id = kwargs.eos_token_id ?? null;
|
|
451
|
-
this.encoder_no_repeat_ngram_size = kwargs.encoder_no_repeat_ngram_size ?? 0;
|
|
452
|
-
this.decoder_start_token_id = kwargs.decoder_start_token_id ?? null;
|
|
453
|
-
this.generation_kwargs = kwargs.generation_kwargs ?? {};
|
|
454
|
-
}
|
|
455
|
-
}
|
|
456
|
-
);
|
|
457
|
-
class Sampler extends Callable {
|
|
458
|
-
/**
|
|
459
|
-
* Creates a new Sampler object with the specified generation config.
|
|
460
|
-
* @param {GenerationConfigType} generation_config The generation config.
|
|
461
|
-
*/
|
|
462
|
-
constructor(generation_config) {
|
|
463
|
-
super();
|
|
464
|
-
this.generation_config = generation_config;
|
|
465
|
-
}
|
|
466
|
-
/**
|
|
467
|
-
* Executes the sampler, using the specified logits.
|
|
468
|
-
* @param {Tensor} logits
|
|
469
|
-
* @param {number} index
|
|
470
|
-
* @returns {void}
|
|
471
|
-
*/
|
|
472
|
-
_call(logits, index = -1) {
|
|
473
|
-
return this.sample(logits, index);
|
|
474
|
-
}
|
|
475
|
-
/**
|
|
476
|
-
* Abstract method for sampling the logits.
|
|
477
|
-
* @param {Tensor} logits
|
|
478
|
-
* @param {number} index
|
|
479
|
-
* @throws {Error}
|
|
480
|
-
*/
|
|
481
|
-
sample(logits, index) {
|
|
482
|
-
throw Error("sample should be implemented in subclasses.");
|
|
483
|
-
}
|
|
484
|
-
/**
|
|
485
|
-
* Returns the specified logits as an array, with temperature applied.
|
|
486
|
-
* @param {Tensor} logits
|
|
487
|
-
* @param {number} index
|
|
488
|
-
* @returns {Float32Array}
|
|
489
|
-
*/
|
|
490
|
-
getLogits(logits, index) {
|
|
491
|
-
let vocabSize = logits.dims.at(-1);
|
|
492
|
-
let logs = (
|
|
493
|
-
/** @type {Float32Array} */
|
|
494
|
-
logits.data
|
|
495
|
-
);
|
|
496
|
-
if (index === -1) {
|
|
497
|
-
logs = logs.slice(-vocabSize);
|
|
498
|
-
} else {
|
|
499
|
-
let startIndex = index * vocabSize;
|
|
500
|
-
logs = logs.slice(startIndex, startIndex + vocabSize);
|
|
501
|
-
}
|
|
502
|
-
if (this.generation_config.temperature > 0) {
|
|
503
|
-
logs = logs.map((x) => x / this.generation_config.temperature);
|
|
504
|
-
}
|
|
505
|
-
return logs;
|
|
506
|
-
}
|
|
507
|
-
/**
|
|
508
|
-
* Selects an item randomly based on the specified probabilities.
|
|
509
|
-
* @param {Array} probabilities An array of probabilities to use for selection.
|
|
510
|
-
* @returns {number} The index of the selected item.
|
|
511
|
-
*/
|
|
512
|
-
randomSelect(probabilities) {
|
|
513
|
-
let sumProbabilities = probabilities.reduce((acc, curr) => acc + curr, 0);
|
|
514
|
-
let r = Math.random() * sumProbabilities;
|
|
515
|
-
for (let i = 0; i < probabilities.length; ++i) {
|
|
516
|
-
r -= probabilities[i];
|
|
517
|
-
if (r <= 0) {
|
|
518
|
-
return i;
|
|
519
|
-
}
|
|
520
|
-
}
|
|
521
|
-
return 0;
|
|
522
|
-
}
|
|
523
|
-
/**
|
|
524
|
-
* Returns a Sampler object based on the specified options.
|
|
525
|
-
* @param {GenerationConfigType} generation_config An object containing options for the sampler.
|
|
526
|
-
* @returns {Sampler} A Sampler object.
|
|
527
|
-
*/
|
|
528
|
-
static getSampler(generation_config) {
|
|
529
|
-
if (generation_config.do_sample) {
|
|
530
|
-
return new MultinomialSampler(generation_config);
|
|
531
|
-
} else if (generation_config.num_beams > 1) {
|
|
532
|
-
return new BeamSearchSampler(generation_config);
|
|
533
|
-
} else {
|
|
534
|
-
if (generation_config.num_return_sequences > 1) {
|
|
535
|
-
throw Error(`num_return_sequences has to be 1 when doing greedy search, but is ${generation_config.num_return_sequences}.`);
|
|
536
|
-
}
|
|
537
|
-
return new GreedySampler(generation_config);
|
|
538
|
-
}
|
|
539
|
-
}
|
|
540
|
-
}
|
|
541
|
-
class GreedySampler extends Sampler {
|
|
542
|
-
/**
|
|
543
|
-
* Sample the maximum probability of a given logits tensor.
|
|
544
|
-
* @param {Tensor} logits
|
|
545
|
-
* @param {number} [index=-1]
|
|
546
|
-
* @returns {Array} An array with a single tuple, containing the index of the maximum value and a meaningless score (since this is a greedy search).
|
|
547
|
-
*/
|
|
548
|
-
sample(logits, index = -1) {
|
|
549
|
-
let logs = this.getLogits(logits, index);
|
|
550
|
-
let argmax = max(logs)[1];
|
|
551
|
-
return [
|
|
552
|
-
[argmax, 0]
|
|
553
|
-
];
|
|
554
|
-
}
|
|
555
|
-
}
|
|
556
|
-
class MultinomialSampler extends Sampler {
|
|
557
|
-
/**
|
|
558
|
-
* Sample from the logits.
|
|
559
|
-
* @param {Tensor} logits
|
|
560
|
-
* @param {number} index
|
|
561
|
-
* @returns {Array}
|
|
562
|
-
*/
|
|
563
|
-
sample(logits, index = -1) {
|
|
564
|
-
let k = logits.dims.at(-1);
|
|
565
|
-
if (this.generation_config.top_k > 0) {
|
|
566
|
-
k = Math.min(this.generation_config.top_k, k);
|
|
567
|
-
}
|
|
568
|
-
const logs = this.getLogits(logits, index);
|
|
569
|
-
const topLogits = getTopItems(logs, k);
|
|
570
|
-
const probabilities = softmax(topLogits.map((x) => x[1]));
|
|
571
|
-
return Array.from({ length: this.generation_config.num_beams }, () => {
|
|
572
|
-
const sampledIndex = this.randomSelect(probabilities);
|
|
573
|
-
return [
|
|
574
|
-
topLogits[sampledIndex][0],
|
|
575
|
-
// token id
|
|
576
|
-
Math.log(probabilities[sampledIndex])
|
|
577
|
-
// score
|
|
578
|
-
];
|
|
579
|
-
});
|
|
580
|
-
}
|
|
581
|
-
}
|
|
582
|
-
class BeamSearchSampler extends Sampler {
|
|
583
|
-
/**
|
|
584
|
-
* Sample from the logits.
|
|
585
|
-
* @param {Tensor} logits
|
|
586
|
-
* @param {number} index
|
|
587
|
-
* @returns {Array}
|
|
588
|
-
*/
|
|
589
|
-
sample(logits, index = -1) {
|
|
590
|
-
let k = logits.dims.at(-1);
|
|
591
|
-
if (this.generation_config.top_k > 0) {
|
|
592
|
-
k = Math.min(this.generation_config.top_k, k);
|
|
593
|
-
}
|
|
594
|
-
const logs = this.getLogits(logits, index);
|
|
595
|
-
const topLogits = getTopItems(logs, k);
|
|
596
|
-
const probabilities = softmax(topLogits.map((x) => x[1]));
|
|
597
|
-
return Array.from({ length: this.generation_config.num_beams }, (_, i) => {
|
|
598
|
-
return [
|
|
599
|
-
topLogits[i][0],
|
|
600
|
-
// token id
|
|
601
|
-
Math.log(probabilities[i])
|
|
602
|
-
// score
|
|
603
|
-
];
|
|
604
|
-
});
|
|
605
|
-
}
|
|
606
|
-
}
|
|
607
|
-
export {
|
|
608
|
-
ForceTokensLogitsProcessor,
|
|
609
|
-
ForcedBOSTokenLogitsProcessor,
|
|
610
|
-
ForcedEOSTokenLogitsProcessor,
|
|
611
|
-
GenerationConfig,
|
|
612
|
-
LogitsProcessor,
|
|
613
|
-
LogitsProcessorList,
|
|
614
|
-
MinLengthLogitsProcessor,
|
|
615
|
-
MinNewTokensLengthLogitsProcessor,
|
|
616
|
-
NoBadWordsLogitsProcessor,
|
|
617
|
-
NoRepeatNGramLogitsProcessor,
|
|
618
|
-
RepetitionPenaltyLogitsProcessor,
|
|
619
|
-
Sampler,
|
|
620
|
-
SuppressTokensAtBeginLogitsProcessor,
|
|
621
|
-
WhisperTimeStampLogitsProcessor
|
|
622
|
-
};
|
|
623
|
-
//# sourceMappingURL=generation.js.map
|