@huggingface/transformers 3.0.0-alpha.2 → 3.0.0-alpha.20

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. package/README.md +19 -9
  2. package/dist/ort-wasm-simd-threaded.jsep.wasm +0 -0
  3. package/dist/transformers.cjs +2402 -2039
  4. package/dist/transformers.cjs.map +1 -1
  5. package/dist/transformers.js +3423 -2999
  6. package/dist/transformers.js.map +1 -1
  7. package/dist/transformers.min.cjs +37 -43
  8. package/dist/transformers.min.cjs.map +1 -1
  9. package/dist/transformers.min.js +39 -40
  10. package/dist/transformers.min.js.map +1 -1
  11. package/dist/transformers.min.mjs +63 -70
  12. package/dist/transformers.min.mjs.map +1 -1
  13. package/dist/transformers.mjs +2452 -2063
  14. package/dist/transformers.mjs.map +1 -1
  15. package/package.json +23 -13
  16. package/src/backends/onnx.js +98 -36
  17. package/src/configs.js +18 -4
  18. package/src/env.js +9 -9
  19. package/src/generation/logits_process.js +40 -37
  20. package/src/generation/streamers.js +3 -3
  21. package/src/models.js +238 -74
  22. package/src/ops/registry.js +14 -3
  23. package/src/pipelines.js +5 -4
  24. package/src/processors.js +390 -351
  25. package/src/tokenizers.js +139 -174
  26. package/src/utils/core.js +12 -0
  27. package/src/utils/data-structures.js +13 -11
  28. package/src/utils/devices.js +15 -4
  29. package/src/utils/dtypes.js +1 -3
  30. package/src/utils/hub.js +18 -17
  31. package/src/utils/maths.js +14 -5
  32. package/src/utils/tensor.js +23 -0
  33. package/types/backends/onnx.d.ts +6 -5
  34. package/types/backends/onnx.d.ts.map +1 -1
  35. package/types/configs.d.ts +29 -3
  36. package/types/configs.d.ts.map +1 -1
  37. package/types/env.d.ts +6 -2
  38. package/types/env.d.ts.map +1 -1
  39. package/types/generation/logits_process.d.ts.map +1 -1
  40. package/types/models.d.ts +108 -2
  41. package/types/models.d.ts.map +1 -1
  42. package/types/ops/registry.d.ts +6 -6
  43. package/types/ops/registry.d.ts.map +1 -1
  44. package/types/pipelines.d.ts.map +1 -1
  45. package/types/processors.d.ts +55 -51
  46. package/types/processors.d.ts.map +1 -1
  47. package/types/tokenizers.d.ts +23 -32
  48. package/types/tokenizers.d.ts.map +1 -1
  49. package/types/utils/core.d.ts +7 -0
  50. package/types/utils/core.d.ts.map +1 -1
  51. package/types/utils/data-structures.d.ts +6 -6
  52. package/types/utils/data-structures.d.ts.map +1 -1
  53. package/types/utils/devices.d.ts +11 -1
  54. package/types/utils/devices.d.ts.map +1 -1
  55. package/types/utils/dtypes.d.ts +0 -3
  56. package/types/utils/dtypes.d.ts.map +1 -1
  57. package/types/utils/hub.d.ts +2 -41
  58. package/types/utils/hub.d.ts.map +1 -1
  59. package/types/utils/maths.d.ts +2 -2
  60. package/types/utils/maths.d.ts.map +1 -1
  61. package/types/utils/tensor.d.ts +13 -1
  62. package/types/utils/tensor.d.ts.map +1 -1
