@huggingface/transformers 3.0.0-alpha.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. package/LICENSE +202 -0
  2. package/README.md +376 -0
  3. package/dist/ort-wasm-simd-threaded.jsep.wasm +0 -0
  4. package/dist/transformers.cjs +30741 -0
  5. package/dist/transformers.cjs.map +1 -0
  6. package/dist/transformers.js +33858 -0
  7. package/dist/transformers.js.map +1 -0
  8. package/dist/transformers.min.cjs +173 -0
  9. package/dist/transformers.min.cjs.map +1 -0
  10. package/dist/transformers.min.js +231 -0
  11. package/dist/transformers.min.js.map +1 -0
  12. package/package.json +92 -0
  13. package/src/backends/onnx.js +151 -0
  14. package/src/configs.js +360 -0
  15. package/src/env.js +152 -0
  16. package/src/generation/configuration_utils.js +381 -0
  17. package/src/generation/logits_process.js +716 -0
  18. package/src/generation/logits_sampler.js +204 -0
  19. package/src/generation/parameters.js +35 -0
  20. package/src/generation/stopping_criteria.js +156 -0
  21. package/src/generation/streamers.js +212 -0
  22. package/src/models/whisper/common_whisper.js +151 -0
  23. package/src/models/whisper/generation_whisper.js +89 -0
  24. package/src/models.js +7028 -0
  25. package/src/ops/registry.js +92 -0
  26. package/src/pipelines.js +3341 -0
  27. package/src/processors.js +2614 -0
  28. package/src/tokenizers.js +4395 -0
  29. package/src/transformers.js +28 -0
  30. package/src/utils/audio.js +704 -0
  31. package/src/utils/constants.js +2 -0
  32. package/src/utils/core.js +149 -0
  33. package/src/utils/data-structures.js +445 -0
  34. package/src/utils/devices.js +11 -0
  35. package/src/utils/dtypes.js +62 -0
  36. package/src/utils/generic.js +35 -0
  37. package/src/utils/hub.js +671 -0
  38. package/src/utils/image.js +745 -0
  39. package/src/utils/maths.js +1050 -0
  40. package/src/utils/tensor.js +1378 -0
  41. package/types/backends/onnx.d.ts +26 -0
  42. package/types/backends/onnx.d.ts.map +1 -0
  43. package/types/configs.d.ts +59 -0
  44. package/types/configs.d.ts.map +1 -0
  45. package/types/env.d.ts +106 -0
  46. package/types/env.d.ts.map +1 -0
  47. package/types/generation/configuration_utils.d.ts +320 -0
  48. package/types/generation/configuration_utils.d.ts.map +1 -0
  49. package/types/generation/logits_process.d.ts +354 -0
  50. package/types/generation/logits_process.d.ts.map +1 -0
  51. package/types/generation/logits_sampler.d.ts +51 -0
  52. package/types/generation/logits_sampler.d.ts.map +1 -0
  53. package/types/generation/parameters.d.ts +47 -0
  54. package/types/generation/parameters.d.ts.map +1 -0
  55. package/types/generation/stopping_criteria.d.ts +81 -0
  56. package/types/generation/stopping_criteria.d.ts.map +1 -0
  57. package/types/generation/streamers.d.ts +81 -0
  58. package/types/generation/streamers.d.ts.map +1 -0
  59. package/types/models/whisper/common_whisper.d.ts +8 -0
  60. package/types/models/whisper/common_whisper.d.ts.map +1 -0
  61. package/types/models/whisper/generation_whisper.d.ts +76 -0
  62. package/types/models/whisper/generation_whisper.d.ts.map +1 -0
  63. package/types/models.d.ts +3845 -0
  64. package/types/models.d.ts.map +1 -0
  65. package/types/ops/registry.d.ts +11 -0
  66. package/types/ops/registry.d.ts.map +1 -0
  67. package/types/pipelines.d.ts +2403 -0
  68. package/types/pipelines.d.ts.map +1 -0
  69. package/types/processors.d.ts +917 -0
  70. package/types/processors.d.ts.map +1 -0
  71. package/types/tokenizers.d.ts +999 -0
  72. package/types/tokenizers.d.ts.map +1 -0
  73. package/types/transformers.d.ts +13 -0
  74. package/types/transformers.d.ts.map +1 -0
  75. package/types/utils/audio.d.ts +130 -0
  76. package/types/utils/audio.d.ts.map +1 -0
  77. package/types/utils/constants.d.ts +2 -0
  78. package/types/utils/constants.d.ts.map +1 -0
  79. package/types/utils/core.d.ts +91 -0
  80. package/types/utils/core.d.ts.map +1 -0
  81. package/types/utils/data-structures.d.ts +236 -0
  82. package/types/utils/data-structures.d.ts.map +1 -0
  83. package/types/utils/devices.d.ts +8 -0
  84. package/types/utils/devices.d.ts.map +1 -0
  85. package/types/utils/dtypes.d.ts +22 -0
  86. package/types/utils/dtypes.d.ts.map +1 -0
  87. package/types/utils/generic.d.ts +11 -0
  88. package/types/utils/generic.d.ts.map +1 -0
  89. package/types/utils/hub.d.ts +191 -0
  90. package/types/utils/hub.d.ts.map +1 -0
  91. package/types/utils/image.d.ts +119 -0
  92. package/types/utils/image.d.ts.map +1 -0
  93. package/types/utils/maths.d.ts +280 -0
  94. package/types/utils/maths.d.ts.map +1 -0
  95. package/types/utils/tensor.d.ts +392 -0
  96. package/types/utils/tensor.d.ts.map +1 -0
