@huggingface/transformers 3.0.0-alpha.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +202 -0
- package/README.md +376 -0
- package/dist/ort-wasm-simd-threaded.jsep.wasm +0 -0
- package/dist/transformers.cjs +30741 -0
- package/dist/transformers.cjs.map +1 -0
- package/dist/transformers.js +33858 -0
- package/dist/transformers.js.map +1 -0
- package/dist/transformers.min.cjs +173 -0
- package/dist/transformers.min.cjs.map +1 -0
- package/dist/transformers.min.js +231 -0
- package/dist/transformers.min.js.map +1 -0
- package/package.json +92 -0
- package/src/backends/onnx.js +151 -0
- package/src/configs.js +360 -0
- package/src/env.js +152 -0
- package/src/generation/configuration_utils.js +381 -0
- package/src/generation/logits_process.js +716 -0
- package/src/generation/logits_sampler.js +204 -0
- package/src/generation/parameters.js +35 -0
- package/src/generation/stopping_criteria.js +156 -0
- package/src/generation/streamers.js +212 -0
- package/src/models/whisper/common_whisper.js +151 -0
- package/src/models/whisper/generation_whisper.js +89 -0
- package/src/models.js +7028 -0
- package/src/ops/registry.js +92 -0
- package/src/pipelines.js +3341 -0
- package/src/processors.js +2614 -0
- package/src/tokenizers.js +4395 -0
- package/src/transformers.js +28 -0
- package/src/utils/audio.js +704 -0
- package/src/utils/constants.js +2 -0
- package/src/utils/core.js +149 -0
- package/src/utils/data-structures.js +445 -0
- package/src/utils/devices.js +11 -0
- package/src/utils/dtypes.js +62 -0
- package/src/utils/generic.js +35 -0
- package/src/utils/hub.js +671 -0
- package/src/utils/image.js +745 -0
- package/src/utils/maths.js +1050 -0
- package/src/utils/tensor.js +1378 -0
- package/types/backends/onnx.d.ts +26 -0
- package/types/backends/onnx.d.ts.map +1 -0
- package/types/configs.d.ts +59 -0
- package/types/configs.d.ts.map +1 -0
- package/types/env.d.ts +106 -0
- package/types/env.d.ts.map +1 -0
- package/types/generation/configuration_utils.d.ts +320 -0
- package/types/generation/configuration_utils.d.ts.map +1 -0
- package/types/generation/logits_process.d.ts +354 -0
- package/types/generation/logits_process.d.ts.map +1 -0
- package/types/generation/logits_sampler.d.ts +51 -0
- package/types/generation/logits_sampler.d.ts.map +1 -0
- package/types/generation/parameters.d.ts +47 -0
- package/types/generation/parameters.d.ts.map +1 -0
- package/types/generation/stopping_criteria.d.ts +81 -0
- package/types/generation/stopping_criteria.d.ts.map +1 -0
- package/types/generation/streamers.d.ts +81 -0
- package/types/generation/streamers.d.ts.map +1 -0
- package/types/models/whisper/common_whisper.d.ts +8 -0
- package/types/models/whisper/common_whisper.d.ts.map +1 -0
- package/types/models/whisper/generation_whisper.d.ts +76 -0
- package/types/models/whisper/generation_whisper.d.ts.map +1 -0
- package/types/models.d.ts +3845 -0
- package/types/models.d.ts.map +1 -0
- package/types/ops/registry.d.ts +11 -0
- package/types/ops/registry.d.ts.map +1 -0
- package/types/pipelines.d.ts +2403 -0
- package/types/pipelines.d.ts.map +1 -0
- package/types/processors.d.ts +917 -0
- package/types/processors.d.ts.map +1 -0
- package/types/tokenizers.d.ts +999 -0
- package/types/tokenizers.d.ts.map +1 -0
- package/types/transformers.d.ts +13 -0
- package/types/transformers.d.ts.map +1 -0
- package/types/utils/audio.d.ts +130 -0
- package/types/utils/audio.d.ts.map +1 -0
- package/types/utils/constants.d.ts +2 -0
- package/types/utils/constants.d.ts.map +1 -0
- package/types/utils/core.d.ts +91 -0
- package/types/utils/core.d.ts.map +1 -0
- package/types/utils/data-structures.d.ts +236 -0
- package/types/utils/data-structures.d.ts.map +1 -0
- package/types/utils/devices.d.ts +8 -0
- package/types/utils/devices.d.ts.map +1 -0
- package/types/utils/dtypes.d.ts +22 -0
- package/types/utils/dtypes.d.ts.map +1 -0
- package/types/utils/generic.d.ts +11 -0
- package/types/utils/generic.d.ts.map +1 -0
- package/types/utils/hub.d.ts +191 -0
- package/types/utils/hub.d.ts.map +1 -0
- package/types/utils/image.d.ts +119 -0
- package/types/utils/image.d.ts.map +1 -0
- package/types/utils/maths.d.ts +280 -0
- package/types/utils/maths.d.ts.map +1 -0
- package/types/utils/tensor.d.ts +392 -0
- package/types/utils/tensor.d.ts.map +1 -0
package/package.json
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
{
|
|
2
|
+
"name": "@huggingface/transformers",
|
|
3
|
+
"version": "3.0.0-alpha.0",
|
|
4
|
+
"description": "State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!",
|
|
5
|
+
"main": "./src/transformers.js",
|
|
6
|
+
"types": "./types/transformers.d.ts",
|
|
7
|
+
"type": "module",
|
|
8
|
+
"exports": {
|
|
9
|
+
"node": {
|
|
10
|
+
"import": "./dist/transformers.js",
|
|
11
|
+
"require": "./dist/transformers.cjs"
|
|
12
|
+
},
|
|
13
|
+
"default": "./src/transformers.js"
|
|
14
|
+
},
|
|
15
|
+
"scripts": {
|
|
16
|
+
"format": "prettier --write .",
|
|
17
|
+
"format:check": "prettier --check .",
|
|
18
|
+
"typegen": "tsc ./src/transformers.js --allowJs --declaration --emitDeclarationOnly --declarationMap --outDir types",
|
|
19
|
+
"dev": "webpack serve --no-client-overlay",
|
|
20
|
+
"build": "webpack && npm run typegen",
|
|
21
|
+
"generate-tests": "python -m tests.generate_tests",
|
|
22
|
+
"test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --verbose --maxConcurrency 1",
|
|
23
|
+
"readme": "python ./docs/scripts/build_readme.py",
|
|
24
|
+
"docs-api": "node ./docs/scripts/generate.js",
|
|
25
|
+
"docs-preview": "doc-builder preview transformers.js ./docs/source/ --not_python_module",
|
|
26
|
+
"docs-build": "doc-builder build transformers.js ./docs/source/ --not_python_module --build_dir ./docs/build/ --repo_owner xenova"
|
|
27
|
+
},
|
|
28
|
+
"repository": {
|
|
29
|
+
"type": "git",
|
|
30
|
+
"url": "git+https://github.com/xenova/transformers.js.git"
|
|
31
|
+
},
|
|
32
|
+
"keywords": [
|
|
33
|
+
"transformers",
|
|
34
|
+
"transformers.js",
|
|
35
|
+
"huggingface",
|
|
36
|
+
"hugging face",
|
|
37
|
+
"machine learning",
|
|
38
|
+
"deep learning",
|
|
39
|
+
"artificial intelligence",
|
|
40
|
+
"AI",
|
|
41
|
+
"ML"
|
|
42
|
+
],
|
|
43
|
+
"author": "Hugging Face",
|
|
44
|
+
"license": "Apache-2.0",
|
|
45
|
+
"bugs": {
|
|
46
|
+
"url": "https://github.com/xenova/transformers.js/issues"
|
|
47
|
+
},
|
|
48
|
+
"homepage": "https://github.com/xenova/transformers.js#readme",
|
|
49
|
+
"dependencies": {
|
|
50
|
+
"@huggingface/jinja": "^0.3.0",
|
|
51
|
+
"onnxruntime-node": "1.18.0",
|
|
52
|
+
"onnxruntime-web": "1.19.0-dev.20240804-ee2fe87e2d",
|
|
53
|
+
"sharp": "^0.33.2"
|
|
54
|
+
},
|
|
55
|
+
"devDependencies": {
|
|
56
|
+
"@types/jest": "^29.5.1",
|
|
57
|
+
"@webgpu/types": "^0.1.44",
|
|
58
|
+
"catharsis": "github:xenova/catharsis",
|
|
59
|
+
"jest": "^29.5.0",
|
|
60
|
+
"jest-environment-node": "^29.5.0",
|
|
61
|
+
"jsdoc-to-markdown": "^8.0.1",
|
|
62
|
+
"prettier": "3.3.3",
|
|
63
|
+
"typescript": "^5.2.2",
|
|
64
|
+
"wavefile": "^11.0.0",
|
|
65
|
+
"webpack": "^5.80.0",
|
|
66
|
+
"webpack-cli": "^5.0.2",
|
|
67
|
+
"webpack-dev-server": "^4.13.3"
|
|
68
|
+
},
|
|
69
|
+
"overrides": {
|
|
70
|
+
"semver": "^7.6.3",
|
|
71
|
+
"protobufjs": "^7.2.6"
|
|
72
|
+
},
|
|
73
|
+
"files": [
|
|
74
|
+
"src",
|
|
75
|
+
"dist",
|
|
76
|
+
"types",
|
|
77
|
+
"README.md",
|
|
78
|
+
"LICENSE"
|
|
79
|
+
],
|
|
80
|
+
"browser": {
|
|
81
|
+
"fs": false,
|
|
82
|
+
"path": false,
|
|
83
|
+
"url": false,
|
|
84
|
+
"sharp": false,
|
|
85
|
+
"onnxruntime-node": false
|
|
86
|
+
},
|
|
87
|
+
"publishConfig": {
|
|
88
|
+
"access": "public"
|
|
89
|
+
},
|
|
90
|
+
"jsdelivr": "./dist/transformers.min.js",
|
|
91
|
+
"unpkg": "./dist/transformers.min.js"
|
|
92
|
+
}
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file Handler file for choosing the correct version of ONNX Runtime, based on the environment.
|
|
3
|
+
* Ideally, we could import the `onnxruntime-web` and `onnxruntime-node` packages only when needed,
|
|
4
|
+
* but dynamic imports don't seem to work with the current webpack version and/or configuration.
|
|
5
|
+
* This is possibly due to the experimental nature of top-level await statements.
|
|
6
|
+
* So, we just import both packages, and use the appropriate one based on the environment:
|
|
7
|
+
* - When running in node, we use `onnxruntime-node`.
|
|
8
|
+
* - When running in the browser, we use `onnxruntime-web` (`onnxruntime-node` is not bundled).
|
|
9
|
+
*
|
|
10
|
+
* This module is not directly exported, but can be accessed through the environment variables:
|
|
11
|
+
* ```javascript
|
|
12
|
+
* import { env } from '@huggingface/transformers';
|
|
13
|
+
* console.log(env.backends.onnx);
|
|
14
|
+
* ```
|
|
15
|
+
*
|
|
16
|
+
* @module backends/onnx
|
|
17
|
+
*/
|
|
18
|
+
|
|
19
|
+
import { env, apis } from '../env.js';
|
|
20
|
+
|
|
21
|
+
// NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`.
|
|
22
|
+
// In either case, we select the default export if it exists, otherwise we use the named export.
|
|
23
|
+
import * as ONNX_NODE from 'onnxruntime-node';
|
|
24
|
+
import * as ONNX_WEB from 'onnxruntime-web/webgpu';
|
|
25
|
+
|
|
26
|
+
export { Tensor } from 'onnxruntime-common';
|
|
27
|
+
|
|
28
|
+
/** @type {import('../utils/devices.js').DeviceType[]} */
|
|
29
|
+
const supportedExecutionProviders = [];
|
|
30
|
+
|
|
31
|
+
/** @type {import('../utils/devices.js').DeviceType[]} */
|
|
32
|
+
let defaultExecutionProviders;
|
|
33
|
+
let ONNX;
|
|
34
|
+
if (apis.IS_NODE_ENV) {
|
|
35
|
+
ONNX = ONNX_NODE.default ?? ONNX_NODE;
|
|
36
|
+
supportedExecutionProviders.push('cpu');
|
|
37
|
+
defaultExecutionProviders = ['cpu'];
|
|
38
|
+
} else {
|
|
39
|
+
ONNX = ONNX_WEB;
|
|
40
|
+
if (apis.IS_WEBGPU_AVAILABLE) {
|
|
41
|
+
supportedExecutionProviders.push('webgpu');
|
|
42
|
+
}
|
|
43
|
+
supportedExecutionProviders.push('wasm');
|
|
44
|
+
defaultExecutionProviders = ['wasm'];
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
// @ts-ignore
|
|
48
|
+
const InferenceSession = ONNX.InferenceSession;
|
|
49
|
+
|
|
50
|
+
/**
|
|
51
|
+
* Map a device to the execution providers to use for the given device.
|
|
52
|
+
* @param {import("../utils/devices.js").DeviceType} [device=null] (Optional) The device to run the inference on.
|
|
53
|
+
* @returns {import("../utils/devices.js").DeviceType[]} The execution providers to use for the given device.
|
|
54
|
+
*/
|
|
55
|
+
export function deviceToExecutionProviders(device) {
|
|
56
|
+
// TODO: Use mapping from device to execution providers for overloaded devices (e.g., 'gpu' or 'cpu').
|
|
57
|
+
let executionProviders = defaultExecutionProviders;
|
|
58
|
+
if (device) { // User has specified a device
|
|
59
|
+
if (!supportedExecutionProviders.includes(device)) {
|
|
60
|
+
throw new Error(`Unsupported device: "${device}". Should be one of: ${supportedExecutionProviders.join(', ')}.`)
|
|
61
|
+
}
|
|
62
|
+
executionProviders = [device];
|
|
63
|
+
}
|
|
64
|
+
return executionProviders;
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
/**
|
|
69
|
+
* To prevent multiple calls to `initWasm()`, we store the first call in a Promise
|
|
70
|
+
* that is resolved when the first InferenceSession is created. Subsequent calls
|
|
71
|
+
* will wait for this Promise to resolve before creating their own InferenceSession.
|
|
72
|
+
* @type {Promise<any>|null}
|
|
73
|
+
*/
|
|
74
|
+
let wasmInitPromise = null;
|
|
75
|
+
|
|
76
|
+
/**
|
|
77
|
+
* Create an ONNX inference session.
|
|
78
|
+
* @param {Uint8Array} buffer The ONNX model buffer.
|
|
79
|
+
* @param {Object} session_options ONNX inference session options.
|
|
80
|
+
* @returns {Promise<import('onnxruntime-common').InferenceSession>} The ONNX inference session.
|
|
81
|
+
*/
|
|
82
|
+
export async function createInferenceSession(buffer, session_options) {
|
|
83
|
+
if (wasmInitPromise) {
|
|
84
|
+
// A previous session has already initialized the WASM runtime
|
|
85
|
+
// so we wait for it to resolve before creating this new session.
|
|
86
|
+
await wasmInitPromise;
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
const sessionPromise = InferenceSession.create(buffer, session_options);
|
|
90
|
+
wasmInitPromise ??= sessionPromise;
|
|
91
|
+
return await sessionPromise;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
/**
|
|
95
|
+
* Check if an object is an ONNX tensor.
|
|
96
|
+
* @param {any} x The object to check
|
|
97
|
+
* @returns {boolean} Whether the object is an ONNX tensor.
|
|
98
|
+
*/
|
|
99
|
+
export function isONNXTensor(x) {
|
|
100
|
+
return x instanceof ONNX.Tensor;
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
// @ts-ignore
|
|
104
|
+
const ONNX_ENV = ONNX?.env;
|
|
105
|
+
if (ONNX_ENV?.wasm) {
|
|
106
|
+
// Initialize wasm backend with suitable default settings.
|
|
107
|
+
|
|
108
|
+
// (Optional) Set path to wasm files. This is needed when running in a web worker.
|
|
109
|
+
// https://onnxruntime.ai/docs/api/js/interfaces/Env.WebAssemblyFlags.html#wasmPaths
|
|
110
|
+
// We use remote wasm files by default to make it easier for newer users.
|
|
111
|
+
// In practice, users should probably self-host the necessary .wasm files.
|
|
112
|
+
// ONNX_ENV.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.19.0-dev.20240804-ee2fe87e2d/dist/';
|
|
113
|
+
|
|
114
|
+
// TODO: Add support for loading WASM files from cached buffer when we upgrade to onnxruntime-web@1.19.0
|
|
115
|
+
// https://github.com/microsoft/onnxruntime/pull/21534
|
|
116
|
+
|
|
117
|
+
// Proxy the WASM backend to prevent the UI from freezing
|
|
118
|
+
// NOTE: This is only needed when running in a non-worker browser environment.
|
|
119
|
+
ONNX_ENV.wasm.proxy = !apis.IS_WEBWORKER_ENV;
|
|
120
|
+
|
|
121
|
+
// https://developer.mozilla.org/en-US/docs/Web/API/crossOriginIsolated
|
|
122
|
+
if (typeof crossOriginIsolated === 'undefined' || !crossOriginIsolated) {
|
|
123
|
+
ONNX_ENV.wasm.numThreads = 1;
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
// Running in a browser-environment
|
|
127
|
+
// TODO: Check if 1.17.1 fixes this issue.
|
|
128
|
+
// SIMD for WebAssembly does not operate correctly in some recent versions of iOS (16.4.x).
|
|
129
|
+
// As a temporary fix, we disable it for now.
|
|
130
|
+
// For more information, see: https://github.com/microsoft/onnxruntime/issues/15644
|
|
131
|
+
const isIOS = typeof navigator !== 'undefined' && /iP(hone|od|ad).+16_4.+AppleWebKit/.test(navigator.userAgent);
|
|
132
|
+
if (isIOS) {
|
|
133
|
+
ONNX_ENV.wasm.simd = false;
|
|
134
|
+
}
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
if (ONNX_ENV?.webgpu) {
|
|
138
|
+
ONNX_ENV.webgpu.powerPreference = 'high-performance';
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
/**
|
|
142
|
+
* Check if ONNX's WASM backend is being proxied.
|
|
143
|
+
* @returns {boolean} Whether ONNX's WASM backend is being proxied.
|
|
144
|
+
*/
|
|
145
|
+
export function isONNXProxy() {
|
|
146
|
+
// TODO: Update this when allowing non-WASM backends.
|
|
147
|
+
return ONNX_ENV?.wasm?.proxy;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
// Expose ONNX environment variables to `env.backends.onnx`
|
|
151
|
+
env.backends.onnx = ONNX_ENV;
|
package/src/configs.js
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
|
|
2
|
+
/**
|
|
3
|
+
* @file Helper module for using model configs. For more information, see the corresponding
|
|
4
|
+
* [Python documentation](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoConfig).
|
|
5
|
+
*
|
|
6
|
+
* **Example:** Load an `AutoConfig`.
|
|
7
|
+
*
|
|
8
|
+
* ```javascript
|
|
9
|
+
* import { AutoConfig } from '@huggingface/transformers';
|
|
10
|
+
* const config = await AutoConfig.from_pretrained('bert-base-uncased');
|
|
11
|
+
* console.log(config);
|
|
12
|
+
* // PretrainedConfig {
|
|
13
|
+
* // "model_type": "bert",
|
|
14
|
+
* // "is_encoder_decoder": false,
|
|
15
|
+
* // "architectures": [
|
|
16
|
+
* // "BertForMaskedLM"
|
|
17
|
+
* // ],
|
|
18
|
+
* // "vocab_size": 30522
|
|
19
|
+
* // "num_attention_heads": 12,
|
|
20
|
+
* // "num_hidden_layers": 12,
|
|
21
|
+
* // "hidden_size": 768,
|
|
22
|
+
* // "max_position_embeddings": 512,
|
|
23
|
+
* // ...
|
|
24
|
+
* // }
|
|
25
|
+
* ```
|
|
26
|
+
*
|
|
27
|
+
* @module configs
|
|
28
|
+
*/
|
|
29
|
+
|
|
30
|
+
import { pick } from './utils/core.js';
|
|
31
|
+
import {
|
|
32
|
+
getModelJSON,
|
|
33
|
+
} from './utils/hub.js';
|
|
34
|
+
|
|
35
|
+
/**
|
|
36
|
+
* @typedef {import('./utils/hub.js').PretrainedOptions} PretrainedOptions
|
|
37
|
+
*/
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
/**
|
|
41
|
+
* Loads a config from the specified path.
|
|
42
|
+
* @param {string} pretrained_model_name_or_path The path to the config directory.
|
|
43
|
+
* @param {PretrainedOptions} options Additional options for loading the config.
|
|
44
|
+
* @returns {Promise<Object>} A promise that resolves with information about the loaded config.
|
|
45
|
+
*/
|
|
46
|
+
async function loadConfig(pretrained_model_name_or_path, options) {
|
|
47
|
+
return await getModelJSON(pretrained_model_name_or_path, 'config.json', true, options);
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
/**
|
|
51
|
+
*
|
|
52
|
+
* @param {PretrainedConfig} config
|
|
53
|
+
* @returns {Object} The normalized configuration.
|
|
54
|
+
*/
|
|
55
|
+
function getNormalizedConfig(config) {
|
|
56
|
+
const mapping = {};
|
|
57
|
+
|
|
58
|
+
let init_normalized_config = {};
|
|
59
|
+
switch (config.model_type) {
|
|
60
|
+
// Sub-configs
|
|
61
|
+
case 'llava':
|
|
62
|
+
case 'paligemma':
|
|
63
|
+
case 'florence2':
|
|
64
|
+
init_normalized_config = getNormalizedConfig(config.text_config);
|
|
65
|
+
break;
|
|
66
|
+
case 'moondream1':
|
|
67
|
+
init_normalized_config = getNormalizedConfig(config.phi_config);
|
|
68
|
+
break;
|
|
69
|
+
case 'musicgen':
|
|
70
|
+
init_normalized_config = getNormalizedConfig(config.decoder);
|
|
71
|
+
break;
|
|
72
|
+
|
|
73
|
+
// Decoder-only models
|
|
74
|
+
case 'gpt2':
|
|
75
|
+
case 'gptj':
|
|
76
|
+
case 'codegen':
|
|
77
|
+
case 'gpt_bigcode':
|
|
78
|
+
mapping['num_heads'] = 'n_head';
|
|
79
|
+
mapping['num_layers'] = 'n_layer';
|
|
80
|
+
mapping['hidden_size'] = 'n_embd';
|
|
81
|
+
break;
|
|
82
|
+
case 'gpt_neox':
|
|
83
|
+
case 'stablelm':
|
|
84
|
+
case 'opt':
|
|
85
|
+
case 'phi':
|
|
86
|
+
case 'phi3':
|
|
87
|
+
case 'falcon':
|
|
88
|
+
mapping['num_heads'] = 'num_attention_heads';
|
|
89
|
+
mapping['num_layers'] = 'num_hidden_layers';
|
|
90
|
+
mapping['hidden_size'] = 'hidden_size';
|
|
91
|
+
break;
|
|
92
|
+
case 'llama':
|
|
93
|
+
case 'cohere':
|
|
94
|
+
case 'mistral':
|
|
95
|
+
case 'starcoder2':
|
|
96
|
+
case 'qwen2':
|
|
97
|
+
mapping['num_heads'] = 'num_key_value_heads';
|
|
98
|
+
mapping['num_layers'] = 'num_hidden_layers';
|
|
99
|
+
mapping['hidden_size'] = 'hidden_size';
|
|
100
|
+
mapping['num_attention_heads'] = 'num_attention_heads';
|
|
101
|
+
break;
|
|
102
|
+
case 'gemma':
|
|
103
|
+
case 'gemma2':
|
|
104
|
+
mapping['num_heads'] = 'num_key_value_heads';
|
|
105
|
+
mapping['num_layers'] = 'num_hidden_layers';
|
|
106
|
+
mapping['dim_kv'] = 'head_dim';
|
|
107
|
+
break;
|
|
108
|
+
case 'openelm':
|
|
109
|
+
mapping['num_heads'] = 'num_kv_heads';
|
|
110
|
+
mapping['num_layers'] = 'num_transformer_layers';
|
|
111
|
+
mapping['dim_kv'] = 'head_dim';
|
|
112
|
+
break;
|
|
113
|
+
case 'gpt_neo':
|
|
114
|
+
case 'donut-swin':
|
|
115
|
+
mapping['num_heads'] = 'num_heads';
|
|
116
|
+
mapping['num_layers'] = 'num_layers';
|
|
117
|
+
mapping['hidden_size'] = 'hidden_size';
|
|
118
|
+
break;
|
|
119
|
+
case 'bloom':
|
|
120
|
+
mapping['num_heads'] = 'n_head';
|
|
121
|
+
mapping['num_layers'] = 'n_layer';
|
|
122
|
+
mapping['hidden_size'] = 'hidden_size';
|
|
123
|
+
break;
|
|
124
|
+
case 'mpt':
|
|
125
|
+
mapping['num_heads'] = 'n_heads';
|
|
126
|
+
mapping['num_layers'] = 'n_layers';
|
|
127
|
+
mapping['hidden_size'] = 'd_model';
|
|
128
|
+
break;
|
|
129
|
+
|
|
130
|
+
// Encoder-decoder models
|
|
131
|
+
case 't5':
|
|
132
|
+
case 'mt5':
|
|
133
|
+
case 'longt5':
|
|
134
|
+
mapping['num_decoder_layers'] = 'num_decoder_layers';
|
|
135
|
+
mapping['num_decoder_heads'] = 'num_heads';
|
|
136
|
+
mapping['decoder_dim_kv'] = 'd_kv';
|
|
137
|
+
mapping['num_encoder_layers'] = 'num_layers';
|
|
138
|
+
mapping['num_encoder_heads'] = 'num_heads';
|
|
139
|
+
mapping['encoder_dim_kv'] = 'd_kv';
|
|
140
|
+
break;
|
|
141
|
+
case 'bart':
|
|
142
|
+
case 'mbart':
|
|
143
|
+
case 'marian':
|
|
144
|
+
case 'whisper':
|
|
145
|
+
case 'm2m_100':
|
|
146
|
+
case 'blenderbot':
|
|
147
|
+
case 'blenderbot-small':
|
|
148
|
+
case 'florence2_language':
|
|
149
|
+
mapping['num_decoder_layers'] = 'decoder_layers';
|
|
150
|
+
mapping['num_decoder_heads'] = 'decoder_attention_heads';
|
|
151
|
+
mapping['decoder_hidden_size'] = 'd_model';
|
|
152
|
+
mapping['num_encoder_layers'] = 'encoder_layers';
|
|
153
|
+
mapping['num_encoder_heads'] = 'encoder_attention_heads';
|
|
154
|
+
mapping['encoder_hidden_size'] = 'd_model';
|
|
155
|
+
break;
|
|
156
|
+
case 'speecht5':
|
|
157
|
+
mapping['num_decoder_layers'] = 'decoder_layers';
|
|
158
|
+
mapping['num_decoder_heads'] = 'decoder_attention_heads';
|
|
159
|
+
mapping['decoder_hidden_size'] = 'hidden_size';
|
|
160
|
+
mapping['num_encoder_layers'] = 'encoder_layers';
|
|
161
|
+
mapping['num_encoder_heads'] = 'encoder_attention_heads';
|
|
162
|
+
mapping['encoder_hidden_size'] = 'hidden_size';
|
|
163
|
+
break;
|
|
164
|
+
case 'trocr':
|
|
165
|
+
mapping['num_encoder_layers'] = mapping['num_decoder_layers'] = 'decoder_layers';
|
|
166
|
+
mapping['num_encoder_heads'] = mapping['num_decoder_heads'] = 'decoder_attention_heads';
|
|
167
|
+
mapping['encoder_hidden_size'] = mapping['decoder_hidden_size'] = 'd_model';
|
|
168
|
+
break;
|
|
169
|
+
case 'musicgen_decoder':
|
|
170
|
+
mapping['num_encoder_layers'] = mapping['num_decoder_layers'] = 'num_hidden_layers';
|
|
171
|
+
mapping['num_encoder_heads'] = mapping['num_decoder_heads'] = 'num_attention_heads';
|
|
172
|
+
mapping['encoder_hidden_size'] = mapping['decoder_hidden_size'] = 'hidden_size';
|
|
173
|
+
break;
|
|
174
|
+
|
|
175
|
+
case 'vision-encoder-decoder':
|
|
176
|
+
const decoderConfig = getNormalizedConfig(config.decoder);
|
|
177
|
+
|
|
178
|
+
const add_encoder_pkv = 'num_decoder_layers' in decoderConfig;
|
|
179
|
+
const result = pick(config, ['model_type', 'is_encoder_decoder']);
|
|
180
|
+
if (add_encoder_pkv) {
|
|
181
|
+
// Decoder is part of an encoder-decoder model
|
|
182
|
+
result.num_decoder_layers = decoderConfig.num_decoder_layers;
|
|
183
|
+
result.num_decoder_heads = decoderConfig.num_decoder_heads;
|
|
184
|
+
result.decoder_hidden_size = decoderConfig.decoder_hidden_size;
|
|
185
|
+
|
|
186
|
+
result.num_encoder_layers = decoderConfig.num_encoder_layers;
|
|
187
|
+
result.num_encoder_heads = decoderConfig.num_encoder_heads;
|
|
188
|
+
result.encoder_hidden_size = decoderConfig.encoder_hidden_size;
|
|
189
|
+
} else {
|
|
190
|
+
// Decoder is a decoder-only model
|
|
191
|
+
result.num_layers = decoderConfig.num_layers;
|
|
192
|
+
result.num_heads = decoderConfig.num_heads;
|
|
193
|
+
result.hidden_size = decoderConfig.hidden_size;
|
|
194
|
+
}
|
|
195
|
+
return result;
|
|
196
|
+
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
// NOTE: If `num_attention_heads` is not set, it is assumed to be equal to `num_heads`
|
|
200
|
+
const normalized_config = {
|
|
201
|
+
...init_normalized_config,
|
|
202
|
+
...pick(config, ['model_type', 'multi_query', 'is_encoder_decoder']),
|
|
203
|
+
};
|
|
204
|
+
for (const key in mapping) {
|
|
205
|
+
normalized_config[key] = config[mapping[key]];
|
|
206
|
+
}
|
|
207
|
+
return normalized_config;
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
/**
|
|
211
|
+
*
|
|
212
|
+
* @param {PretrainedConfig} config
|
|
213
|
+
* @returns {Record<string, number[]>}
|
|
214
|
+
*/
|
|
215
|
+
export function getKeyValueShapes(config, {
|
|
216
|
+
prefix = 'past_key_values',
|
|
217
|
+
} = {}) {
|
|
218
|
+
/** @type {Record<string, number[]>} */
|
|
219
|
+
const decoderFeeds = {};
|
|
220
|
+
const normalized_config = config.normalized_config;
|
|
221
|
+
|
|
222
|
+
// TODO support batches (i.e., batch_size > 1)
|
|
223
|
+
const batch_size = 1;
|
|
224
|
+
|
|
225
|
+
if (normalized_config.is_encoder_decoder && (
|
|
226
|
+
'num_encoder_heads' in normalized_config && 'num_decoder_heads' in normalized_config
|
|
227
|
+
)) {
|
|
228
|
+
const encoder_dim_kv = normalized_config.encoder_dim_kv ?? (
|
|
229
|
+
normalized_config.encoder_hidden_size / normalized_config.num_encoder_heads
|
|
230
|
+
);
|
|
231
|
+
const decoder_dim_kv = normalized_config.decoder_dim_kv ?? (
|
|
232
|
+
normalized_config.decoder_hidden_size / normalized_config.num_decoder_heads
|
|
233
|
+
);
|
|
234
|
+
|
|
235
|
+
const encoder_dims = [batch_size, normalized_config.num_encoder_heads, 0, encoder_dim_kv];
|
|
236
|
+
const decoder_dims = [batch_size, normalized_config.num_decoder_heads, 0, decoder_dim_kv];
|
|
237
|
+
for (let i = 0; i < normalized_config.num_decoder_layers; ++i) {
|
|
238
|
+
decoderFeeds[`${prefix}.${i}.encoder.key`] = encoder_dims;
|
|
239
|
+
decoderFeeds[`${prefix}.${i}.encoder.value`] = encoder_dims;
|
|
240
|
+
decoderFeeds[`${prefix}.${i}.decoder.key`] = decoder_dims;
|
|
241
|
+
decoderFeeds[`${prefix}.${i}.decoder.value`] = decoder_dims;
|
|
242
|
+
}
|
|
243
|
+
} else { // Decoders
|
|
244
|
+
const num_heads = normalized_config.num_heads;
|
|
245
|
+
const num_layers = normalized_config.num_layers;
|
|
246
|
+
const dim_kv = normalized_config.dim_kv ?? (
|
|
247
|
+
normalized_config.hidden_size /
|
|
248
|
+
(normalized_config.num_attention_heads ?? num_heads)
|
|
249
|
+
);
|
|
250
|
+
|
|
251
|
+
if (normalized_config.model_type === 'falcon') {
|
|
252
|
+
// NOTE: Custom implementation for Falcon
|
|
253
|
+
const dims = [batch_size * num_heads, 0, dim_kv]
|
|
254
|
+
for (let i = 0; i < num_layers; ++i) {
|
|
255
|
+
decoderFeeds[`${prefix}.${i}.key`] = dims;
|
|
256
|
+
decoderFeeds[`${prefix}.${i}.value`] = dims;
|
|
257
|
+
}
|
|
258
|
+
} else if (normalized_config.multi_query) { // e.g., for `gpt_bigcode`
|
|
259
|
+
const dims = [batch_size * num_heads, 0, 2 * dim_kv]
|
|
260
|
+
|
|
261
|
+
for (let i = 0; i < num_layers; ++i) {
|
|
262
|
+
decoderFeeds[`${prefix}.${i}.key_value`] = dims;
|
|
263
|
+
}
|
|
264
|
+
} else if (normalized_config.model_type === 'bloom') {
|
|
265
|
+
// NOTE: Custom implementation for Bloom
|
|
266
|
+
|
|
267
|
+
const keyDims = [batch_size * num_heads, dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length]
|
|
268
|
+
const valueDims = [batch_size * num_heads, 0, dim_kv] // [batch_size x num_heads,past_sequence_length,64]
|
|
269
|
+
for (let i = 0; i < num_layers; ++i) {
|
|
270
|
+
decoderFeeds[`${prefix}.${i}.key`] = keyDims;
|
|
271
|
+
decoderFeeds[`${prefix}.${i}.value`] = valueDims;
|
|
272
|
+
}
|
|
273
|
+
} else if (normalized_config.model_type === 'openelm') {
|
|
274
|
+
for (let i = 0; i < num_layers; ++i) {
|
|
275
|
+
const dims = [batch_size, num_heads[i], 0, dim_kv]
|
|
276
|
+
|
|
277
|
+
decoderFeeds[`${prefix}.${i}.key`] = dims;
|
|
278
|
+
decoderFeeds[`${prefix}.${i}.value`] = dims;
|
|
279
|
+
}
|
|
280
|
+
} else { // Decoder-only
|
|
281
|
+
const dims = [batch_size, num_heads, 0, dim_kv]
|
|
282
|
+
for (let i = 0; i < num_layers; ++i) {
|
|
283
|
+
decoderFeeds[`${prefix}.${i}.key`] = dims;
|
|
284
|
+
decoderFeeds[`${prefix}.${i}.value`] = dims;
|
|
285
|
+
}
|
|
286
|
+
}
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
return decoderFeeds;
|
|
290
|
+
}
|
|
291
|
+
/**
|
|
292
|
+
* Base class for all configuration classes. For more information, see the corresponding
|
|
293
|
+
* [Python documentation](https://huggingface.co/docs/transformers/main/en/main_classes/configuration#transformers.PretrainedConfig).
|
|
294
|
+
*/
|
|
295
|
+
export class PretrainedConfig {
|
|
296
|
+
// NOTE: Typo in original
|
|
297
|
+
|
|
298
|
+
max_position_embeddings;
|
|
299
|
+
|
|
300
|
+
/**
|
|
301
|
+
* Create a new PreTrainedTokenizer instance.
|
|
302
|
+
* @param {Object} configJSON The JSON of the config.
|
|
303
|
+
*/
|
|
304
|
+
constructor(configJSON) {
|
|
305
|
+
this.model_type = null;
|
|
306
|
+
this.is_encoder_decoder = false;
|
|
307
|
+
|
|
308
|
+
Object.assign(this, configJSON);
|
|
309
|
+
this.normalized_config = getNormalizedConfig(this);
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
/**
|
|
313
|
+
* Loads a pre-trained config from the given `pretrained_model_name_or_path`.
|
|
314
|
+
*
|
|
315
|
+
* @param {string} pretrained_model_name_or_path The path to the pre-trained config.
|
|
316
|
+
* @param {PretrainedOptions} options Additional options for loading the config.
|
|
317
|
+
* @throws {Error} Throws an error if the config.json is not found in the `pretrained_model_name_or_path`.
|
|
318
|
+
*
|
|
319
|
+
* @returns {Promise<PretrainedConfig>} A new instance of the `PretrainedConfig` class.
|
|
320
|
+
*/
|
|
321
|
+
static async from_pretrained(pretrained_model_name_or_path, {
|
|
322
|
+
progress_callback = null,
|
|
323
|
+
config = null,
|
|
324
|
+
cache_dir = null,
|
|
325
|
+
local_files_only = false,
|
|
326
|
+
revision = 'main',
|
|
327
|
+
} = {}) {
|
|
328
|
+
if (config && !(config instanceof PretrainedConfig)) {
|
|
329
|
+
config = new PretrainedConfig(config);
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
const data = config ?? await loadConfig(pretrained_model_name_or_path, {
|
|
333
|
+
progress_callback,
|
|
334
|
+
config,
|
|
335
|
+
cache_dir,
|
|
336
|
+
local_files_only,
|
|
337
|
+
revision,
|
|
338
|
+
})
|
|
339
|
+
return new this(data);
|
|
340
|
+
}
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
/**
|
|
344
|
+
* Helper class which is used to instantiate pretrained configs with the `from_pretrained` function.
|
|
345
|
+
*
|
|
346
|
+
* @example
|
|
347
|
+
* const config = await AutoConfig.from_pretrained('Xenova/bert-base-uncased');
|
|
348
|
+
*/
|
|
349
|
+
export class AutoConfig {
|
|
350
|
+
/** @type {typeof PretrainedConfig.from_pretrained} */
|
|
351
|
+
static async from_pretrained(...args) {
|
|
352
|
+
return PretrainedConfig.from_pretrained(...args);
|
|
353
|
+
}
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
/**
|
|
357
|
+
* Transformers.js-specific configuration, possibly present in config.json under the key `transformers.js_config`.
|
|
358
|
+
* @typedef {Object} TransformersJSConfig
|
|
359
|
+
* @property {import('./transformers.js').DataType} [kv_cache_dtype]
|
|
360
|
+
*/
|