package/package.json CHANGED
@@ -1,16 +1,31 @@
1
1
  {
2
2
  "name": "@huggingface/transformers",
3
- "version": "3.0.0-alpha.2",
3
+ "version": "3.0.0-alpha.20",
4
4
  "description": "State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!",
5
5
  "main": "./src/transformers.js",
6
6
  "types": "./types/transformers.d.ts",
7
7
  "type": "module",
8
8
  "exports": {
9
9
  "node": {
10
- "import": "./dist/transformers.mjs",
11
- "require": "./dist/transformers.cjs"
10
+ "import": {
11
+ "types": "./types/transformers.d.ts",
12
+ "default": "./dist/transformers.mjs"
13
+ },
14
+ "require": {
15
+ "types": "./types/transformers.d.ts",
16
+ "default": "./dist/transformers.cjs"
17
+ }
12
18
  },
13
- "default": "./dist/transformers.js"
19
+ "default": {
20
+ "types": "./types/transformers.d.ts",
21
+ "default": "./dist/transformers.js"
22
+ }
23
+ },
24
+ "imports": {
25
+ "#onnxruntime-webgpu": {
26
+ "node": "onnxruntime-web",
27
+ "default": "onnxruntime-web/webgpu"
28
+ }
14
29
  },
15
30
  "scripts": {
16
31
  "format": "prettier --write .",
@@ -18,8 +33,7 @@
18
33
  "typegen": "tsc ./src/transformers.js --allowJs --declaration --emitDeclarationOnly --declarationMap --outDir types",
19
34
  "dev": "webpack serve --no-client-overlay",
20
35
  "build": "webpack && npm run typegen",
21
- "generate-tests": "python -m tests.generate_tests",
22
- "test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --verbose --maxConcurrency 1",
36
+ "test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --verbose",
23
37
  "readme": "python ./docs/scripts/build_readme.py",
24
38
  "docs-api": "node ./docs/scripts/generate.js",
25
39
  "docs-preview": "doc-builder preview transformers.js ./docs/source/ --not_python_module",
@@ -48,9 +62,9 @@
48
62
  "homepage": "https://github.com/xenova/transformers.js#readme",
49
63
  "dependencies": {
50
64
  "@huggingface/jinja": "^0.3.0",
51
- "onnxruntime-node": "1.18.0",
52
- "onnxruntime-web": "1.19.0-dev.20240804-ee2fe87e2d",
53
- "sharp": "^0.33.2"
65
+ "onnxruntime-node": "1.19.2",
66
+ "onnxruntime-web": "1.20.0-dev.20240928-1bda91fc57",
67
+ "sharp": "^0.33.5"
54
68
  },
55
69
  "devDependencies": {
56
70
  "@types/jest": "^29.5.1",
@@ -66,10 +80,6 @@
66
80
  "webpack-cli": "^5.0.2",
67
81
  "webpack-dev-server": "^4.13.3"
68
82
  },
69
- "overrides": {
70
- "semver": "^7.6.3",
71
- "protobufjs": "^7.2.6"
72
- },
73
83
  "files": [
74
84
  "src",
75
85
  "dist",
@@ -21,27 +21,89 @@ import { env, apis } from '../env.js';
21
21
  // NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`.
22
22
  // In either case, we select the default export if it exists, otherwise we use the named export.
23
23
  import * as ONNX_NODE from 'onnxruntime-node';
24
- import * as ONNX_WEB from 'onnxruntime-web/webgpu';
24
+
25
+ // Use subpath-imports to ensure Node.js and browser interoperability.
26
+ // See package.json and https://nodejs.org/api/packages.html#subpath-imports
27
+ // for more information.
28
+ // @ts-ignore
29
+ import * as ONNX_WEB from '#onnxruntime-webgpu';
25
30
 
26
31
  export { Tensor } from 'onnxruntime-common';
27
32
 
28
- /** @type {import('../utils/devices.js').DeviceType[]} */
29
- const supportedExecutionProviders = [];
33
+ /**
34
+ * @typedef {import('onnxruntime-common').InferenceSession.ExecutionProviderConfig} ONNXExecutionProviders
35
+ */
36
+
37
+ /** @type {Record<import("../utils/devices.js").DeviceType, ONNXExecutionProviders>} */
38
+ const DEVICE_TO_EXECUTION_PROVIDER_MAPPING = Object.freeze({
39
+ auto: null, // Auto-detect based on device and environment
40
+ gpu: null, // Auto-detect GPU
41
+ cpu: 'cpu', // CPU
42
+ wasm: 'wasm', // WebAssembly
43
+ webgpu: 'webgpu', // WebGPU
44
+ cuda: 'cuda', // CUDA
45
+ dml: 'dml', // DirectML
46
+
47
+ webnn: { name: 'webnn', deviceType: 'cpu' }, // WebNN (default)
48
+ 'webnn-npu': { name: 'webnn', deviceType: 'npu' }, // WebNN NPU
49
+ 'webnn-gpu': { name: 'webnn', deviceType: 'gpu' }, // WebNN GPU
50
+ 'webnn-cpu': { name: 'webnn', deviceType: 'cpu' }, // WebNN CPU
51
+ });
52
+
53
+ /**
54
+ * The list of supported devices, sorted by priority/performance.
55
+ * @type {import("../utils/devices.js").DeviceType[]}
56
+ */
57
+ const supportedDevices = [];
30
58
 
31
- /** @type {import('../utils/devices.js').DeviceType[]} */
32
- let defaultExecutionProviders;
59
+ /** @type {ONNXExecutionProviders[]} */
60
+ let defaultDevices;
33
61
  let ONNX;
34
- if (apis.IS_NODE_ENV) {
62
+ const ORT_SYMBOL = Symbol.for('onnxruntime');
63
+
64
+ if (ORT_SYMBOL in globalThis) {
65
+ // If the JS runtime exposes their own ONNX runtime, use it
66
+ ONNX = globalThis[ORT_SYMBOL];
67
+
68
+ } else if (apis.IS_NODE_ENV) {
35
69
  ONNX = ONNX_NODE.default ?? ONNX_NODE;
36
- supportedExecutionProviders.push('cpu');
37
- defaultExecutionProviders = ['cpu'];
70
+
71
+ // Updated as of ONNX Runtime 1.18.0
72
+ // The following table lists the supported versions of ONNX Runtime Node.js binding provided with pre-built binaries.
73
+ // | EPs/Platforms | Windows x64 | Windows arm64 | Linux x64 | Linux arm64 | MacOS x64 | MacOS arm64 |
74
+ // | ------------- | ----------- | ------------- | ----------------- | ----------- | --------- | ----------- |
75
+ // | CPU | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
76
+ // | DirectML | ✔️ | ✔️ | ❌ | ❌ | ❌ | ❌ |
77
+ // | CUDA | ❌ | ❌ | ✔️ (CUDA v11.8) | ❌ | ❌ | ❌ |
78
+ switch (process.platform) {
79
+ case 'win32': // Windows x64 and Windows arm64
80
+ supportedDevices.push('dml');
81
+ break;
82
+ case 'linux': // Linux x64 and Linux arm64
83
+ if (process.arch === 'x64') {
84
+ supportedDevices.push('cuda');
85
+ }
86
+ break;
87
+ case 'darwin': // MacOS x64 and MacOS arm64
88
+ break;
89
+ }
90
+
91
+ supportedDevices.push('cpu');
92
+ defaultDevices = ['cpu'];
38
93
  } else {
39
94
  ONNX = ONNX_WEB;
95
+
96
+ if (apis.IS_WEBNN_AVAILABLE) {
97
+ // TODO: Only push supported providers (depending on available hardware)
98
+ supportedDevices.push('webnn-npu', 'webnn-gpu', 'webnn-cpu', 'webnn');
99
+ }
100
+
40
101
  if (apis.IS_WEBGPU_AVAILABLE) {
41
- supportedExecutionProviders.push('webgpu');
102
+ supportedDevices.push('webgpu');
42
103
  }
43
- supportedExecutionProviders.push('wasm');
44
- defaultExecutionProviders = ['wasm'];
104
+
105
+ supportedDevices.push('wasm');
106
+ defaultDevices = ['wasm'];
45
107
  }
46
108
 
47
109
  // @ts-ignore
@@ -49,19 +111,28 @@ const InferenceSession = ONNX.InferenceSession;
49
111
 
50
112
  /**
51
113
  * Map a device to the execution providers to use for the given device.
52
- * @param {import("../utils/devices.js").DeviceType} [device=null] (Optional) The device to run the inference on.
53
- * @returns {import("../utils/devices.js").DeviceType[]} The execution providers to use for the given device.
114
+ * @param {import("../utils/devices.js").DeviceType|"auto"|null} [device=null] (Optional) The device to run the inference on.
115
+ * @returns {ONNXExecutionProviders[]} The execution providers to use for the given device.
54
116
  */
55
- export function deviceToExecutionProviders(device) {
56
- // TODO: Use mapping from device to execution providers for overloaded devices (e.g., 'gpu' or 'cpu').
57
- let executionProviders = defaultExecutionProviders;
58
- if (device) { // User has specified a device
59
- if (!supportedExecutionProviders.includes(device)) {
60
- throw new Error(`Unsupported device: "${device}". Should be one of: ${supportedExecutionProviders.join(', ')}.`)
61
- }
62
- executionProviders = [device];
117
+ export function deviceToExecutionProviders(device = null) {
118
+ // Use the default execution providers if the user hasn't specified anything
119
+ if (!device) return defaultDevices;
120
+
121
+ // Handle overloaded cases
122
+ switch (device) {
123
+ case "auto":
124
+ return supportedDevices;
125
+ case "gpu":
126
+ return supportedDevices.filter(x =>
127
+ ["webgpu", "cuda", "dml", "webnn-gpu"].includes(x),
128
+ );
63
129
  }
64
- return executionProviders;
130
+
131
+ if (supportedDevices.includes(device)) {
132
+ return [DEVICE_TO_EXECUTION_PROVIDER_MAPPING[device] ?? device];
133
+ }
134
+
135
+ throw new Error(`Unsupported device: "${device}". Should be one of: ${supportedDevices.join(', ')}.`)
65
136
  }
66
137
 
67
138
 
@@ -76,7 +147,7 @@ let wasmInitPromise = null;
76
147
  /**
77
148
  * Create an ONNX inference session.
78
149
  * @param {Uint8Array} buffer The ONNX model buffer.
79
- * @param {Object} session_options ONNX inference session options.
150
+ * @param {import('onnxruntime-common').InferenceSession.SessionOptions} session_options ONNX inference session options.
80
151
  * @returns {Promise<import('onnxruntime-common').InferenceSession>} The ONNX inference session.
81
152
  */
82
153
  export async function createInferenceSession(buffer, session_options) {
@@ -100,6 +171,7 @@ export function isONNXTensor(x) {
100
171
  return x instanceof ONNX.Tensor;
101
172
  }
102
173
 
174
+ /** @type {import('onnxruntime-common').Env} */
103
175
  // @ts-ignore
104
176
  const ONNX_ENV = ONNX?.env;
105
177
  if (ONNX_ENV?.wasm) {
@@ -114,24 +186,14 @@ if (ONNX_ENV?.wasm) {
114
186
  // TODO: Add support for loading WASM files from cached buffer when we upgrade to onnxruntime-web@1.19.0
115
187
  // https://github.com/microsoft/onnxruntime/pull/21534
116
188
 
117
- // Proxy the WASM backend to prevent the UI from freezing
118
- // NOTE: This is only needed when running in a non-worker browser environment.
119
- ONNX_ENV.wasm.proxy = !apis.IS_WEBWORKER_ENV;
189
+ // Users may wish to proxy the WASM backend to prevent the UI from freezing,
190
+ // However, this is not necessary when using WebGPU, so we default to false.
191
+ ONNX_ENV.wasm.proxy = false;
120
192
 
121
193
  // https://developer.mozilla.org/en-US/docs/Web/API/crossOriginIsolated
122
194
  if (typeof crossOriginIsolated === 'undefined' || !crossOriginIsolated) {
123
195
  ONNX_ENV.wasm.numThreads = 1;
124
196
  }
125
-
126
- // Running in a browser-environment
127
- // TODO: Check if 1.17.1 fixes this issue.
128
- // SIMD for WebAssembly does not operate correctly in some recent versions of iOS (16.4.x).
129
- // As a temporary fix, we disable it for now.
130
- // For more information, see: https://github.com/microsoft/onnxruntime/issues/15644
131
- const isIOS = typeof navigator !== 'undefined' && /iP(hone|od|ad).+16_4.+AppleWebKit/.test(navigator.userAgent);
132
- if (isIOS) {
133
- ONNX_ENV.wasm.simd = false;
134
- }
135
197
  }
136
198
 
137
199
  if (ONNX_ENV?.webgpu) {
package/src/configs.js CHANGED
@@ -73,6 +73,7 @@ function getNormalizedConfig(config) {
73
73
  // Decoder-only models
74
74
  case 'gpt2':
75
75
  case 'gptj':
76
+ case 'jais':
76
77
  case 'codegen':
77
78
  case 'gpt_bigcode':
78
79
  mapping['num_heads'] = 'n_head';
@@ -295,16 +296,23 @@ export function getKeyValueShapes(config, {
295
296
  export class PretrainedConfig {
296
297
  // NOTE: Typo in original
297
298
 
299
+ /** @type {string|null} */
300
+ model_type = null;
301
+
302
+ /** @type {boolean} */
303
+ is_encoder_decoder = false;
304
+
305
+ /** @type {number} */
298
306
  max_position_embeddings;
299
307
 
308
+ /** @type {TransformersJSConfig} */
309
+ 'transformers.js_config';
310
+
300
311
  /**
301
312
  * Create a new PreTrainedTokenizer instance.
302
313
  * @param {Object} configJSON The JSON of the config.
303
314
  */
304
315
  constructor(configJSON) {
305
- this.model_type = null;
306
- this.is_encoder_decoder = false;
307
-
308
316
  Object.assign(this, configJSON);
309
317
  this.normalized_config = getNormalizedConfig(this);
310
318
  }
@@ -356,5 +364,11 @@ export class AutoConfig {
356
364
  /**
357
365
  * Transformers.js-specific configuration, possibly present in config.json under the key `transformers.js_config`.
358
366
  * @typedef {Object} TransformersJSConfig
359
- * @property {import('./transformers.js').DataType} [kv_cache_dtype]
367
+ * @property {import('./utils/tensor.js').DataType} [kv_cache_dtype] The data type of the key-value cache.
368
+ * @property {Record<string, number>} [free_dimension_overrides] Override the free dimensions of the model.
369
+ * See https://onnxruntime.ai/docs/tutorials/web/env-flags-and-session-options.html#freedimensionoverrides
370
+ * for more information.
371
+ * @property {import('./utils/devices.js').DeviceType} [device] The default device to use for the model.
372
+ * @property {import('./utils/dtypes.js').DataType} [dtype] The default data type to use for the model.
373
+ * @property {boolean|Record<string, boolean>} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size).
360
374
  */
package/src/env.js CHANGED
@@ -26,13 +26,14 @@ import fs from 'fs';
26
26
  import path from 'path';
27
27
  import url from 'url';
28
28
 
29
- const VERSION = '3.0.0-alpha.2';
29
+ const VERSION = '3.0.0-alpha.20';
30
30
 
31
31
  // Check if various APIs are available (depends on environment)
32
32
  const IS_BROWSER_ENV = typeof self !== 'undefined';
33
33
  const IS_WEBWORKER_ENV = IS_BROWSER_ENV && self.constructor.name === 'DedicatedWorkerGlobalScope';
34
34
  const IS_WEB_CACHE_AVAILABLE = IS_BROWSER_ENV && 'caches' in self;
35
35
  const IS_WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator;
36
+ const IS_WEBNN_AVAILABLE = typeof navigator !== 'undefined' && 'ml' in navigator;
36
37
 
37
38
  const IS_PROCESS_AVAILABLE = typeof process !== 'undefined';
38
39
  const IS_NODE_ENV = IS_PROCESS_AVAILABLE && process?.release?.name === 'node';
@@ -55,6 +56,9 @@ export const apis = Object.freeze({
55
56
  /** Whether the WebGPU API is available */
56
57
  IS_WEBGPU_AVAILABLE,
57
58
 
59
+ /** Whether the WebNN API is available */
60
+ IS_WEBNN_AVAILABLE,
61
+
58
62
  /** Whether the Node.js process API is available */
59
63
  IS_PROCESS_AVAILABLE,
60
64
 
@@ -69,26 +73,26 @@ export const apis = Object.freeze({
69
73
  });
70
74
 
71
75
  const RUNNING_LOCALLY = IS_FS_AVAILABLE && IS_PATH_AVAILABLE;
72
- const __dirname = RUNNING_LOCALLY
76
+ const dirname__ = RUNNING_LOCALLY
73
77
  ? path.dirname(path.dirname(url.fileURLToPath(import.meta.url)))
74
78
  : './';
75
79
 
76
80
  // Only used for environments with access to file system
77
81
  const DEFAULT_CACHE_DIR = RUNNING_LOCALLY
78
- ? path.join(__dirname, '/.cache/')
82
+ ? path.join(dirname__, '/.cache/')
79
83
  : null;
80
84
 
81
85
  // Set local model path, based on available APIs
82
86
  const DEFAULT_LOCAL_MODEL_PATH = '/models/';
83
87
  const localModelPath = RUNNING_LOCALLY
84
- ? path.join(__dirname, DEFAULT_LOCAL_MODEL_PATH)
88
+ ? path.join(dirname__, DEFAULT_LOCAL_MODEL_PATH)
85
89
  : DEFAULT_LOCAL_MODEL_PATH;
86
90
 
87
91
  /**
88
92
  * Global variable given visible to users to control execution. This provides users a simple way to configure Transformers.js.
89
93
  * @typedef {Object} TransformersEnvironment
90
94
  * @property {string} version This version of Transformers.js.
91
- * @property {Object} backends Expose environment variables of different backends,
95
+ * @property {{onnx: Partial<import('onnxruntime-common').Env>}} backends Expose environment variables of different backends,
92
96
  * allowing users to set these variables if they want to.
93
97
  * @property {boolean} allowRemoteModels Whether to allow loading of remote files, defaults to `true`.
94
98
  * If set to `false`, it will have the same effect as setting `local_files_only=true` when loading pipelines, models, tokenizers, processors, etc.
@@ -115,12 +119,8 @@ export const env = {
115
119
  backends: {
116
120
  // onnxruntime-web/onnxruntime-node
117
121
  onnx: {},
118
-
119
- // TensorFlow.js
120
- tfjs: {},
121
122
  },
122
123
 
123
-
124
124
  /////////////////// Model settings ///////////////////
125
125
  allowRemoteModels: true,
126
126
  remoteHost: 'https://huggingface.co/',
@@ -156,9 +156,9 @@ export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
156
156
  _call(input_ids, logits) {
157
157
  for (let i = 0; i < input_ids.length; ++i) {
158
158
  if (input_ids[i].length === 1) {
159
- const batch_logits = logits[i];
160
- batch_logits.data.fill(-Infinity);
161
- batch_logits.data[this.bos_token_id] = 0;
159
+ const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
160
+ batch_logits_data.fill(-Infinity);
161
+ batch_logits_data[this.bos_token_id] = 0;
162
162
  }
163
163
  }
164
164
  return logits;
@@ -189,11 +189,10 @@ export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
189
189
  _call(input_ids, logits) {
190
190
  for (let i = 0; i < input_ids.length; ++i) {
191
191
  if (input_ids[i].length === this.max_length - 1) {
192
- const batch_logits = logits[i];
193
- batch_logits.data.fill(-Infinity);
194
-
192
+ const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
193
+ batch_logits_data.fill(-Infinity);
195
194
  for (const eos_token of this.eos_token_id) {
196
- batch_logits.data[eos_token] = 0;
195
+ batch_logits_data[eos_token] = 0;
197
196
  }
198
197
  }
199
198
  }
@@ -227,9 +226,9 @@ export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor {
227
226
  _call(input_ids, logits) {
228
227
  for (let i = 0; i < input_ids.length; ++i) {
229
228
  if (input_ids[i].length === this.begin_index) {
230
- const batch_logits = logits[i];
229
+ const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
231
230
  for (const token_id of this.begin_suppress_tokens) {
232
- batch_logits.data[token_id] = -Infinity;
231
+ batch_logits_data[token_id] = -Infinity;
233
232
  }
234
233
  }
235
234
  }
@@ -271,15 +270,14 @@ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
271
270
  */
272
271
  _call(input_ids, logits) {
273
272
  for (let i = 0; i < input_ids.length; ++i) {
274
- const batch_logits = logits[i];
275
- const logitsData = /** @type {Float32Array} */(batch_logits.data);
273
+ const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
276
274
 
277
275
  // suppress <|notimestamps|> which is handled by without_timestamps
278
- logitsData[this.no_timestamps_token_id] = -Infinity;
276
+ batch_logits_data[this.no_timestamps_token_id] = -Infinity;
279
277
 
280
278
  if (input_ids[i].length === this.begin_index - 1) {
281
- logitsData.fill(-Infinity);
282
- logitsData[this.timestamp_begin] = 0;
279
+ batch_logits_data.fill(-Infinity);
280
+ batch_logits_data[this.timestamp_begin] = 0;
283
281
  continue;
284
282
  }
285
283
 
@@ -290,25 +288,25 @@ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
290
288
 
291
289
  if (last_was_timestamp) {
292
290
  if (penultimate_was_timestamp) { // has to be non-timestamp
293
- logitsData.subarray(this.timestamp_begin).fill(-Infinity);
291
+ batch_logits_data.subarray(this.timestamp_begin).fill(-Infinity);
294
292
  } else { // cannot be normal text tokens
295
- logitsData.subarray(0, this.eos_token_id).fill(-Infinity);
293
+ batch_logits_data.subarray(0, this.eos_token_id).fill(-Infinity);
296
294
  }
297
295
  }
298
296
 
299
297
  // apply the `max_initial_timestamp` option
300
298
  if (input_ids[i].length === this.begin_index && this.max_initial_timestamp_index !== null) {
301
299
  const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index;
302
- logitsData.subarray(last_allowed + 1).fill(-Infinity);
300
+ batch_logits_data.subarray(last_allowed + 1).fill(-Infinity);
303
301
  }
304
302
 
305
303
  // if sum of probability over timestamps is above any other token, sample timestamp
306
- const logprobs = log_softmax(logitsData);
304
+ const logprobs = log_softmax(batch_logits_data);
307
305
  const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b));
308
306
  const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0];
309
307
 
310
308
  if (timestamp_logprob > max_text_token_logprob) {
311
- logitsData.subarray(0, this.timestamp_begin).fill(-Infinity);
309
+ batch_logits_data.subarray(0, this.timestamp_begin).fill(-Infinity);
312
310
  }
313
311
  }
314
312
 
@@ -397,10 +395,10 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
397
395
  */
398
396
  _call(input_ids, logits) {
399
397
  for (let i = 0; i < input_ids.length; ++i) {
400
- const batch_logits = logits[i];
398
+ const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
401
399
  const bannedTokens = this.calcBannedNgramTokens(input_ids[i]);
402
400
  for (const token of bannedTokens) {
403
- batch_logits.data[token] = -Infinity;
401
+ batch_logits_data[token] = -Infinity;
404
402
  }
405
403
  }
406
404
  return logits;
@@ -432,13 +430,13 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
432
430
  // many times in the output will be penalised more.
433
431
 
434
432
  for (let i = 0; i < input_ids.length; ++i) {
435
- const batch_logits = logits[i];
436
-
433
+ const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
437
434
  for (const input_id of input_ids[i]) {
438
- if (batch_logits.data[input_id] < 0) {
439
- batch_logits.data[input_id] *= this.penalty;
435
+ const token = Number(input_id);
436
+ if (batch_logits_data[token] < 0) {
437
+ batch_logits_data[token] *= this.penalty;
440
438
  } else {
441
- batch_logits.data[input_id] /= this.penalty;
439
+ batch_logits_data[token] /= this.penalty;
442
440
  }
443
441
  }
444
442
  }
@@ -471,9 +469,10 @@ export class MinLengthLogitsProcessor extends LogitsProcessor {
471
469
  _call(input_ids, logits) {
472
470
  for (let i = 0; i < input_ids.length; ++i) {
473
471
  if (input_ids[i].length < this.min_length) {
474
- const batch_logits = logits[i];
472
+ const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
473
+
475
474
  for (const eos_token of this.eos_token_id) {
476
- batch_logits.data[eos_token] = -Infinity;
475
+ batch_logits_data[eos_token] = -Infinity;
477
476
  }
478
477
  }
479
478
  }
@@ -509,9 +508,10 @@ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
509
508
  for (let i = 0; i < input_ids.length; ++i) {
510
509
  const new_tokens_length = input_ids[i].length - this.prompt_length_to_skip;
511
510
  if (new_tokens_length < this.min_new_tokens) {
512
- const batch_logits = logits[i];
511
+ const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
512
+
513
513
  for (const eos_token of this.eos_token_id) {
514
- batch_logits[eos_token] = -Infinity;
514
+ batch_logits_data[eos_token] = -Infinity;
515
515
  }
516
516
  }
517
517
  }
@@ -539,23 +539,26 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor {
539
539
  */
540
540
  _call(input_ids, logits) {
541
541
  for (let i = 0; i < input_ids.length; ++i) {
542
- const batch_logits = logits[i];
542
+ const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
543
+ const ids = input_ids[i];
543
544
  for (const bad_word_ids of this.bad_words_ids) {
544
545
  // Whether to modify the logits of the last token in the bad word id sequence
545
546
  let mark = true;
546
547
 
547
548
  // For each bad word in the list, if the current sequence of input ids ends with this sequence (excluding the last),
548
549
  // then we set the logits of the last bad word id to -Infinity.
549
- for (let i = 1; i <= bad_word_ids.length - 1 && bad_word_ids.length < input_ids[i].length; ++i) {
550
+ for (let j = 1; j <= bad_word_ids.length - 1 && bad_word_ids.length < ids.length; ++j) {
550
551
 
551
- if (bad_word_ids.at(-i - 1) !== Number(input_ids[i].at(-i))) {
552
+ // NOTE: We use != instead of !== to compare bigint and number
553
+ // @ts-ignore
554
+ if (bad_word_ids.at(-j - 1) != ids.at(-j)) {
552
555
  // We have found a mismatch
553
556
  mark = false;
554
557
  break;
555
558
  }
556
559
  }
557
560
  if (mark) {
558
- batch_logits[bad_word_ids.at(-1)] = -Infinity;
561
+ batch_logits_data[bad_word_ids.at(-1)] = -Infinity;
559
562
  }
560
563
  }
561
564
  }
@@ -650,9 +653,9 @@ export class TemperatureLogitsWarper extends LogitsWarper {
650
653
  * @returns {Object} The processed logits.
651
654
  */
652
655
  _call(input_ids, logits) {
653
- const logitsData = /** @type {Float32Array} */(logits.data);
654
- for (let i = 0; i < logitsData.length; ++i) {
655
- logitsData[i] /= this.temperature;
656
+ const batch_logits_data = /** @type {Float32Array} */(logits.data);
657
+ for (let i = 0; i < batch_logits_data.length; ++i) {
658
+ batch_logits_data[i] /= this.temperature;
656
659
  }
657
660
  return logits;
658
661
  }
@@ -65,14 +65,14 @@ export class TextStreamer extends BaseStreamer {
65
65
  throw Error('TextStreamer only supports batch size of 1');
66
66
  }
67
67
 
68
- const tokens = value[0];
69
- this.token_callback_function?.(tokens)
70
-
71
68
  if (this.skip_prompt && this.next_tokens_are_prompt) {
72
69
  this.next_tokens_are_prompt = false;
73
70
  return;
74
71
  }
75
72
 
73
+ const tokens = value[0];
74
+ this.token_callback_function?.(tokens)
75
+
76
76
  // Add the new token to the cache and decodes the entire thing.
77
77
  this.token_cache = mergeArrays(this.token_cache, tokens);
78
78
  const text = this.tokenizer.decode(this.token_cache, this.decode_kwargs);