@huggingface/transformers 3.0.0-alpha.0 → 3.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +6 -5
- package/dist/ort-wasm-simd-threaded.jsep.wasm +0 -0
- package/dist/transformers.cjs +317 -235
- package/dist/transformers.cjs.map +1 -1
- package/dist/transformers.js +1198 -1035
- package/dist/transformers.js.map +1 -1
- package/dist/transformers.min.cjs +34 -40
- package/dist/transformers.min.cjs.map +1 -1
- package/dist/transformers.min.js +32 -32
- package/dist/transformers.min.js.map +1 -1
- package/dist/transformers.min.mjs +168 -0
- package/dist/transformers.min.mjs.map +1 -0
- package/dist/transformers.mjs +31358 -0
- package/dist/transformers.mjs.map +1 -0
- package/package.json +16 -7
- package/src/backends/onnx.js +86 -35
- package/src/env.js +6 -6
- package/src/generation/logits_process.js +39 -36
- package/src/generation/streamers.js +3 -3
- package/src/models.js +23 -10
- package/src/processors.js +79 -67
- package/src/utils/devices.js +15 -4
- package/src/utils/dtypes.js +1 -3
- package/src/utils/hub.js +17 -16
- package/types/backends/onnx.d.ts +6 -5
- package/types/backends/onnx.d.ts.map +1 -1
- package/types/env.d.ts +6 -2
- package/types/env.d.ts.map +1 -1
- package/types/generation/logits_process.d.ts.map +1 -1
- package/types/models.d.ts +8 -0
- package/types/models.d.ts.map +1 -1
- package/types/processors.d.ts +15 -1
- package/types/processors.d.ts.map +1 -1
- package/types/utils/devices.d.ts +11 -1
- package/types/utils/devices.d.ts.map +1 -1
- package/types/utils/dtypes.d.ts +0 -3
- package/types/utils/dtypes.d.ts.map +1 -1
- package/types/utils/hub.d.ts +1 -40
- package/types/utils/hub.d.ts.map +1 -1
- package/types/utils/tensor.d.ts +1 -1
package/package.json
CHANGED
|
@@ -1,16 +1,25 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@huggingface/transformers",
|
|
3
|
-
"version": "3.0.0-alpha.
|
|
3
|
+
"version": "3.0.0-alpha.10",
|
|
4
4
|
"description": "State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!",
|
|
5
5
|
"main": "./src/transformers.js",
|
|
6
6
|
"types": "./types/transformers.d.ts",
|
|
7
7
|
"type": "module",
|
|
8
8
|
"exports": {
|
|
9
9
|
"node": {
|
|
10
|
-
"import":
|
|
11
|
-
|
|
10
|
+
"import": {
|
|
11
|
+
"types": "./types/transformers.d.ts",
|
|
12
|
+
"default": "./dist/transformers.min.mjs"
|
|
13
|
+
},
|
|
14
|
+
"require": {
|
|
15
|
+
"types": "./types/transformers.d.ts",
|
|
16
|
+
"default": "./dist/transformers.min.cjs"
|
|
17
|
+
}
|
|
12
18
|
},
|
|
13
|
-
"default":
|
|
19
|
+
"default": {
|
|
20
|
+
"types": "./types/transformers.d.ts",
|
|
21
|
+
"default": "./src/transformers.js"
|
|
22
|
+
}
|
|
14
23
|
},
|
|
15
24
|
"scripts": {
|
|
16
25
|
"format": "prettier --write .",
|
|
@@ -48,9 +57,9 @@
|
|
|
48
57
|
"homepage": "https://github.com/xenova/transformers.js#readme",
|
|
49
58
|
"dependencies": {
|
|
50
59
|
"@huggingface/jinja": "^0.3.0",
|
|
51
|
-
"onnxruntime-node": "1.
|
|
52
|
-
"onnxruntime-web": "1.
|
|
53
|
-
"sharp": "^0.33.
|
|
60
|
+
"onnxruntime-node": "1.19.0",
|
|
61
|
+
"onnxruntime-web": "1.20.0-dev.20240827-1d059b8702",
|
|
62
|
+
"sharp": "^0.33.5"
|
|
54
63
|
},
|
|
55
64
|
"devDependencies": {
|
|
56
65
|
"@types/jest": "^29.5.1",
|
package/src/backends/onnx.js
CHANGED
|
@@ -25,23 +25,74 @@ import * as ONNX_WEB from 'onnxruntime-web/webgpu';
|
|
|
25
25
|
|
|
26
26
|
export { Tensor } from 'onnxruntime-common';
|
|
27
27
|
|
|
28
|
-
/**
|
|
29
|
-
|
|
28
|
+
/**
|
|
29
|
+
* @typedef {import('onnxruntime-common').InferenceSession.ExecutionProviderConfig} ONNXExecutionProviders
|
|
30
|
+
*/
|
|
31
|
+
|
|
32
|
+
/** @type {Record<import("../utils/devices.js").DeviceType, ONNXExecutionProviders>} */
|
|
33
|
+
const DEVICE_TO_EXECUTION_PROVIDER_MAPPING = Object.freeze({
|
|
34
|
+
auto: null, // Auto-detect based on device and environment
|
|
35
|
+
gpu: null, // Auto-detect GPU
|
|
36
|
+
cpu: 'cpu', // CPU
|
|
37
|
+
wasm: 'wasm', // WebAssembly
|
|
38
|
+
webgpu: 'webgpu', // WebGPU
|
|
39
|
+
cuda: 'cuda', // CUDA
|
|
40
|
+
dml: 'dml', // DirectML
|
|
41
|
+
|
|
42
|
+
webnn: { name: 'webnn', deviceType: 'cpu' }, // WebNN (default)
|
|
43
|
+
'webnn-npu': { name: 'webnn', deviceType: 'npu' }, // WebNN NPU
|
|
44
|
+
'webnn-gpu': { name: 'webnn', deviceType: 'gpu' }, // WebNN GPU
|
|
45
|
+
'webnn-cpu': { name: 'webnn', deviceType: 'cpu' }, // WebNN CPU
|
|
46
|
+
});
|
|
47
|
+
|
|
48
|
+
/**
|
|
49
|
+
* The list of supported devices, sorted by priority/performance.
|
|
50
|
+
* @type {import("../utils/devices.js").DeviceType[]}
|
|
51
|
+
*/
|
|
52
|
+
const supportedDevices = [];
|
|
30
53
|
|
|
31
|
-
/** @type {
|
|
32
|
-
let
|
|
54
|
+
/** @type {ONNXExecutionProviders[]} */
|
|
55
|
+
let defaultDevices;
|
|
33
56
|
let ONNX;
|
|
34
57
|
if (apis.IS_NODE_ENV) {
|
|
35
58
|
ONNX = ONNX_NODE.default ?? ONNX_NODE;
|
|
36
|
-
|
|
37
|
-
|
|
59
|
+
|
|
60
|
+
// Updated as of ONNX Runtime 1.18.0
|
|
61
|
+
// The following table lists the supported versions of ONNX Runtime Node.js binding provided with pre-built binaries.
|
|
62
|
+
// | EPs/Platforms | Windows x64 | Windows arm64 | Linux x64 | Linux arm64 | MacOS x64 | MacOS arm64 |
|
|
63
|
+
// | ------------- | ----------- | ------------- | ----------------- | ----------- | --------- | ----------- |
|
|
64
|
+
// | CPU | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
|
|
65
|
+
// | DirectML | ✔️ | ✔️ | ❌ | ❌ | ❌ | ❌ |
|
|
66
|
+
// | CUDA | ❌ | ❌ | ✔️ (CUDA v11.8) | ❌ | ❌ | ❌ |
|
|
67
|
+
switch (process.platform) {
|
|
68
|
+
case 'win32': // Windows x64 and Windows arm64
|
|
69
|
+
supportedDevices.push('dml');
|
|
70
|
+
break;
|
|
71
|
+
case 'linux': // Linux x64 and Linux arm64
|
|
72
|
+
if (process.arch === 'x64') {
|
|
73
|
+
supportedDevices.push('cuda');
|
|
74
|
+
}
|
|
75
|
+
break;
|
|
76
|
+
case 'darwin': // MacOS x64 and MacOS arm64
|
|
77
|
+
break;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
supportedDevices.push('cpu');
|
|
81
|
+
defaultDevices = ['cpu'];
|
|
38
82
|
} else {
|
|
39
83
|
ONNX = ONNX_WEB;
|
|
84
|
+
|
|
85
|
+
if (apis.IS_WEBNN_AVAILABLE) {
|
|
86
|
+
// TODO: Only push supported providers (depending on available hardware)
|
|
87
|
+
supportedDevices.push('webnn-npu', 'webnn-gpu', 'webnn-cpu', 'webnn');
|
|
88
|
+
}
|
|
89
|
+
|
|
40
90
|
if (apis.IS_WEBGPU_AVAILABLE) {
|
|
41
|
-
|
|
91
|
+
supportedDevices.push('webgpu');
|
|
42
92
|
}
|
|
43
|
-
|
|
44
|
-
|
|
93
|
+
|
|
94
|
+
supportedDevices.push('wasm');
|
|
95
|
+
defaultDevices = ['wasm'];
|
|
45
96
|
}
|
|
46
97
|
|
|
47
98
|
// @ts-ignore
|
|
@@ -49,19 +100,28 @@ const InferenceSession = ONNX.InferenceSession;
|
|
|
49
100
|
|
|
50
101
|
/**
|
|
51
102
|
* Map a device to the execution providers to use for the given device.
|
|
52
|
-
* @param {import("../utils/devices.js").DeviceType} [device=null] (Optional) The device to run the inference on.
|
|
53
|
-
* @returns {
|
|
103
|
+
* @param {import("../utils/devices.js").DeviceType|"auto"|null} [device=null] (Optional) The device to run the inference on.
|
|
104
|
+
* @returns {ONNXExecutionProviders[]} The execution providers to use for the given device.
|
|
54
105
|
*/
|
|
55
|
-
export function deviceToExecutionProviders(device) {
|
|
56
|
-
//
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
106
|
+
export function deviceToExecutionProviders(device = null) {
|
|
107
|
+
// Use the default execution providers if the user hasn't specified anything
|
|
108
|
+
if (!device) return defaultDevices;
|
|
109
|
+
|
|
110
|
+
// Handle overloaded cases
|
|
111
|
+
switch (device) {
|
|
112
|
+
case "auto":
|
|
113
|
+
return supportedDevices;
|
|
114
|
+
case "gpu":
|
|
115
|
+
return supportedDevices.filter(x =>
|
|
116
|
+
["webgpu", "cuda", "dml", "webnn-gpu"].includes(x),
|
|
117
|
+
);
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
if (supportedDevices.includes(device)) {
|
|
121
|
+
return [DEVICE_TO_EXECUTION_PROVIDER_MAPPING[device] ?? device];
|
|
63
122
|
}
|
|
64
|
-
|
|
123
|
+
|
|
124
|
+
throw new Error(`Unsupported device: "${device}". Should be one of: ${supportedDevices.join(', ')}.`)
|
|
65
125
|
}
|
|
66
126
|
|
|
67
127
|
|
|
@@ -76,7 +136,7 @@ let wasmInitPromise = null;
|
|
|
76
136
|
/**
|
|
77
137
|
* Create an ONNX inference session.
|
|
78
138
|
* @param {Uint8Array} buffer The ONNX model buffer.
|
|
79
|
-
* @param {
|
|
139
|
+
* @param {import('onnxruntime-common').InferenceSession.SessionOptions} session_options ONNX inference session options.
|
|
80
140
|
* @returns {Promise<import('onnxruntime-common').InferenceSession>} The ONNX inference session.
|
|
81
141
|
*/
|
|
82
142
|
export async function createInferenceSession(buffer, session_options) {
|
|
@@ -100,6 +160,7 @@ export function isONNXTensor(x) {
|
|
|
100
160
|
return x instanceof ONNX.Tensor;
|
|
101
161
|
}
|
|
102
162
|
|
|
163
|
+
/** @type {import('onnxruntime-common').Env} */
|
|
103
164
|
// @ts-ignore
|
|
104
165
|
const ONNX_ENV = ONNX?.env;
|
|
105
166
|
if (ONNX_ENV?.wasm) {
|
|
@@ -109,29 +170,19 @@ if (ONNX_ENV?.wasm) {
|
|
|
109
170
|
// https://onnxruntime.ai/docs/api/js/interfaces/Env.WebAssemblyFlags.html#wasmPaths
|
|
110
171
|
// We use remote wasm files by default to make it easier for newer users.
|
|
111
172
|
// In practice, users should probably self-host the necessary .wasm files.
|
|
112
|
-
|
|
173
|
+
ONNX_ENV.wasm.wasmPaths = `https://cdn.jsdelivr.net/npm/@huggingface/transformers@${env.version}/dist/`;
|
|
113
174
|
|
|
114
175
|
// TODO: Add support for loading WASM files from cached buffer when we upgrade to onnxruntime-web@1.19.0
|
|
115
176
|
// https://github.com/microsoft/onnxruntime/pull/21534
|
|
116
177
|
|
|
117
|
-
//
|
|
118
|
-
//
|
|
119
|
-
ONNX_ENV.wasm.proxy =
|
|
178
|
+
// Users may wish to proxy the WASM backend to prevent the UI from freezing,
|
|
179
|
+
// However, this is not necessary when using WebGPU, so we default to false.
|
|
180
|
+
ONNX_ENV.wasm.proxy = false;
|
|
120
181
|
|
|
121
182
|
// https://developer.mozilla.org/en-US/docs/Web/API/crossOriginIsolated
|
|
122
183
|
if (typeof crossOriginIsolated === 'undefined' || !crossOriginIsolated) {
|
|
123
184
|
ONNX_ENV.wasm.numThreads = 1;
|
|
124
185
|
}
|
|
125
|
-
|
|
126
|
-
// Running in a browser-environment
|
|
127
|
-
// TODO: Check if 1.17.1 fixes this issue.
|
|
128
|
-
// SIMD for WebAssembly does not operate correctly in some recent versions of iOS (16.4.x).
|
|
129
|
-
// As a temporary fix, we disable it for now.
|
|
130
|
-
// For more information, see: https://github.com/microsoft/onnxruntime/issues/15644
|
|
131
|
-
const isIOS = typeof navigator !== 'undefined' && /iP(hone|od|ad).+16_4.+AppleWebKit/.test(navigator.userAgent);
|
|
132
|
-
if (isIOS) {
|
|
133
|
-
ONNX_ENV.wasm.simd = false;
|
|
134
|
-
}
|
|
135
186
|
}
|
|
136
187
|
|
|
137
188
|
if (ONNX_ENV?.webgpu) {
|
package/src/env.js
CHANGED
|
@@ -26,13 +26,14 @@ import fs from 'fs';
|
|
|
26
26
|
import path from 'path';
|
|
27
27
|
import url from 'url';
|
|
28
28
|
|
|
29
|
-
const VERSION = '3.0.0-alpha.
|
|
29
|
+
const VERSION = '3.0.0-alpha.10';
|
|
30
30
|
|
|
31
31
|
// Check if various APIs are available (depends on environment)
|
|
32
32
|
const IS_BROWSER_ENV = typeof self !== 'undefined';
|
|
33
33
|
const IS_WEBWORKER_ENV = IS_BROWSER_ENV && self.constructor.name === 'DedicatedWorkerGlobalScope';
|
|
34
34
|
const IS_WEB_CACHE_AVAILABLE = IS_BROWSER_ENV && 'caches' in self;
|
|
35
35
|
const IS_WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator;
|
|
36
|
+
const IS_WEBNN_AVAILABLE = typeof navigator !== 'undefined' && 'ml' in navigator;
|
|
36
37
|
|
|
37
38
|
const IS_PROCESS_AVAILABLE = typeof process !== 'undefined';
|
|
38
39
|
const IS_NODE_ENV = IS_PROCESS_AVAILABLE && process?.release?.name === 'node';
|
|
@@ -55,6 +56,9 @@ export const apis = Object.freeze({
|
|
|
55
56
|
/** Whether the WebGPU API is available */
|
|
56
57
|
IS_WEBGPU_AVAILABLE,
|
|
57
58
|
|
|
59
|
+
/** Whether the WebNN API is available */
|
|
60
|
+
IS_WEBNN_AVAILABLE,
|
|
61
|
+
|
|
58
62
|
/** Whether the Node.js process API is available */
|
|
59
63
|
IS_PROCESS_AVAILABLE,
|
|
60
64
|
|
|
@@ -88,7 +92,7 @@ const localModelPath = RUNNING_LOCALLY
|
|
|
88
92
|
* Global variable given visible to users to control execution. This provides users a simple way to configure Transformers.js.
|
|
89
93
|
* @typedef {Object} TransformersEnvironment
|
|
90
94
|
* @property {string} version This version of Transformers.js.
|
|
91
|
-
* @property {
|
|
95
|
+
* @property {{onnx: Partial<import('onnxruntime-common').Env>}} backends Expose environment variables of different backends,
|
|
92
96
|
* allowing users to set these variables if they want to.
|
|
93
97
|
* @property {boolean} allowRemoteModels Whether to allow loading of remote files, defaults to `true`.
|
|
94
98
|
* If set to `false`, it will have the same effect as setting `local_files_only=true` when loading pipelines, models, tokenizers, processors, etc.
|
|
@@ -115,12 +119,8 @@ export const env = {
|
|
|
115
119
|
backends: {
|
|
116
120
|
// onnxruntime-web/onnxruntime-node
|
|
117
121
|
onnx: {},
|
|
118
|
-
|
|
119
|
-
// TensorFlow.js
|
|
120
|
-
tfjs: {},
|
|
121
122
|
},
|
|
122
123
|
|
|
123
|
-
|
|
124
124
|
/////////////////// Model settings ///////////////////
|
|
125
125
|
allowRemoteModels: true,
|
|
126
126
|
remoteHost: 'https://huggingface.co/',
|
|
@@ -156,9 +156,9 @@ export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
|
|
|
156
156
|
_call(input_ids, logits) {
|
|
157
157
|
for (let i = 0; i < input_ids.length; ++i) {
|
|
158
158
|
if (input_ids[i].length === 1) {
|
|
159
|
-
const
|
|
160
|
-
|
|
161
|
-
|
|
159
|
+
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
|
|
160
|
+
batch_logits_data.fill(-Infinity);
|
|
161
|
+
batch_logits_data[this.bos_token_id] = 0;
|
|
162
162
|
}
|
|
163
163
|
}
|
|
164
164
|
return logits;
|
|
@@ -189,11 +189,10 @@ export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
|
|
|
189
189
|
_call(input_ids, logits) {
|
|
190
190
|
for (let i = 0; i < input_ids.length; ++i) {
|
|
191
191
|
if (input_ids[i].length === this.max_length - 1) {
|
|
192
|
-
const
|
|
193
|
-
|
|
194
|
-
|
|
192
|
+
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
|
|
193
|
+
batch_logits_data.fill(-Infinity);
|
|
195
194
|
for (const eos_token of this.eos_token_id) {
|
|
196
|
-
|
|
195
|
+
batch_logits_data[eos_token] = 0;
|
|
197
196
|
}
|
|
198
197
|
}
|
|
199
198
|
}
|
|
@@ -227,9 +226,9 @@ export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor {
|
|
|
227
226
|
_call(input_ids, logits) {
|
|
228
227
|
for (let i = 0; i < input_ids.length; ++i) {
|
|
229
228
|
if (input_ids[i].length === this.begin_index) {
|
|
230
|
-
const
|
|
229
|
+
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
|
|
231
230
|
for (const token_id of this.begin_suppress_tokens) {
|
|
232
|
-
|
|
231
|
+
batch_logits_data[token_id] = -Infinity;
|
|
233
232
|
}
|
|
234
233
|
}
|
|
235
234
|
}
|
|
@@ -271,15 +270,14 @@ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
|
|
|
271
270
|
*/
|
|
272
271
|
_call(input_ids, logits) {
|
|
273
272
|
for (let i = 0; i < input_ids.length; ++i) {
|
|
274
|
-
const
|
|
275
|
-
const logitsData = /** @type {Float32Array} */(batch_logits.data);
|
|
273
|
+
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
|
|
276
274
|
|
|
277
275
|
// suppress <|notimestamps|> which is handled by without_timestamps
|
|
278
|
-
|
|
276
|
+
batch_logits_data[this.no_timestamps_token_id] = -Infinity;
|
|
279
277
|
|
|
280
278
|
if (input_ids[i].length === this.begin_index - 1) {
|
|
281
|
-
|
|
282
|
-
|
|
279
|
+
batch_logits_data.fill(-Infinity);
|
|
280
|
+
batch_logits_data[this.timestamp_begin] = 0;
|
|
283
281
|
continue;
|
|
284
282
|
}
|
|
285
283
|
|
|
@@ -290,25 +288,25 @@ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
|
|
|
290
288
|
|
|
291
289
|
if (last_was_timestamp) {
|
|
292
290
|
if (penultimate_was_timestamp) { // has to be non-timestamp
|
|
293
|
-
|
|
291
|
+
batch_logits_data.subarray(this.timestamp_begin).fill(-Infinity);
|
|
294
292
|
} else { // cannot be normal text tokens
|
|
295
|
-
|
|
293
|
+
batch_logits_data.subarray(0, this.eos_token_id).fill(-Infinity);
|
|
296
294
|
}
|
|
297
295
|
}
|
|
298
296
|
|
|
299
297
|
// apply the `max_initial_timestamp` option
|
|
300
298
|
if (input_ids[i].length === this.begin_index && this.max_initial_timestamp_index !== null) {
|
|
301
299
|
const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index;
|
|
302
|
-
|
|
300
|
+
batch_logits_data.subarray(last_allowed + 1).fill(-Infinity);
|
|
303
301
|
}
|
|
304
302
|
|
|
305
303
|
// if sum of probability over timestamps is above any other token, sample timestamp
|
|
306
|
-
const logprobs = log_softmax(
|
|
304
|
+
const logprobs = log_softmax(batch_logits_data);
|
|
307
305
|
const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b));
|
|
308
306
|
const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0];
|
|
309
307
|
|
|
310
308
|
if (timestamp_logprob > max_text_token_logprob) {
|
|
311
|
-
|
|
309
|
+
batch_logits_data.subarray(0, this.timestamp_begin).fill(-Infinity);
|
|
312
310
|
}
|
|
313
311
|
}
|
|
314
312
|
|
|
@@ -397,10 +395,10 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
|
|
|
397
395
|
*/
|
|
398
396
|
_call(input_ids, logits) {
|
|
399
397
|
for (let i = 0; i < input_ids.length; ++i) {
|
|
400
|
-
const
|
|
398
|
+
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
|
|
401
399
|
const bannedTokens = this.calcBannedNgramTokens(input_ids[i]);
|
|
402
400
|
for (const token of bannedTokens) {
|
|
403
|
-
|
|
401
|
+
batch_logits_data[token] = -Infinity;
|
|
404
402
|
}
|
|
405
403
|
}
|
|
406
404
|
return logits;
|
|
@@ -432,13 +430,13 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
|
|
|
432
430
|
// many times in the output will be penalised more.
|
|
433
431
|
|
|
434
432
|
for (let i = 0; i < input_ids.length; ++i) {
|
|
435
|
-
const
|
|
436
|
-
|
|
433
|
+
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
|
|
437
434
|
for (const input_id of input_ids[i]) {
|
|
438
|
-
|
|
439
|
-
|
|
435
|
+
const token = Number(input_id);
|
|
436
|
+
if (batch_logits_data[token] < 0) {
|
|
437
|
+
batch_logits_data[token] *= this.penalty;
|
|
440
438
|
} else {
|
|
441
|
-
|
|
439
|
+
batch_logits_data[token] /= this.penalty;
|
|
442
440
|
}
|
|
443
441
|
}
|
|
444
442
|
}
|
|
@@ -471,9 +469,10 @@ export class MinLengthLogitsProcessor extends LogitsProcessor {
|
|
|
471
469
|
_call(input_ids, logits) {
|
|
472
470
|
for (let i = 0; i < input_ids.length; ++i) {
|
|
473
471
|
if (input_ids[i].length < this.min_length) {
|
|
474
|
-
const
|
|
472
|
+
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
|
|
473
|
+
|
|
475
474
|
for (const eos_token of this.eos_token_id) {
|
|
476
|
-
|
|
475
|
+
batch_logits_data[eos_token] = -Infinity;
|
|
477
476
|
}
|
|
478
477
|
}
|
|
479
478
|
}
|
|
@@ -509,9 +508,10 @@ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
|
|
|
509
508
|
for (let i = 0; i < input_ids.length; ++i) {
|
|
510
509
|
const new_tokens_length = input_ids[i].length - this.prompt_length_to_skip;
|
|
511
510
|
if (new_tokens_length < this.min_new_tokens) {
|
|
512
|
-
const
|
|
511
|
+
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
|
|
512
|
+
|
|
513
513
|
for (const eos_token of this.eos_token_id) {
|
|
514
|
-
|
|
514
|
+
batch_logits_data[eos_token] = -Infinity;
|
|
515
515
|
}
|
|
516
516
|
}
|
|
517
517
|
}
|
|
@@ -539,7 +539,8 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor {
|
|
|
539
539
|
*/
|
|
540
540
|
_call(input_ids, logits) {
|
|
541
541
|
for (let i = 0; i < input_ids.length; ++i) {
|
|
542
|
-
const
|
|
542
|
+
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
|
|
543
|
+
|
|
543
544
|
for (const bad_word_ids of this.bad_words_ids) {
|
|
544
545
|
// Whether to modify the logits of the last token in the bad word id sequence
|
|
545
546
|
let mark = true;
|
|
@@ -548,14 +549,16 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor {
|
|
|
548
549
|
// then we set the logits of the last bad word id to -Infinity.
|
|
549
550
|
for (let i = 1; i <= bad_word_ids.length - 1 && bad_word_ids.length < input_ids[i].length; ++i) {
|
|
550
551
|
|
|
551
|
-
|
|
552
|
+
// NOTE: We use != instead of !== to compare bigint and number
|
|
553
|
+
// @ts-ignore
|
|
554
|
+
if (bad_word_ids.at(-i - 1) != input_ids[i].at(-i)) {
|
|
552
555
|
// We have found a mismatch
|
|
553
556
|
mark = false;
|
|
554
557
|
break;
|
|
555
558
|
}
|
|
556
559
|
}
|
|
557
560
|
if (mark) {
|
|
558
|
-
|
|
561
|
+
batch_logits_data[bad_word_ids.at(-1)] = -Infinity;
|
|
559
562
|
}
|
|
560
563
|
}
|
|
561
564
|
}
|
|
@@ -650,9 +653,9 @@ export class TemperatureLogitsWarper extends LogitsWarper {
|
|
|
650
653
|
* @returns {Object} The processed logits.
|
|
651
654
|
*/
|
|
652
655
|
_call(input_ids, logits) {
|
|
653
|
-
const
|
|
654
|
-
for (let i = 0; i <
|
|
655
|
-
|
|
656
|
+
const batch_logits_data = /** @type {Float32Array} */(logits.data);
|
|
657
|
+
for (let i = 0; i < batch_logits_data.length; ++i) {
|
|
658
|
+
batch_logits_data[i] /= this.temperature;
|
|
656
659
|
}
|
|
657
660
|
return logits;
|
|
658
661
|
}
|
|
@@ -65,14 +65,14 @@ export class TextStreamer extends BaseStreamer {
|
|
|
65
65
|
throw Error('TextStreamer only supports batch size of 1');
|
|
66
66
|
}
|
|
67
67
|
|
|
68
|
-
const tokens = value[0];
|
|
69
|
-
this.token_callback_function?.(tokens)
|
|
70
|
-
|
|
71
68
|
if (this.skip_prompt && this.next_tokens_are_prompt) {
|
|
72
69
|
this.next_tokens_are_prompt = false;
|
|
73
70
|
return;
|
|
74
71
|
}
|
|
75
72
|
|
|
73
|
+
const tokens = value[0];
|
|
74
|
+
this.token_callback_function?.(tokens)
|
|
75
|
+
|
|
76
76
|
// Add the new token to the cache and decodes the entire thing.
|
|
77
77
|
this.token_cache = mergeArrays(this.token_cache, tokens);
|
|
78
78
|
const text = this.tokenizer.decode(this.token_cache, this.decode_kwargs);
|
package/src/models.js
CHANGED
|
@@ -157,9 +157,10 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
|
157
157
|
}
|
|
158
158
|
|
|
159
159
|
// If the device is not specified, we use the default (supported) execution providers.
|
|
160
|
-
const
|
|
161
|
-
|
|
160
|
+
const selectedDevice = /** @type {import("./utils/devices.js").DeviceType} */(
|
|
161
|
+
device ?? (apis.IS_NODE_ENV ? 'cpu' : 'wasm')
|
|
162
162
|
);
|
|
163
|
+
const executionProviders = deviceToExecutionProviders(selectedDevice);
|
|
163
164
|
|
|
164
165
|
// If options.dtype is specified, we use it to choose the suffix for the model file.
|
|
165
166
|
// Otherwise, we use the default dtype for the device.
|
|
@@ -168,19 +169,21 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
|
168
169
|
if (dtype && dtype.hasOwnProperty(fileName)) {
|
|
169
170
|
dtype = dtype[fileName];
|
|
170
171
|
} else {
|
|
171
|
-
dtype = DEFAULT_DEVICE_DTYPE_MAPPING[
|
|
172
|
-
console.warn(`dtype not specified for "${fileName}". Using the default dtype for this device (${
|
|
172
|
+
dtype = DEFAULT_DEVICE_DTYPE_MAPPING[selectedDevice] ?? DATA_TYPES.fp32;
|
|
173
|
+
console.warn(`dtype not specified for "${fileName}". Using the default dtype (${dtype}) for this device (${selectedDevice}).`);
|
|
173
174
|
}
|
|
174
175
|
}
|
|
175
176
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
throw new Error(`
|
|
177
|
+
const selectedDtype = /** @type {import("./utils/dtypes.js").DataType} */(dtype);
|
|
178
|
+
|
|
179
|
+
if (!DEFAULT_DTYPE_SUFFIX_MAPPING.hasOwnProperty(selectedDtype)) {
|
|
180
|
+
throw new Error(`Invalid dtype: ${selectedDtype}. Should be one of: ${Object.keys(DATA_TYPES).join(', ')}`);
|
|
181
|
+
} else if (selectedDtype === DATA_TYPES.fp16 && selectedDevice === 'webgpu' && !(await isWebGpuFp16Supported())) {
|
|
182
|
+
throw new Error(`The device (${selectedDevice}) does not support fp16.`);
|
|
180
183
|
}
|
|
181
184
|
|
|
182
185
|
// Construct the model file name
|
|
183
|
-
const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[
|
|
186
|
+
const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype];
|
|
184
187
|
const modelFileName = `${options.subfolder ?? ''}/${fileName}${suffix}.onnx`;
|
|
185
188
|
|
|
186
189
|
const session_options = { ...options.session_options } ?? {};
|
|
@@ -227,7 +230,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
|
227
230
|
session_options.externalData = await Promise.all(externalDataPromises);
|
|
228
231
|
}
|
|
229
232
|
|
|
230
|
-
if (
|
|
233
|
+
if (selectedDevice === 'webgpu') {
|
|
231
234
|
const shapes = getKeyValueShapes(options.config, {
|
|
232
235
|
prefix: 'present',
|
|
233
236
|
});
|
|
@@ -4565,6 +4568,14 @@ export class DepthAnythingForDepthEstimation extends DepthAnythingPreTrainedMode
|
|
|
4565
4568
|
//////////////////////////////////////////////////
|
|
4566
4569
|
|
|
4567
4570
|
|
|
4571
|
+
//////////////////////////////////////////////////
|
|
4572
|
+
export class SapiensPreTrainedModel extends PreTrainedModel { }
|
|
4573
|
+
export class SapiensForSemanticSegmentation extends SapiensPreTrainedModel { }
|
|
4574
|
+
export class SapiensForDepthEstimation extends SapiensPreTrainedModel { }
|
|
4575
|
+
export class SapiensForNormalEstimation extends SapiensPreTrainedModel { }
|
|
4576
|
+
//////////////////////////////////////////////////
|
|
4577
|
+
|
|
4578
|
+
|
|
4568
4579
|
//////////////////////////////////////////////////
|
|
4569
4580
|
export class GLPNPreTrainedModel extends PreTrainedModel { }
|
|
4570
4581
|
|
|
@@ -6535,6 +6546,7 @@ const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([
|
|
|
6535
6546
|
|
|
6536
6547
|
const MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = new Map([
|
|
6537
6548
|
['segformer', ['SegformerForSemanticSegmentation', SegformerForSemanticSegmentation]],
|
|
6549
|
+
['sapiens', ['SapiensForSemanticSegmentation', SapiensForSemanticSegmentation]],
|
|
6538
6550
|
]);
|
|
6539
6551
|
|
|
6540
6552
|
const MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = new Map([
|
|
@@ -6583,6 +6595,7 @@ const MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = new Map([
|
|
|
6583
6595
|
['dpt', ['DPTForDepthEstimation', DPTForDepthEstimation]],
|
|
6584
6596
|
['depth_anything', ['DepthAnythingForDepthEstimation', DepthAnythingForDepthEstimation]],
|
|
6585
6597
|
['glpn', ['GLPNForDepthEstimation', GLPNForDepthEstimation]],
|
|
6598
|
+
['sapiens', ['SapiensForDepthEstimation', SapiensForDepthEstimation]],
|
|
6586
6599
|
])
|
|
6587
6600
|
|
|
6588
6601
|
// NOTE: This is custom to Transformers.js, and is necessary because certain models
|