package/package.json ADDED
@@ -0,0 +1,92 @@
1
+ {
2
+ "name": "@huggingface/transformers",
3
+ "version": "3.0.0-alpha.0",
4
+ "description": "State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!",
5
+ "main": "./src/transformers.js",
6
+ "types": "./types/transformers.d.ts",
7
+ "type": "module",
8
+ "exports": {
9
+ "node": {
10
+ "import": "./dist/transformers.js",
11
+ "require": "./dist/transformers.cjs"
12
+ },
13
+ "default": "./src/transformers.js"
14
+ },
15
+ "scripts": {
16
+ "format": "prettier --write .",
17
+ "format:check": "prettier --check .",
18
+ "typegen": "tsc ./src/transformers.js --allowJs --declaration --emitDeclarationOnly --declarationMap --outDir types",
19
+ "dev": "webpack serve --no-client-overlay",
20
+ "build": "webpack && npm run typegen",
21
+ "generate-tests": "python -m tests.generate_tests",
22
+ "test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --verbose --maxConcurrency 1",
23
+ "readme": "python ./docs/scripts/build_readme.py",
24
+ "docs-api": "node ./docs/scripts/generate.js",
25
+ "docs-preview": "doc-builder preview transformers.js ./docs/source/ --not_python_module",
26
+ "docs-build": "doc-builder build transformers.js ./docs/source/ --not_python_module --build_dir ./docs/build/ --repo_owner xenova"
27
+ },
28
+ "repository": {
29
+ "type": "git",
30
+ "url": "git+https://github.com/xenova/transformers.js.git"
31
+ },
32
+ "keywords": [
33
+ "transformers",
34
+ "transformers.js",
35
+ "huggingface",
36
+ "hugging face",
37
+ "machine learning",
38
+ "deep learning",
39
+ "artificial intelligence",
40
+ "AI",
41
+ "ML"
42
+ ],
43
+ "author": "Hugging Face",
44
+ "license": "Apache-2.0",
45
+ "bugs": {
46
+ "url": "https://github.com/xenova/transformers.js/issues"
47
+ },
48
+ "homepage": "https://github.com/xenova/transformers.js#readme",
49
+ "dependencies": {
50
+ "@huggingface/jinja": "^0.3.0",
51
+ "onnxruntime-node": "1.18.0",
52
+ "onnxruntime-web": "1.19.0-dev.20240804-ee2fe87e2d",
53
+ "sharp": "^0.33.2"
54
+ },
55
+ "devDependencies": {
56
+ "@types/jest": "^29.5.1",
57
+ "@webgpu/types": "^0.1.44",
58
+ "catharsis": "github:xenova/catharsis",
59
+ "jest": "^29.5.0",
60
+ "jest-environment-node": "^29.5.0",
61
+ "jsdoc-to-markdown": "^8.0.1",
62
+ "prettier": "3.3.3",
63
+ "typescript": "^5.2.2",
64
+ "wavefile": "^11.0.0",
65
+ "webpack": "^5.80.0",
66
+ "webpack-cli": "^5.0.2",
67
+ "webpack-dev-server": "^4.13.3"
68
+ },
69
+ "overrides": {
70
+ "semver": "^7.6.3",
71
+ "protobufjs": "^7.2.6"
72
+ },
73
+ "files": [
74
+ "src",
75
+ "dist",
76
+ "types",
77
+ "README.md",
78
+ "LICENSE"
79
+ ],
80
+ "browser": {
81
+ "fs": false,
82
+ "path": false,
83
+ "url": false,
84
+ "sharp": false,
85
+ "onnxruntime-node": false
86
+ },
87
+ "publishConfig": {
88
+ "access": "public"
89
+ },
90
+ "jsdelivr": "./dist/transformers.min.js",
91
+ "unpkg": "./dist/transformers.min.js"
92
+ }
@@ -0,0 +1,151 @@
1
+ /**
2
+ * @file Handler file for choosing the correct version of ONNX Runtime, based on the environment.
3
+ * Ideally, we could import the `onnxruntime-web` and `onnxruntime-node` packages only when needed,
4
+ * but dynamic imports don't seem to work with the current webpack version and/or configuration.
5
+ * This is possibly due to the experimental nature of top-level await statements.
6
+ * So, we just import both packages, and use the appropriate one based on the environment:
7
+ * - When running in node, we use `onnxruntime-node`.
8
+ * - When running in the browser, we use `onnxruntime-web` (`onnxruntime-node` is not bundled).
9
+ *
10
+ * This module is not directly exported, but can be accessed through the environment variables:
11
+ * ```javascript
12
+ * import { env } from '@huggingface/transformers';
13
+ * console.log(env.backends.onnx);
14
+ * ```
15
+ *
16
+ * @module backends/onnx
17
+ */
18
+
19
+ import { env, apis } from '../env.js';
20
+
21
+ // NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`.
22
+ // In either case, we select the default export if it exists, otherwise we use the named export.
23
+ import * as ONNX_NODE from 'onnxruntime-node';
24
+ import * as ONNX_WEB from 'onnxruntime-web/webgpu';
25
+
26
+ export { Tensor } from 'onnxruntime-common';
27
+
28
+ /** @type {import('../utils/devices.js').DeviceType[]} */
29
+ const supportedExecutionProviders = [];
30
+
31
+ /** @type {import('../utils/devices.js').DeviceType[]} */
32
+ let defaultExecutionProviders;
33
+ let ONNX;
34
+ if (apis.IS_NODE_ENV) {
35
+ ONNX = ONNX_NODE.default ?? ONNX_NODE;
36
+ supportedExecutionProviders.push('cpu');
37
+ defaultExecutionProviders = ['cpu'];
38
+ } else {
39
+ ONNX = ONNX_WEB;
40
+ if (apis.IS_WEBGPU_AVAILABLE) {
41
+ supportedExecutionProviders.push('webgpu');
42
+ }
43
+ supportedExecutionProviders.push('wasm');
44
+ defaultExecutionProviders = ['wasm'];
45
+ }
46
+
47
+ // @ts-ignore
48
+ const InferenceSession = ONNX.InferenceSession;
49
+
50
+ /**
51
+ * Map a device to the execution providers to use for the given device.
52
+ * @param {import("../utils/devices.js").DeviceType} [device=null] (Optional) The device to run the inference on.
53
+ * @returns {import("../utils/devices.js").DeviceType[]} The execution providers to use for the given device.
54
+ */
55
+ export function deviceToExecutionProviders(device) {
56
+ // TODO: Use mapping from device to execution providers for overloaded devices (e.g., 'gpu' or 'cpu').
57
+ let executionProviders = defaultExecutionProviders;
58
+ if (device) { // User has specified a device
59
+ if (!supportedExecutionProviders.includes(device)) {
60
+ throw new Error(`Unsupported device: "${device}". Should be one of: ${supportedExecutionProviders.join(', ')}.`)
61
+ }
62
+ executionProviders = [device];
63
+ }
64
+ return executionProviders;
65
+ }
66
+
67
+
68
+ /**
69
+ * To prevent multiple calls to `initWasm()`, we store the first call in a Promise
70
+ * that is resolved when the first InferenceSession is created. Subsequent calls
71
+ * will wait for this Promise to resolve before creating their own InferenceSession.
72
+ * @type {Promise<any>|null}
73
+ */
74
+ let wasmInitPromise = null;
75
+
76
+ /**
77
+ * Create an ONNX inference session.
78
+ * @param {Uint8Array} buffer The ONNX model buffer.
79
+ * @param {Object} session_options ONNX inference session options.
80
+ * @returns {Promise<import('onnxruntime-common').InferenceSession>} The ONNX inference session.
81
+ */
82
+ export async function createInferenceSession(buffer, session_options) {
83
+ if (wasmInitPromise) {
84
+ // A previous session has already initialized the WASM runtime
85
+ // so we wait for it to resolve before creating this new session.
86
+ await wasmInitPromise;
87
+ }
88
+
89
+ const sessionPromise = InferenceSession.create(buffer, session_options);
90
+ wasmInitPromise ??= sessionPromise;
91
+ return await sessionPromise;
92
+ }
93
+
94
+ /**
95
+ * Check if an object is an ONNX tensor.
96
+ * @param {any} x The object to check
97
+ * @returns {boolean} Whether the object is an ONNX tensor.
98
+ */
99
+ export function isONNXTensor(x) {
100
+ return x instanceof ONNX.Tensor;
101
+ }
102
+
103
+ // @ts-ignore
104
+ const ONNX_ENV = ONNX?.env;
105
+ if (ONNX_ENV?.wasm) {
106
+ // Initialize wasm backend with suitable default settings.
107
+
108
+ // (Optional) Set path to wasm files. This is needed when running in a web worker.
109
+ // https://onnxruntime.ai/docs/api/js/interfaces/Env.WebAssemblyFlags.html#wasmPaths
110
+ // We use remote wasm files by default to make it easier for newer users.
111
+ // In practice, users should probably self-host the necessary .wasm files.
112
+ // ONNX_ENV.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.19.0-dev.20240804-ee2fe87e2d/dist/';
113
+
114
+ // TODO: Add support for loading WASM files from cached buffer when we upgrade to onnxruntime-web@1.19.0
115
+ // https://github.com/microsoft/onnxruntime/pull/21534
116
+
117
+ // Proxy the WASM backend to prevent the UI from freezing
118
+ // NOTE: This is only needed when running in a non-worker browser environment.
119
+ ONNX_ENV.wasm.proxy = !apis.IS_WEBWORKER_ENV;
120
+
121
+ // https://developer.mozilla.org/en-US/docs/Web/API/crossOriginIsolated
122
+ if (typeof crossOriginIsolated === 'undefined' || !crossOriginIsolated) {
123
+ ONNX_ENV.wasm.numThreads = 1;
124
+ }
125
+
126
+ // Running in a browser-environment
127
+ // TODO: Check if 1.17.1 fixes this issue.
128
+ // SIMD for WebAssembly does not operate correctly in some recent versions of iOS (16.4.x).
129
+ // As a temporary fix, we disable it for now.
130
+ // For more information, see: https://github.com/microsoft/onnxruntime/issues/15644
131
+ const isIOS = typeof navigator !== 'undefined' && /iP(hone|od|ad).+16_4.+AppleWebKit/.test(navigator.userAgent);
132
+ if (isIOS) {
133
+ ONNX_ENV.wasm.simd = false;
134
+ }
135
+ }
136
+
137
+ if (ONNX_ENV?.webgpu) {
138
+ ONNX_ENV.webgpu.powerPreference = 'high-performance';
139
+ }
140
+
141
+ /**
142
+ * Check if ONNX's WASM backend is being proxied.
143
+ * @returns {boolean} Whether ONNX's WASM backend is being proxied.
144
+ */
145
+ export function isONNXProxy() {
146
+ // TODO: Update this when allowing non-WASM backends.
147
+ return ONNX_ENV?.wasm?.proxy;
148
+ }
149
+
150
+ // Expose ONNX environment variables to `env.backends.onnx`
151
+ env.backends.onnx = ONNX_ENV;
package/src/configs.js ADDED
@@ -0,0 +1,360 @@
1
+
2
+ /**
3
+ * @file Helper module for using model configs. For more information, see the corresponding
4
+ * [Python documentation](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoConfig).
5
+ *
6
+ * **Example:** Load an `AutoConfig`.
7
+ *
8
+ * ```javascript
9
+ * import { AutoConfig } from '@huggingface/transformers';
10
+ * const config = await AutoConfig.from_pretrained('bert-base-uncased');
11
+ * console.log(config);
12
+ * // PretrainedConfig {
13
+ * // "model_type": "bert",
14
+ * // "is_encoder_decoder": false,
15
+ * // "architectures": [
16
+ * // "BertForMaskedLM"
17
+ * // ],
18
+ * // "vocab_size": 30522
19
+ * // "num_attention_heads": 12,
20
+ * // "num_hidden_layers": 12,
21
+ * // "hidden_size": 768,
22
+ * // "max_position_embeddings": 512,
23
+ * // ...
24
+ * // }
25
+ * ```
26
+ *
27
+ * @module configs
28
+ */
29
+
30
+ import { pick } from './utils/core.js';
31
+ import {
32
+ getModelJSON,
33
+ } from './utils/hub.js';
34
+
35
+ /**
36
+ * @typedef {import('./utils/hub.js').PretrainedOptions} PretrainedOptions
37
+ */
38
+
39
+
40
+ /**
41
+ * Loads a config from the specified path.
42
+ * @param {string} pretrained_model_name_or_path The path to the config directory.
43
+ * @param {PretrainedOptions} options Additional options for loading the config.
44
+ * @returns {Promise<Object>} A promise that resolves with information about the loaded config.
45
+ */
46
+ async function loadConfig(pretrained_model_name_or_path, options) {
47
+ return await getModelJSON(pretrained_model_name_or_path, 'config.json', true, options);
48
+ }
49
+
50
+ /**
51
+ *
52
+ * @param {PretrainedConfig} config
53
+ * @returns {Object} The normalized configuration.
54
+ */
55
+ function getNormalizedConfig(config) {
56
+ const mapping = {};
57
+
58
+ let init_normalized_config = {};
59
+ switch (config.model_type) {
60
+ // Sub-configs
61
+ case 'llava':
62
+ case 'paligemma':
63
+ case 'florence2':
64
+ init_normalized_config = getNormalizedConfig(config.text_config);
65
+ break;
66
+ case 'moondream1':
67
+ init_normalized_config = getNormalizedConfig(config.phi_config);
68
+ break;
69
+ case 'musicgen':
70
+ init_normalized_config = getNormalizedConfig(config.decoder);
71
+ break;
72
+
73
+ // Decoder-only models
74
+ case 'gpt2':
75
+ case 'gptj':
76
+ case 'codegen':
77
+ case 'gpt_bigcode':
78
+ mapping['num_heads'] = 'n_head';
79
+ mapping['num_layers'] = 'n_layer';
80
+ mapping['hidden_size'] = 'n_embd';
81
+ break;
82
+ case 'gpt_neox':
83
+ case 'stablelm':
84
+ case 'opt':
85
+ case 'phi':
86
+ case 'phi3':
87
+ case 'falcon':
88
+ mapping['num_heads'] = 'num_attention_heads';
89
+ mapping['num_layers'] = 'num_hidden_layers';
90
+ mapping['hidden_size'] = 'hidden_size';
91
+ break;
92
+ case 'llama':
93
+ case 'cohere':
94
+ case 'mistral':
95
+ case 'starcoder2':
96
+ case 'qwen2':
97
+ mapping['num_heads'] = 'num_key_value_heads';
98
+ mapping['num_layers'] = 'num_hidden_layers';
99
+ mapping['hidden_size'] = 'hidden_size';
100
+ mapping['num_attention_heads'] = 'num_attention_heads';
101
+ break;
102
+ case 'gemma':
103
+ case 'gemma2':
104
+ mapping['num_heads'] = 'num_key_value_heads';
105
+ mapping['num_layers'] = 'num_hidden_layers';
106
+ mapping['dim_kv'] = 'head_dim';
107
+ break;
108
+ case 'openelm':
109
+ mapping['num_heads'] = 'num_kv_heads';
110
+ mapping['num_layers'] = 'num_transformer_layers';
111
+ mapping['dim_kv'] = 'head_dim';
112
+ break;
113
+ case 'gpt_neo':
114
+ case 'donut-swin':
115
+ mapping['num_heads'] = 'num_heads';
116
+ mapping['num_layers'] = 'num_layers';
117
+ mapping['hidden_size'] = 'hidden_size';
118
+ break;
119
+ case 'bloom':
120
+ mapping['num_heads'] = 'n_head';
121
+ mapping['num_layers'] = 'n_layer';
122
+ mapping['hidden_size'] = 'hidden_size';
123
+ break;
124
+ case 'mpt':
125
+ mapping['num_heads'] = 'n_heads';
126
+ mapping['num_layers'] = 'n_layers';
127
+ mapping['hidden_size'] = 'd_model';
128
+ break;
129
+
130
+ // Encoder-decoder models
131
+ case 't5':
132
+ case 'mt5':
133
+ case 'longt5':
134
+ mapping['num_decoder_layers'] = 'num_decoder_layers';
135
+ mapping['num_decoder_heads'] = 'num_heads';
136
+ mapping['decoder_dim_kv'] = 'd_kv';
137
+ mapping['num_encoder_layers'] = 'num_layers';
138
+ mapping['num_encoder_heads'] = 'num_heads';
139
+ mapping['encoder_dim_kv'] = 'd_kv';
140
+ break;
141
+ case 'bart':
142
+ case 'mbart':
143
+ case 'marian':
144
+ case 'whisper':
145
+ case 'm2m_100':
146
+ case 'blenderbot':
147
+ case 'blenderbot-small':
148
+ case 'florence2_language':
149
+ mapping['num_decoder_layers'] = 'decoder_layers';
150
+ mapping['num_decoder_heads'] = 'decoder_attention_heads';
151
+ mapping['decoder_hidden_size'] = 'd_model';
152
+ mapping['num_encoder_layers'] = 'encoder_layers';
153
+ mapping['num_encoder_heads'] = 'encoder_attention_heads';
154
+ mapping['encoder_hidden_size'] = 'd_model';
155
+ break;
156
+ case 'speecht5':
157
+ mapping['num_decoder_layers'] = 'decoder_layers';
158
+ mapping['num_decoder_heads'] = 'decoder_attention_heads';
159
+ mapping['decoder_hidden_size'] = 'hidden_size';
160
+ mapping['num_encoder_layers'] = 'encoder_layers';
161
+ mapping['num_encoder_heads'] = 'encoder_attention_heads';
162
+ mapping['encoder_hidden_size'] = 'hidden_size';
163
+ break;
164
+ case 'trocr':
165
+ mapping['num_encoder_layers'] = mapping['num_decoder_layers'] = 'decoder_layers';
166
+ mapping['num_encoder_heads'] = mapping['num_decoder_heads'] = 'decoder_attention_heads';
167
+ mapping['encoder_hidden_size'] = mapping['decoder_hidden_size'] = 'd_model';
168
+ break;
169
+ case 'musicgen_decoder':
170
+ mapping['num_encoder_layers'] = mapping['num_decoder_layers'] = 'num_hidden_layers';
171
+ mapping['num_encoder_heads'] = mapping['num_decoder_heads'] = 'num_attention_heads';
172
+ mapping['encoder_hidden_size'] = mapping['decoder_hidden_size'] = 'hidden_size';
173
+ break;
174
+
175
+ case 'vision-encoder-decoder':
176
+ const decoderConfig = getNormalizedConfig(config.decoder);
177
+
178
+ const add_encoder_pkv = 'num_decoder_layers' in decoderConfig;
179
+ const result = pick(config, ['model_type', 'is_encoder_decoder']);
180
+ if (add_encoder_pkv) {
181
+ // Decoder is part of an encoder-decoder model
182
+ result.num_decoder_layers = decoderConfig.num_decoder_layers;
183
+ result.num_decoder_heads = decoderConfig.num_decoder_heads;
184
+ result.decoder_hidden_size = decoderConfig.decoder_hidden_size;
185
+
186
+ result.num_encoder_layers = decoderConfig.num_encoder_layers;
187
+ result.num_encoder_heads = decoderConfig.num_encoder_heads;
188
+ result.encoder_hidden_size = decoderConfig.encoder_hidden_size;
189
+ } else {
190
+ // Decoder is a decoder-only model
191
+ result.num_layers = decoderConfig.num_layers;
192
+ result.num_heads = decoderConfig.num_heads;
193
+ result.hidden_size = decoderConfig.hidden_size;
194
+ }
195
+ return result;
196
+
197
+ }
198
+
199
+ // NOTE: If `num_attention_heads` is not set, it is assumed to be equal to `num_heads`
200
+ const normalized_config = {
201
+ ...init_normalized_config,
202
+ ...pick(config, ['model_type', 'multi_query', 'is_encoder_decoder']),
203
+ };
204
+ for (const key in mapping) {
205
+ normalized_config[key] = config[mapping[key]];
206
+ }
207
+ return normalized_config;
208
+ }
209
+
210
+ /**
211
+ *
212
+ * @param {PretrainedConfig} config
213
+ * @returns {Record<string, number[]>}
214
+ */
215
+ export function getKeyValueShapes(config, {
216
+ prefix = 'past_key_values',
217
+ } = {}) {
218
+ /** @type {Record<string, number[]>} */
219
+ const decoderFeeds = {};
220
+ const normalized_config = config.normalized_config;
221
+
222
+ // TODO support batches (i.e., batch_size > 1)
223
+ const batch_size = 1;
224
+
225
+ if (normalized_config.is_encoder_decoder && (
226
+ 'num_encoder_heads' in normalized_config && 'num_decoder_heads' in normalized_config
227
+ )) {
228
+ const encoder_dim_kv = normalized_config.encoder_dim_kv ?? (
229
+ normalized_config.encoder_hidden_size / normalized_config.num_encoder_heads
230
+ );
231
+ const decoder_dim_kv = normalized_config.decoder_dim_kv ?? (
232
+ normalized_config.decoder_hidden_size / normalized_config.num_decoder_heads
233
+ );
234
+
235
+ const encoder_dims = [batch_size, normalized_config.num_encoder_heads, 0, encoder_dim_kv];
236
+ const decoder_dims = [batch_size, normalized_config.num_decoder_heads, 0, decoder_dim_kv];
237
+ for (let i = 0; i < normalized_config.num_decoder_layers; ++i) {
238
+ decoderFeeds[`${prefix}.${i}.encoder.key`] = encoder_dims;
239
+ decoderFeeds[`${prefix}.${i}.encoder.value`] = encoder_dims;
240
+ decoderFeeds[`${prefix}.${i}.decoder.key`] = decoder_dims;
241
+ decoderFeeds[`${prefix}.${i}.decoder.value`] = decoder_dims;
242
+ }
243
+ } else { // Decoders
244
+ const num_heads = normalized_config.num_heads;
245
+ const num_layers = normalized_config.num_layers;
246
+ const dim_kv = normalized_config.dim_kv ?? (
247
+ normalized_config.hidden_size /
248
+ (normalized_config.num_attention_heads ?? num_heads)
249
+ );
250
+
251
+ if (normalized_config.model_type === 'falcon') {
252
+ // NOTE: Custom implementation for Falcon
253
+ const dims = [batch_size * num_heads, 0, dim_kv]
254
+ for (let i = 0; i < num_layers; ++i) {
255
+ decoderFeeds[`${prefix}.${i}.key`] = dims;
256
+ decoderFeeds[`${prefix}.${i}.value`] = dims;
257
+ }
258
+ } else if (normalized_config.multi_query) { // e.g., for `gpt_bigcode`
259
+ const dims = [batch_size * num_heads, 0, 2 * dim_kv]
260
+
261
+ for (let i = 0; i < num_layers; ++i) {
262
+ decoderFeeds[`${prefix}.${i}.key_value`] = dims;
263
+ }
264
+ } else if (normalized_config.model_type === 'bloom') {
265
+ // NOTE: Custom implementation for Bloom
266
+
267
+ const keyDims = [batch_size * num_heads, dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length]
268
+ const valueDims = [batch_size * num_heads, 0, dim_kv] // [batch_size x num_heads,past_sequence_length,64]
269
+ for (let i = 0; i < num_layers; ++i) {
270
+ decoderFeeds[`${prefix}.${i}.key`] = keyDims;
271
+ decoderFeeds[`${prefix}.${i}.value`] = valueDims;
272
+ }
273
+ } else if (normalized_config.model_type === 'openelm') {
274
+ for (let i = 0; i < num_layers; ++i) {
275
+ const dims = [batch_size, num_heads[i], 0, dim_kv]
276
+
277
+ decoderFeeds[`${prefix}.${i}.key`] = dims;
278
+ decoderFeeds[`${prefix}.${i}.value`] = dims;
279
+ }
280
+ } else { // Decoder-only
281
+ const dims = [batch_size, num_heads, 0, dim_kv]
282
+ for (let i = 0; i < num_layers; ++i) {
283
+ decoderFeeds[`${prefix}.${i}.key`] = dims;
284
+ decoderFeeds[`${prefix}.${i}.value`] = dims;
285
+ }
286
+ }
287
+ }
288
+
289
+ return decoderFeeds;
290
+ }
291
+ /**
292
+ * Base class for all configuration classes. For more information, see the corresponding
293
+ * [Python documentation](https://huggingface.co/docs/transformers/main/en/main_classes/configuration#transformers.PretrainedConfig).
294
+ */
295
+ export class PretrainedConfig {
296
+ // NOTE: Typo in original
297
+
298
+ max_position_embeddings;
299
+
300
+ /**
301
+ * Create a new PreTrainedTokenizer instance.
302
+ * @param {Object} configJSON The JSON of the config.
303
+ */
304
+ constructor(configJSON) {
305
+ this.model_type = null;
306
+ this.is_encoder_decoder = false;
307
+
308
+ Object.assign(this, configJSON);
309
+ this.normalized_config = getNormalizedConfig(this);
310
+ }
311
+
312
+ /**
313
+ * Loads a pre-trained config from the given `pretrained_model_name_or_path`.
314
+ *
315
+ * @param {string} pretrained_model_name_or_path The path to the pre-trained config.
316
+ * @param {PretrainedOptions} options Additional options for loading the config.
317
+ * @throws {Error} Throws an error if the config.json is not found in the `pretrained_model_name_or_path`.
318
+ *
319
+ * @returns {Promise<PretrainedConfig>} A new instance of the `PretrainedConfig` class.
320
+ */
321
+ static async from_pretrained(pretrained_model_name_or_path, {
322
+ progress_callback = null,
323
+ config = null,
324
+ cache_dir = null,
325
+ local_files_only = false,
326
+ revision = 'main',
327
+ } = {}) {
328
+ if (config && !(config instanceof PretrainedConfig)) {
329
+ config = new PretrainedConfig(config);
330
+ }
331
+
332
+ const data = config ?? await loadConfig(pretrained_model_name_or_path, {
333
+ progress_callback,
334
+ config,
335
+ cache_dir,
336
+ local_files_only,
337
+ revision,
338
+ })
339
+ return new this(data);
340
+ }
341
+ }
342
+
343
+ /**
344
+ * Helper class which is used to instantiate pretrained configs with the `from_pretrained` function.
345
+ *
346
+ * @example
347
+ * const config = await AutoConfig.from_pretrained('Xenova/bert-base-uncased');
348
+ */
349
+ export class AutoConfig {
350
+ /** @type {typeof PretrainedConfig.from_pretrained} */
351
+ static async from_pretrained(...args) {
352
+ return PretrainedConfig.from_pretrained(...args);
353
+ }
354
+ }
355
+
356
+ /**
357
+ * Transformers.js-specific configuration, possibly present in config.json under the key `transformers.js_config`.
358
+ * @typedef {Object} TransformersJSConfig
359
+ * @property {import('./transformers.js').DataType} [kv_cache_dtype]
360
+ */