@huggingface/transformers 3.0.0-alpha.9 → 3.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +33 -22
- package/dist/ort-wasm-simd-threaded.jsep.wasm +0 -0
- package/dist/transformers.cjs +2515 -2525
- package/dist/transformers.cjs.map +1 -1
- package/dist/transformers.js +3529 -3455
- package/dist/transformers.js.map +1 -1
- package/dist/transformers.min.cjs +25 -25
- package/dist/transformers.min.cjs.map +1 -1
- package/dist/transformers.min.js +39 -40
- package/dist/transformers.min.js.map +1 -1
- package/dist/transformers.min.mjs +56 -57
- package/dist/transformers.min.mjs.map +1 -1
- package/dist/transformers.mjs +2551 -2538
- package/dist/transformers.mjs.map +1 -1
- package/package.json +14 -13
- package/src/backends/onnx.js +24 -19
- package/src/configs.js +19 -4
- package/src/env.js +5 -9
- package/src/generation/logits_process.js +40 -37
- package/src/models.js +326 -514
- package/src/ops/registry.js +14 -3
- package/src/pipelines.js +5 -4
- package/src/processors.js +390 -351
- package/src/tokenizers.js +140 -175
- package/src/utils/constants.js +1 -1
- package/src/utils/core.js +12 -0
- package/src/utils/data-structures.js +13 -11
- package/src/utils/hub.js +1 -1
- package/src/utils/maths.js +14 -5
- package/src/utils/tensor.js +60 -13
- package/types/backends/onnx.d.ts +5 -2
- package/types/backends/onnx.d.ts.map +1 -1
- package/types/configs.d.ts +29 -3
- package/types/configs.d.ts.map +1 -1
- package/types/env.d.ts +4 -2
- package/types/env.d.ts.map +1 -1
- package/types/generation/logits_process.d.ts.map +1 -1
- package/types/models.d.ts +116 -289
- package/types/models.d.ts.map +1 -1
- package/types/ops/registry.d.ts +6 -6
- package/types/ops/registry.d.ts.map +1 -1
- package/types/pipelines.d.ts +1 -2
- package/types/pipelines.d.ts.map +1 -1
- package/types/processors.d.ts +55 -51
- package/types/processors.d.ts.map +1 -1
- package/types/tokenizers.d.ts +23 -32
- package/types/tokenizers.d.ts.map +1 -1
- package/types/utils/constants.d.ts +1 -1
- package/types/utils/constants.d.ts.map +1 -1
- package/types/utils/core.d.ts +7 -0
- package/types/utils/core.d.ts.map +1 -1
- package/types/utils/data-structures.d.ts +6 -6
- package/types/utils/data-structures.d.ts.map +1 -1
- package/types/utils/hub.d.ts +1 -1
- package/types/utils/hub.d.ts.map +1 -1
- package/types/utils/maths.d.ts +2 -2
- package/types/utils/maths.d.ts.map +1 -1
- package/types/utils/tensor.d.ts +27 -1
- package/types/utils/tensor.d.ts.map +1 -1
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@huggingface/transformers",
|
|
3
|
-
"version": "3.0.0
|
|
3
|
+
"version": "3.0.0",
|
|
4
4
|
"description": "State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!",
|
|
5
5
|
"main": "./src/transformers.js",
|
|
6
6
|
"types": "./types/transformers.d.ts",
|
|
@@ -21,22 +21,27 @@
|
|
|
21
21
|
"default": "./dist/transformers.js"
|
|
22
22
|
}
|
|
23
23
|
},
|
|
24
|
+
"imports": {
|
|
25
|
+
"#onnxruntime-webgpu": {
|
|
26
|
+
"node": "onnxruntime-web",
|
|
27
|
+
"default": "onnxruntime-web/webgpu"
|
|
28
|
+
}
|
|
29
|
+
},
|
|
24
30
|
"scripts": {
|
|
25
31
|
"format": "prettier --write .",
|
|
26
32
|
"format:check": "prettier --check .",
|
|
27
33
|
"typegen": "tsc ./src/transformers.js --allowJs --declaration --emitDeclarationOnly --declarationMap --outDir types",
|
|
28
34
|
"dev": "webpack serve --no-client-overlay",
|
|
29
35
|
"build": "webpack && npm run typegen",
|
|
30
|
-
"
|
|
31
|
-
"test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --verbose --maxConcurrency 1",
|
|
36
|
+
"test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --verbose",
|
|
32
37
|
"readme": "python ./docs/scripts/build_readme.py",
|
|
33
38
|
"docs-api": "node ./docs/scripts/generate.js",
|
|
34
39
|
"docs-preview": "doc-builder preview transformers.js ./docs/source/ --not_python_module",
|
|
35
|
-
"docs-build": "doc-builder build transformers.js ./docs/source/ --not_python_module --build_dir ./docs/build/
|
|
40
|
+
"docs-build": "doc-builder build transformers.js ./docs/source/ --not_python_module --build_dir ./docs/build/"
|
|
36
41
|
},
|
|
37
42
|
"repository": {
|
|
38
43
|
"type": "git",
|
|
39
|
-
"url": "git+https://github.com/
|
|
44
|
+
"url": "git+https://github.com/huggingface/transformers.js.git"
|
|
40
45
|
},
|
|
41
46
|
"keywords": [
|
|
42
47
|
"transformers",
|
|
@@ -52,13 +57,13 @@
|
|
|
52
57
|
"author": "Hugging Face",
|
|
53
58
|
"license": "Apache-2.0",
|
|
54
59
|
"bugs": {
|
|
55
|
-
"url": "https://github.com/
|
|
60
|
+
"url": "https://github.com/huggingface/transformers.js/issues"
|
|
56
61
|
},
|
|
57
|
-
"homepage": "https://github.com/
|
|
62
|
+
"homepage": "https://github.com/huggingface/transformers.js#readme",
|
|
58
63
|
"dependencies": {
|
|
59
64
|
"@huggingface/jinja": "^0.3.0",
|
|
60
|
-
"onnxruntime-node": "1.19.
|
|
61
|
-
"onnxruntime-web": "1.20.0-dev.
|
|
65
|
+
"onnxruntime-node": "1.19.2",
|
|
66
|
+
"onnxruntime-web": "1.20.0-dev.20241016-2b8fc5529b",
|
|
62
67
|
"sharp": "^0.33.5"
|
|
63
68
|
},
|
|
64
69
|
"devDependencies": {
|
|
@@ -75,10 +80,6 @@
|
|
|
75
80
|
"webpack-cli": "^5.0.2",
|
|
76
81
|
"webpack-dev-server": "^4.13.3"
|
|
77
82
|
},
|
|
78
|
-
"overrides": {
|
|
79
|
-
"semver": "^7.6.3",
|
|
80
|
-
"protobufjs": "^7.2.6"
|
|
81
|
-
},
|
|
82
83
|
"files": [
|
|
83
84
|
"src",
|
|
84
85
|
"dist",
|
package/src/backends/onnx.js
CHANGED
|
@@ -21,7 +21,12 @@ import { env, apis } from '../env.js';
|
|
|
21
21
|
// NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`.
|
|
22
22
|
// In either case, we select the default export if it exists, otherwise we use the named export.
|
|
23
23
|
import * as ONNX_NODE from 'onnxruntime-node';
|
|
24
|
-
|
|
24
|
+
|
|
25
|
+
// Use subpath-imports to ensure Node.js and browser interoperability.
|
|
26
|
+
// See package.json and https://nodejs.org/api/packages.html#subpath-imports
|
|
27
|
+
// for more information.
|
|
28
|
+
// @ts-ignore
|
|
29
|
+
import * as ONNX_WEB from '#onnxruntime-webgpu';
|
|
25
30
|
|
|
26
31
|
export { Tensor } from 'onnxruntime-common';
|
|
27
32
|
|
|
@@ -54,7 +59,13 @@ const supportedDevices = [];
|
|
|
54
59
|
/** @type {ONNXExecutionProviders[]} */
|
|
55
60
|
let defaultDevices;
|
|
56
61
|
let ONNX;
|
|
57
|
-
|
|
62
|
+
const ORT_SYMBOL = Symbol.for('onnxruntime');
|
|
63
|
+
|
|
64
|
+
if (ORT_SYMBOL in globalThis) {
|
|
65
|
+
// If the JS runtime exposes their own ONNX runtime, use it
|
|
66
|
+
ONNX = globalThis[ORT_SYMBOL];
|
|
67
|
+
|
|
68
|
+
} else if (apis.IS_NODE_ENV) {
|
|
58
69
|
ONNX = ONNX_NODE.default ?? ONNX_NODE;
|
|
59
70
|
|
|
60
71
|
// Updated as of ONNX Runtime 1.18.0
|
|
@@ -112,7 +123,7 @@ export function deviceToExecutionProviders(device = null) {
|
|
|
112
123
|
case "auto":
|
|
113
124
|
return supportedDevices;
|
|
114
125
|
case "gpu":
|
|
115
|
-
return supportedDevices.filter(x =>
|
|
126
|
+
return supportedDevices.filter(x =>
|
|
116
127
|
["webgpu", "cuda", "dml", "webnn-gpu"].includes(x),
|
|
117
128
|
);
|
|
118
129
|
}
|
|
@@ -137,9 +148,10 @@ let wasmInitPromise = null;
|
|
|
137
148
|
* Create an ONNX inference session.
|
|
138
149
|
* @param {Uint8Array} buffer The ONNX model buffer.
|
|
139
150
|
* @param {import('onnxruntime-common').InferenceSession.SessionOptions} session_options ONNX inference session options.
|
|
140
|
-
* @
|
|
151
|
+
* @param {Object} session_config ONNX inference session configuration.
|
|
152
|
+
* @returns {Promise<import('onnxruntime-common').InferenceSession & { config: Object}>} The ONNX inference session.
|
|
141
153
|
*/
|
|
142
|
-
export async function createInferenceSession(buffer, session_options) {
|
|
154
|
+
export async function createInferenceSession(buffer, session_options, session_config) {
|
|
143
155
|
if (wasmInitPromise) {
|
|
144
156
|
// A previous session has already initialized the WASM runtime
|
|
145
157
|
// so we wait for it to resolve before creating this new session.
|
|
@@ -148,7 +160,9 @@ export async function createInferenceSession(buffer, session_options) {
|
|
|
148
160
|
|
|
149
161
|
const sessionPromise = InferenceSession.create(buffer, session_options);
|
|
150
162
|
wasmInitPromise ??= sessionPromise;
|
|
151
|
-
|
|
163
|
+
const session = await sessionPromise;
|
|
164
|
+
session.config = session_config;
|
|
165
|
+
return session;
|
|
152
166
|
}
|
|
153
167
|
|
|
154
168
|
/**
|
|
@@ -160,6 +174,7 @@ export function isONNXTensor(x) {
|
|
|
160
174
|
return x instanceof ONNX.Tensor;
|
|
161
175
|
}
|
|
162
176
|
|
|
177
|
+
/** @type {import('onnxruntime-common').Env} */
|
|
163
178
|
// @ts-ignore
|
|
164
179
|
const ONNX_ENV = ONNX?.env;
|
|
165
180
|
if (ONNX_ENV?.wasm) {
|
|
@@ -174,24 +189,14 @@ if (ONNX_ENV?.wasm) {
|
|
|
174
189
|
// TODO: Add support for loading WASM files from cached buffer when we upgrade to onnxruntime-web@1.19.0
|
|
175
190
|
// https://github.com/microsoft/onnxruntime/pull/21534
|
|
176
191
|
|
|
177
|
-
//
|
|
178
|
-
//
|
|
179
|
-
ONNX_ENV.wasm.proxy =
|
|
192
|
+
// Users may wish to proxy the WASM backend to prevent the UI from freezing,
|
|
193
|
+
// However, this is not necessary when using WebGPU, so we default to false.
|
|
194
|
+
ONNX_ENV.wasm.proxy = false;
|
|
180
195
|
|
|
181
196
|
// https://developer.mozilla.org/en-US/docs/Web/API/crossOriginIsolated
|
|
182
197
|
if (typeof crossOriginIsolated === 'undefined' || !crossOriginIsolated) {
|
|
183
198
|
ONNX_ENV.wasm.numThreads = 1;
|
|
184
199
|
}
|
|
185
|
-
|
|
186
|
-
// Running in a browser-environment
|
|
187
|
-
// TODO: Check if 1.17.1 fixes this issue.
|
|
188
|
-
// SIMD for WebAssembly does not operate correctly in some recent versions of iOS (16.4.x).
|
|
189
|
-
// As a temporary fix, we disable it for now.
|
|
190
|
-
// For more information, see: https://github.com/microsoft/onnxruntime/issues/15644
|
|
191
|
-
const isIOS = typeof navigator !== 'undefined' && /iP(hone|od|ad).+16_4.+AppleWebKit/.test(navigator.userAgent);
|
|
192
|
-
if (isIOS) {
|
|
193
|
-
ONNX_ENV.wasm.simd = false;
|
|
194
|
-
}
|
|
195
200
|
}
|
|
196
201
|
|
|
197
202
|
if (ONNX_ENV?.webgpu) {
|
package/src/configs.js
CHANGED
|
@@ -73,6 +73,7 @@ function getNormalizedConfig(config) {
|
|
|
73
73
|
// Decoder-only models
|
|
74
74
|
case 'gpt2':
|
|
75
75
|
case 'gptj':
|
|
76
|
+
case 'jais':
|
|
76
77
|
case 'codegen':
|
|
77
78
|
case 'gpt_bigcode':
|
|
78
79
|
mapping['num_heads'] = 'n_head';
|
|
@@ -90,6 +91,7 @@ function getNormalizedConfig(config) {
|
|
|
90
91
|
mapping['hidden_size'] = 'hidden_size';
|
|
91
92
|
break;
|
|
92
93
|
case 'llama':
|
|
94
|
+
case 'granite':
|
|
93
95
|
case 'cohere':
|
|
94
96
|
case 'mistral':
|
|
95
97
|
case 'starcoder2':
|
|
@@ -295,16 +297,23 @@ export function getKeyValueShapes(config, {
|
|
|
295
297
|
export class PretrainedConfig {
|
|
296
298
|
// NOTE: Typo in original
|
|
297
299
|
|
|
300
|
+
/** @type {string|null} */
|
|
301
|
+
model_type = null;
|
|
302
|
+
|
|
303
|
+
/** @type {boolean} */
|
|
304
|
+
is_encoder_decoder = false;
|
|
305
|
+
|
|
306
|
+
/** @type {number} */
|
|
298
307
|
max_position_embeddings;
|
|
299
308
|
|
|
309
|
+
/** @type {TransformersJSConfig} */
|
|
310
|
+
'transformers.js_config';
|
|
311
|
+
|
|
300
312
|
/**
|
|
301
313
|
* Create a new PreTrainedTokenizer instance.
|
|
302
314
|
* @param {Object} configJSON The JSON of the config.
|
|
303
315
|
*/
|
|
304
316
|
constructor(configJSON) {
|
|
305
|
-
this.model_type = null;
|
|
306
|
-
this.is_encoder_decoder = false;
|
|
307
|
-
|
|
308
317
|
Object.assign(this, configJSON);
|
|
309
318
|
this.normalized_config = getNormalizedConfig(this);
|
|
310
319
|
}
|
|
@@ -356,5 +365,11 @@ export class AutoConfig {
|
|
|
356
365
|
/**
|
|
357
366
|
* Transformers.js-specific configuration, possibly present in config.json under the key `transformers.js_config`.
|
|
358
367
|
* @typedef {Object} TransformersJSConfig
|
|
359
|
-
* @property {import('./
|
|
368
|
+
* @property {import('./utils/tensor.js').DataType|Record<import('./utils/dtypes.js').DataType, import('./utils/tensor.js').DataType>} [kv_cache_dtype] The data type of the key-value cache.
|
|
369
|
+
* @property {Record<string, number>} [free_dimension_overrides] Override the free dimensions of the model.
|
|
370
|
+
* See https://onnxruntime.ai/docs/tutorials/web/env-flags-and-session-options.html#freedimensionoverrides
|
|
371
|
+
* for more information.
|
|
372
|
+
* @property {import('./utils/devices.js').DeviceType} [device] The default device to use for the model.
|
|
373
|
+
* @property {import('./utils/dtypes.js').DataType} [dtype] The default data type to use for the model.
|
|
374
|
+
* @property {boolean|Record<string, boolean>} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size).
|
|
360
375
|
*/
|
package/src/env.js
CHANGED
|
@@ -26,7 +26,7 @@ import fs from 'fs';
|
|
|
26
26
|
import path from 'path';
|
|
27
27
|
import url from 'url';
|
|
28
28
|
|
|
29
|
-
const VERSION = '3.0.0
|
|
29
|
+
const VERSION = '3.0.0';
|
|
30
30
|
|
|
31
31
|
// Check if various APIs are available (depends on environment)
|
|
32
32
|
const IS_BROWSER_ENV = typeof self !== 'undefined';
|
|
@@ -73,26 +73,26 @@ export const apis = Object.freeze({
|
|
|
73
73
|
});
|
|
74
74
|
|
|
75
75
|
const RUNNING_LOCALLY = IS_FS_AVAILABLE && IS_PATH_AVAILABLE;
|
|
76
|
-
const
|
|
76
|
+
const dirname__ = RUNNING_LOCALLY
|
|
77
77
|
? path.dirname(path.dirname(url.fileURLToPath(import.meta.url)))
|
|
78
78
|
: './';
|
|
79
79
|
|
|
80
80
|
// Only used for environments with access to file system
|
|
81
81
|
const DEFAULT_CACHE_DIR = RUNNING_LOCALLY
|
|
82
|
-
? path.join(
|
|
82
|
+
? path.join(dirname__, '/.cache/')
|
|
83
83
|
: null;
|
|
84
84
|
|
|
85
85
|
// Set local model path, based on available APIs
|
|
86
86
|
const DEFAULT_LOCAL_MODEL_PATH = '/models/';
|
|
87
87
|
const localModelPath = RUNNING_LOCALLY
|
|
88
|
-
? path.join(
|
|
88
|
+
? path.join(dirname__, DEFAULT_LOCAL_MODEL_PATH)
|
|
89
89
|
: DEFAULT_LOCAL_MODEL_PATH;
|
|
90
90
|
|
|
91
91
|
/**
|
|
92
92
|
* Global variable given visible to users to control execution. This provides users a simple way to configure Transformers.js.
|
|
93
93
|
* @typedef {Object} TransformersEnvironment
|
|
94
94
|
* @property {string} version This version of Transformers.js.
|
|
95
|
-
* @property {
|
|
95
|
+
* @property {{onnx: Partial<import('onnxruntime-common').Env>}} backends Expose environment variables of different backends,
|
|
96
96
|
* allowing users to set these variables if they want to.
|
|
97
97
|
* @property {boolean} allowRemoteModels Whether to allow loading of remote files, defaults to `true`.
|
|
98
98
|
* If set to `false`, it will have the same effect as setting `local_files_only=true` when loading pipelines, models, tokenizers, processors, etc.
|
|
@@ -119,12 +119,8 @@ export const env = {
|
|
|
119
119
|
backends: {
|
|
120
120
|
// onnxruntime-web/onnxruntime-node
|
|
121
121
|
onnx: {},
|
|
122
|
-
|
|
123
|
-
// TensorFlow.js
|
|
124
|
-
tfjs: {},
|
|
125
122
|
},
|
|
126
123
|
|
|
127
|
-
|
|
128
124
|
/////////////////// Model settings ///////////////////
|
|
129
125
|
allowRemoteModels: true,
|
|
130
126
|
remoteHost: 'https://huggingface.co/',
|
|
@@ -156,9 +156,9 @@ export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
|
|
|
156
156
|
_call(input_ids, logits) {
|
|
157
157
|
for (let i = 0; i < input_ids.length; ++i) {
|
|
158
158
|
if (input_ids[i].length === 1) {
|
|
159
|
-
const
|
|
160
|
-
|
|
161
|
-
|
|
159
|
+
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
|
|
160
|
+
batch_logits_data.fill(-Infinity);
|
|
161
|
+
batch_logits_data[this.bos_token_id] = 0;
|
|
162
162
|
}
|
|
163
163
|
}
|
|
164
164
|
return logits;
|
|
@@ -189,11 +189,10 @@ export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
|
|
|
189
189
|
_call(input_ids, logits) {
|
|
190
190
|
for (let i = 0; i < input_ids.length; ++i) {
|
|
191
191
|
if (input_ids[i].length === this.max_length - 1) {
|
|
192
|
-
const
|
|
193
|
-
|
|
194
|
-
|
|
192
|
+
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
|
|
193
|
+
batch_logits_data.fill(-Infinity);
|
|
195
194
|
for (const eos_token of this.eos_token_id) {
|
|
196
|
-
|
|
195
|
+
batch_logits_data[eos_token] = 0;
|
|
197
196
|
}
|
|
198
197
|
}
|
|
199
198
|
}
|
|
@@ -227,9 +226,9 @@ export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor {
|
|
|
227
226
|
_call(input_ids, logits) {
|
|
228
227
|
for (let i = 0; i < input_ids.length; ++i) {
|
|
229
228
|
if (input_ids[i].length === this.begin_index) {
|
|
230
|
-
const
|
|
229
|
+
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
|
|
231
230
|
for (const token_id of this.begin_suppress_tokens) {
|
|
232
|
-
|
|
231
|
+
batch_logits_data[token_id] = -Infinity;
|
|
233
232
|
}
|
|
234
233
|
}
|
|
235
234
|
}
|
|
@@ -271,15 +270,14 @@ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
|
|
|
271
270
|
*/
|
|
272
271
|
_call(input_ids, logits) {
|
|
273
272
|
for (let i = 0; i < input_ids.length; ++i) {
|
|
274
|
-
const
|
|
275
|
-
const logitsData = /** @type {Float32Array} */(batch_logits.data);
|
|
273
|
+
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
|
|
276
274
|
|
|
277
275
|
// suppress <|notimestamps|> which is handled by without_timestamps
|
|
278
|
-
|
|
276
|
+
batch_logits_data[this.no_timestamps_token_id] = -Infinity;
|
|
279
277
|
|
|
280
278
|
if (input_ids[i].length === this.begin_index - 1) {
|
|
281
|
-
|
|
282
|
-
|
|
279
|
+
batch_logits_data.fill(-Infinity);
|
|
280
|
+
batch_logits_data[this.timestamp_begin] = 0;
|
|
283
281
|
continue;
|
|
284
282
|
}
|
|
285
283
|
|
|
@@ -290,25 +288,25 @@ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
|
|
|
290
288
|
|
|
291
289
|
if (last_was_timestamp) {
|
|
292
290
|
if (penultimate_was_timestamp) { // has to be non-timestamp
|
|
293
|
-
|
|
291
|
+
batch_logits_data.subarray(this.timestamp_begin).fill(-Infinity);
|
|
294
292
|
} else { // cannot be normal text tokens
|
|
295
|
-
|
|
293
|
+
batch_logits_data.subarray(0, this.eos_token_id).fill(-Infinity);
|
|
296
294
|
}
|
|
297
295
|
}
|
|
298
296
|
|
|
299
297
|
// apply the `max_initial_timestamp` option
|
|
300
298
|
if (input_ids[i].length === this.begin_index && this.max_initial_timestamp_index !== null) {
|
|
301
299
|
const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index;
|
|
302
|
-
|
|
300
|
+
batch_logits_data.subarray(last_allowed + 1).fill(-Infinity);
|
|
303
301
|
}
|
|
304
302
|
|
|
305
303
|
// if sum of probability over timestamps is above any other token, sample timestamp
|
|
306
|
-
const logprobs = log_softmax(
|
|
304
|
+
const logprobs = log_softmax(batch_logits_data);
|
|
307
305
|
const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b));
|
|
308
306
|
const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0];
|
|
309
307
|
|
|
310
308
|
if (timestamp_logprob > max_text_token_logprob) {
|
|
311
|
-
|
|
309
|
+
batch_logits_data.subarray(0, this.timestamp_begin).fill(-Infinity);
|
|
312
310
|
}
|
|
313
311
|
}
|
|
314
312
|
|
|
@@ -397,10 +395,10 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
|
|
|
397
395
|
*/
|
|
398
396
|
_call(input_ids, logits) {
|
|
399
397
|
for (let i = 0; i < input_ids.length; ++i) {
|
|
400
|
-
const
|
|
398
|
+
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
|
|
401
399
|
const bannedTokens = this.calcBannedNgramTokens(input_ids[i]);
|
|
402
400
|
for (const token of bannedTokens) {
|
|
403
|
-
|
|
401
|
+
batch_logits_data[token] = -Infinity;
|
|
404
402
|
}
|
|
405
403
|
}
|
|
406
404
|
return logits;
|
|
@@ -432,13 +430,13 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
|
|
|
432
430
|
// many times in the output will be penalised more.
|
|
433
431
|
|
|
434
432
|
for (let i = 0; i < input_ids.length; ++i) {
|
|
435
|
-
const
|
|
436
|
-
|
|
433
|
+
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
|
|
437
434
|
for (const input_id of input_ids[i]) {
|
|
438
|
-
|
|
439
|
-
|
|
435
|
+
const token = Number(input_id);
|
|
436
|
+
if (batch_logits_data[token] < 0) {
|
|
437
|
+
batch_logits_data[token] *= this.penalty;
|
|
440
438
|
} else {
|
|
441
|
-
|
|
439
|
+
batch_logits_data[token] /= this.penalty;
|
|
442
440
|
}
|
|
443
441
|
}
|
|
444
442
|
}
|
|
@@ -471,9 +469,10 @@ export class MinLengthLogitsProcessor extends LogitsProcessor {
|
|
|
471
469
|
_call(input_ids, logits) {
|
|
472
470
|
for (let i = 0; i < input_ids.length; ++i) {
|
|
473
471
|
if (input_ids[i].length < this.min_length) {
|
|
474
|
-
const
|
|
472
|
+
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
|
|
473
|
+
|
|
475
474
|
for (const eos_token of this.eos_token_id) {
|
|
476
|
-
|
|
475
|
+
batch_logits_data[eos_token] = -Infinity;
|
|
477
476
|
}
|
|
478
477
|
}
|
|
479
478
|
}
|
|
@@ -509,9 +508,10 @@ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
|
|
|
509
508
|
for (let i = 0; i < input_ids.length; ++i) {
|
|
510
509
|
const new_tokens_length = input_ids[i].length - this.prompt_length_to_skip;
|
|
511
510
|
if (new_tokens_length < this.min_new_tokens) {
|
|
512
|
-
const
|
|
511
|
+
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
|
|
512
|
+
|
|
513
513
|
for (const eos_token of this.eos_token_id) {
|
|
514
|
-
|
|
514
|
+
batch_logits_data[eos_token] = -Infinity;
|
|
515
515
|
}
|
|
516
516
|
}
|
|
517
517
|
}
|
|
@@ -539,23 +539,26 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor {
|
|
|
539
539
|
*/
|
|
540
540
|
_call(input_ids, logits) {
|
|
541
541
|
for (let i = 0; i < input_ids.length; ++i) {
|
|
542
|
-
const
|
|
542
|
+
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
|
|
543
|
+
const ids = input_ids[i];
|
|
543
544
|
for (const bad_word_ids of this.bad_words_ids) {
|
|
544
545
|
// Whether to modify the logits of the last token in the bad word id sequence
|
|
545
546
|
let mark = true;
|
|
546
547
|
|
|
547
548
|
// For each bad word in the list, if the current sequence of input ids ends with this sequence (excluding the last),
|
|
548
549
|
// then we set the logits of the last bad word id to -Infinity.
|
|
549
|
-
for (let
|
|
550
|
+
for (let j = 1; j <= bad_word_ids.length - 1 && bad_word_ids.length < ids.length; ++j) {
|
|
550
551
|
|
|
551
|
-
|
|
552
|
+
// NOTE: We use != instead of !== to compare bigint and number
|
|
553
|
+
// @ts-ignore
|
|
554
|
+
if (bad_word_ids.at(-j - 1) != ids.at(-j)) {
|
|
552
555
|
// We have found a mismatch
|
|
553
556
|
mark = false;
|
|
554
557
|
break;
|
|
555
558
|
}
|
|
556
559
|
}
|
|
557
560
|
if (mark) {
|
|
558
|
-
|
|
561
|
+
batch_logits_data[bad_word_ids.at(-1)] = -Infinity;
|
|
559
562
|
}
|
|
560
563
|
}
|
|
561
564
|
}
|
|
@@ -650,9 +653,9 @@ export class TemperatureLogitsWarper extends LogitsWarper {
|
|
|
650
653
|
* @returns {Object} The processed logits.
|
|
651
654
|
*/
|
|
652
655
|
_call(input_ids, logits) {
|
|
653
|
-
const
|
|
654
|
-
for (let i = 0; i <
|
|
655
|
-
|
|
656
|
+
const batch_logits_data = /** @type {Float32Array} */(logits.data);
|
|
657
|
+
for (let i = 0; i < batch_logits_data.length; ++i) {
|
|
658
|
+
batch_logits_data[i] /= this.temperature;
|
|
656
659
|
}
|
|
657
660
|
return logits;
|
|
658
661
|
}
|